use crate::agent::completions::{message, response};
use serde::{Deserialize, Serialize};
use schemars::JsonSchema;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, JsonSchema, arbitrary::Arbitrary)]
#[schemars(rename = "agent.completions.response.streaming.AssistantResponseChunk")]
pub struct AssistantResponseChunk {
pub role: response::AssistantRole,
#[arbitrary(with = crate::arbitrary_util::arbitrary_u64)]
pub index: u64,
#[arbitrary(with = crate::arbitrary_util::arbitrary_u64)]
pub created: u64,
pub agent: String,
pub model: String,
pub upstream_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub reasoning: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub tool_calls: Option<Vec<message::AssistantToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub content: Option<message::RichContent>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub refusal: Option<String>,
pub finish_reason: Option<response::FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub logprobs: Option<response::Logprobs>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub usage: Option<response::UpstreamUsage>,
}
impl AssistantResponseChunk {
pub fn push(
&mut self,
AssistantResponseChunk {
reasoning,
tool_calls,
content,
refusal,
finish_reason,
logprobs,
upstream_id,
service_tier,
system_fingerprint,
provider,
usage,
..
}: &AssistantResponseChunk,
) {
response::util::push_option_string(&mut self.reasoning, reasoning);
self.push_tool_calls(tool_calls);
match (&mut self.content, content) {
(Some(self_content), Some(other_content)) => {
self_content.push(other_content);
}
(None, Some(other_content)) => {
self.content = Some(other_content.clone());
}
_ => {}
}
response::util::push_option_string(&mut self.refusal, refusal);
if self.finish_reason.is_none() {
self.finish_reason = finish_reason.clone();
}
match (&mut self.logprobs, logprobs) {
(Some(self_logprobs), Some(other_logprobs)) => {
self_logprobs.push(other_logprobs);
}
(None, Some(other_logprobs)) => {
self.logprobs = Some(other_logprobs.clone());
}
_ => {}
}
if self.upstream_id.is_empty() {
self.upstream_id = upstream_id.clone();
}
if self.service_tier.is_none() {
self.service_tier = service_tier.clone();
}
if self.system_fingerprint.is_none() {
self.system_fingerprint = system_fingerprint.clone();
}
if self.provider.is_none() {
self.provider = provider.clone();
}
match (&mut self.usage, usage) {
(Some(self_usage), Some(other_usage)) => {
self_usage.push(other_usage);
}
(None, Some(other_usage)) => {
self.usage = Some(other_usage.clone());
}
_ => {}
}
}
fn push_tool_calls(
&mut self,
other_tool_calls: &Option<Vec<message::AssistantToolCallDelta>>,
) {
fn push_tool_call(
tool_calls: &mut Vec<message::AssistantToolCallDelta>,
other: &message::AssistantToolCallDelta,
) {
fn find_tool_call(
tool_calls: &mut Vec<message::AssistantToolCallDelta>,
index: u64,
) -> Option<&mut message::AssistantToolCallDelta> {
for tool_call in tool_calls {
if tool_call.index == index {
return Some(tool_call);
}
}
None
}
if let Some(tool_call) = find_tool_call(tool_calls, other.index) {
tool_call.push(other);
} else {
tool_calls.push(other.clone());
}
}
match (self.tool_calls.as_mut(), other_tool_calls) {
(Some(self_tool_calls), Some(other_tool_calls)) => {
for other_tool_call in other_tool_calls {
push_tool_call(self_tool_calls, other_tool_call);
}
}
(None, Some(other_tool_calls)) => {
self.tool_calls = Some(other_tool_calls.clone());
}
_ => {}
}
}
#[cfg(feature = "filesystem")]
pub fn produce_files(&self, id: &str, route_base: &str) -> (serde_json::Value, Vec<crate::filesystem::logs::LogFile>) {
use crate::filesystem::logs::LogFile;
let mut files = Vec::new();
let shell = AssistantResponseChunk {
role: self.role,
index: self.index,
created: self.created,
agent: self.agent.clone(),
model: self.model.clone(),
upstream_id: self.upstream_id.clone(),
reasoning: self.reasoning.clone(),
tool_calls: self.tool_calls.clone(),
content: Some(message::RichContent::Text(String::new())),
refusal: self.refusal.clone(),
finish_reason: self.finish_reason.clone(),
logprobs: Some(response::Logprobs::default()),
service_tier: self.service_tier.clone(),
system_fingerprint: self.system_fingerprint.clone(),
provider: self.provider.clone(),
usage: self.usage.clone(),
};
let mut msg_json = serde_json::to_value(&shell).unwrap();
if let Some(logprobs) = &self.logprobs {
let logprobs_file = LogFile {
route: format!("{route_base}/messages/logprobs"),
id: id.to_string(),
message_index: Some(self.index),
media_index: None,
extension: "json".to_string(),
content: serde_json::to_vec_pretty(logprobs).unwrap(),
};
msg_json["logprobs"] = serde_json::json!({
"type": "reference",
"path": logprobs_file.path(),
});
files.push(logprobs_file);
} else if let Some(map) = msg_json.as_object_mut() {
map.remove("logprobs");
}
if let Some(mut content) = self.content.clone() {
content.prepare();
let (content_json, media_files) = content.extract_media(route_base, id, self.index);
msg_json["content"] = content_json;
files.extend(media_files);
} else if let Some(map) = msg_json.as_object_mut() {
map.remove("content");
}
let msg_file = LogFile {
route: format!("{route_base}/messages"),
id: id.to_string(),
message_index: Some(self.index),
media_index: None,
extension: "json".to_string(),
content: serde_json::to_vec_pretty(&msg_json).unwrap(),
};
let reference = serde_json::json!({ "type": "reference", "path": msg_file.path() });
files.push(msg_file);
(reference, files)
}
}