use crate::{agent, vector::completions::response};
use serde::{Deserialize, Serialize};
use schemars::JsonSchema;
#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)]
#[schemars(rename = "vector.completions.response.unary.VectorCompletion")]
pub struct VectorCompletion {
pub id: String,
pub completions: Vec<super::AgentCompletion>,
pub votes: Vec<response::Vote>,
#[serde(deserialize_with = "crate::serde_util::vec_decimal")]
#[schemars(with = "Vec<f64>")]
pub scores: Vec<rust_decimal::Decimal>,
#[serde(deserialize_with = "crate::serde_util::vec_decimal")]
#[schemars(with = "Vec<f64>")]
pub weights: Vec<rust_decimal::Decimal>,
pub created: u64,
pub swarm: String,
pub object: super::Object,
pub usage: agent::completions::response::Usage,
}
impl VectorCompletion {
pub fn normalize_for_tests(&mut self) {
self.id = String::new();
self.created = 0;
for completion in &mut self.completions {
completion.inner.normalize_for_tests();
}
self.votes.sort_by_key(|v| v.flat_swarm_index);
self.completions.sort_by_cached_key(|c| serde_json::to_string(&c.inner).unwrap());
let mut i = 0;
for completion in &mut self.completions {
completion.index = i;
i += 1;
}
}
pub fn default_from_request_responses_len(
request_responses_len: usize,
) -> Self {
let weights = vec![rust_decimal::Decimal::ZERO; request_responses_len];
let scores =
vec![
rust_decimal::Decimal::ONE
/ rust_decimal::Decimal::from(request_responses_len);
request_responses_len
];
Self {
id: String::new(),
completions: Vec::new(),
votes: Vec::new(),
scores,
weights,
created: 0,
swarm: String::new(),
object: super::Object::default(),
usage: agent::completions::response::Usage::default(),
}
}
}
impl From<response::streaming::VectorCompletionChunk> for VectorCompletion {
fn from(
response::streaming::VectorCompletionChunk {
id,
completions,
votes,
scores,
weights,
created,
swarm,
object,
usage,
}: response::streaming::VectorCompletionChunk,
) -> Self {
Self {
id,
completions: completions
.into_iter()
.map(super::AgentCompletion::from)
.collect(),
votes,
scores,
weights,
created,
swarm,
object: object.into(),
usage: usage.unwrap_or_default(),
}
}
}