use crate::vector;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum Params<'i, 't, 'to, 'm> {
Owned(ParamsOwned),
Ref(ParamsRef<'i, 't, 'to, 'm>),
}
impl<'de> serde::Deserialize<'de>
for Params<'static, 'static, 'static, 'static>
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let owned = ParamsOwned::deserialize(deserializer)?;
Ok(Params::Owned(owned))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParamsOwned {
pub input: super::Input,
pub tasks: Vec<Option<TaskOutputOwned>>,
pub map: Option<super::Input>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ParamsRef<'i, 't, 'to, 'm> {
pub input: &'i super::Input,
pub tasks: &'t [Option<TaskOutput<'to>>],
pub map: Option<&'m super::Input>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum TaskOutput<'a> {
Owned(TaskOutputOwned),
Ref(TaskOutputRef<'a>),
}
impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let owned = TaskOutputOwned::deserialize(deserializer)?;
Ok(TaskOutput::Owned(owned))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum TaskOutputOwned {
Function(FunctionOutput),
MapFunction(Vec<FunctionOutput>),
VectorCompletion(VectorCompletionOutput),
MapVectorCompletion(Vec<VectorCompletionOutput>),
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum TaskOutputRef<'a> {
Function(&'a FunctionOutput),
MapFunction(&'a [FunctionOutput]),
VectorCompletion(&'a VectorCompletionOutput),
MapVectorCompletion(&'a [VectorCompletionOutput]),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorCompletionOutput {
pub votes: Vec<vector::completions::response::Vote>,
pub scores: Vec<rust_decimal::Decimal>,
pub weights: Vec<rust_decimal::Decimal>,
}
impl VectorCompletionOutput {
pub fn default_from_request_responses_len(
request_responses_len: usize,
) -> Self {
let weights = vec![rust_decimal::Decimal::ZERO; request_responses_len];
let scores =
vec![
rust_decimal::Decimal::ONE
/ rust_decimal::Decimal::from(request_responses_len);
request_responses_len
];
Self {
votes: Vec::new(),
scores,
weights,
}
}
}
impl From<vector::completions::response::streaming::VectorCompletionChunk>
for VectorCompletionOutput
{
fn from(
vector::completions::response::streaming::VectorCompletionChunk {
votes,
scores,
weights,
..
}: vector::completions::response::streaming::VectorCompletionChunk,
) -> Self {
VectorCompletionOutput {
votes,
scores,
weights,
}
}
}
impl From<vector::completions::response::unary::VectorCompletion>
for VectorCompletionOutput
{
fn from(
vector::completions::response::unary::VectorCompletion {
votes,
scores,
weights,
..
}: vector::completions::response::unary::VectorCompletion,
) -> Self {
VectorCompletionOutput {
votes,
scores,
weights,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum FunctionOutput {
Scalar(rust_decimal::Decimal),
Vector(Vec<rust_decimal::Decimal>),
Err(serde_json::Value),
}
impl FunctionOutput {
pub fn into_err(self) -> Self {
match self {
Self::Scalar(scalar) => {
Self::Err(serde_json::to_value(scalar).unwrap())
}
Self::Vector(vector) => {
Self::Err(serde_json::to_value(vector).unwrap())
}
Self::Err(err) => Self::Err(err),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompiledFunctionOutput {
pub output: FunctionOutput,
pub valid: bool,
}