use crate::agent::completions::response;
use serde::{Deserialize, Serialize};
use schemars::JsonSchema;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, JsonSchema, arbitrary::Arbitrary)]
#[schemars(rename = "agent.completions.response.streaming.AgentCompletionChunk")]
pub struct AgentCompletionChunk {
pub id: String,
#[arbitrary(with = crate::arbitrary_util::arbitrary_u64)]
pub created: u64,
pub messages: Vec<super::MessageChunk>,
pub object: super::Object,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub usage: Option<response::Usage>,
pub upstream: crate::agent::Upstream,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub error: Option<crate::error::ResponseError>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub continuation: Option<String>,
}
impl AgentCompletionChunk {
pub fn push(
&mut self,
AgentCompletionChunk {
messages, usage, error, continuation, ..
}: &AgentCompletionChunk,
) {
self.push_messages(messages);
match (&mut self.usage, usage) {
(Some(self_usage), Some(other_usage)) => {
self_usage.push(other_usage);
}
(None, Some(other_usage)) => {
self.usage = Some(other_usage.clone());
}
_ => {}
}
if let Some(error) = error {
self.error = Some(error.clone());
}
if let Some(continuation) = continuation {
self.continuation = Some(continuation.clone());
}
}
#[cfg(feature = "filesystem")]
pub fn produce_files(&self) -> Option<(serde_json::Value, Vec<crate::filesystem::logs::LogFile>)> {
use crate::filesystem::logs::LogFile;
const ROUTE: &str = "agents/completions";
let id = &self.id;
if id.is_empty() {
return None;
}
let mut files: Vec<LogFile> = Vec::new();
let mut message_refs: Vec<serde_json::Value> = Vec::new();
for msg in &self.messages {
let (reference, msg_files) = msg.produce_files(id, ROUTE);
message_refs.push(reference);
files.extend(msg_files);
}
let shell = AgentCompletionChunk {
id: self.id.clone(),
created: self.created,
messages: Vec::new(),
object: self.object,
usage: self.usage.clone(),
upstream: self.upstream,
error: self.error.clone(),
continuation: Some(String::new()),
};
let mut root = serde_json::to_value(&shell).unwrap();
root["messages"] = serde_json::Value::Array(message_refs);
if let Some(continuation) = &self.continuation {
let cont_file = LogFile {
route: format!("{ROUTE}/continuation"),
id: id.clone(),
message_index: None,
media_index: None,
extension: "json".to_string(),
content: serde_json::to_vec_pretty(continuation).unwrap(),
};
root["continuation"] = serde_json::json!({
"type": "reference",
"path": cont_file.path(),
});
files.push(cont_file);
} else if let Some(map) = root.as_object_mut() {
map.remove("continuation");
}
let root_file = LogFile {
route: ROUTE.to_string(),
id: id.clone(),
message_index: None,
media_index: None,
extension: "json".to_string(),
content: serde_json::to_vec_pretty(&root).unwrap(),
};
let reference = serde_json::json!({ "type": "reference", "path": root_file.path() });
files.push(root_file);
Some((reference, files))
}
fn push_messages(&mut self, other_choices: &[super::MessageChunk]) {
fn push_message(
messages: &mut Vec<super::MessageChunk>,
other: &super::MessageChunk,
) {
fn find_message(
messages: &mut Vec<super::MessageChunk>,
index: u64,
) -> Option<&mut super::MessageChunk> {
for message in messages {
if message.index() == index {
return Some(message);
}
}
None
}
if let Some(message) = find_message(messages, other.index()) {
message.push(other);
} else {
messages.push(other.clone());
}
}
for other_message in other_choices {
push_message(&mut self.messages, other_message);
}
}
}