use crate::{
utils::audio_utils, AudioFormat, AudioPart, Citation, CitationDelta, ContentDelta, ImagePart,
LanguageModelError, LanguageModelResult, ModelResponse, ModelUsage, Part, PartDelta,
PartialModelResponse, ReasoningPart, ReasoningPartDelta, TextPart, ToolCallPart,
ToolCallPartDelta,
};
use serde_json::Value;
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
struct AccumulatedTextData {
text: String,
citations: BTreeMap<usize, CitationDelta>,
}
#[derive(Debug, Clone)]
struct AccumulatedImageData {
data: String,
mime_type: Option<String>,
width: Option<u32>,
height: Option<u32>,
id: Option<String>,
}
#[derive(Debug, Clone)]
struct AccumulatedAudioData {
data_chunks: Vec<String>,
format: Option<AudioFormat>,
sample_rate: Option<u32>,
channels: Option<u32>,
transcript: String,
id: Option<String>,
}
#[derive(Debug, Clone)]
enum AccumulatedData {
Text(AccumulatedTextData),
ToolCall(ToolCallPartDelta),
Image(AccumulatedImageData),
Audio(AccumulatedAudioData),
Reasoning(ReasoningPartDelta),
}
fn initialize_accumulated_data(delta: ContentDelta) -> AccumulatedData {
match delta.part {
PartDelta::Text(text_delta) => AccumulatedData::Text(AccumulatedTextData {
text: text_delta.text,
citations: text_delta
.citation
.map(|citation| {
let mut map = BTreeMap::new();
map.insert(0, citation);
map
})
.unwrap_or_default(),
}),
PartDelta::ToolCall(tool_delta) => AccumulatedData::ToolCall(tool_delta),
PartDelta::Image(image_delta) => AccumulatedData::Image(AccumulatedImageData {
data: image_delta.data.unwrap_or_default(),
mime_type: image_delta.mime_type,
width: image_delta.width,
height: image_delta.height,
id: image_delta.id,
}),
PartDelta::Audio(audio_delta) => AccumulatedData::Audio(AccumulatedAudioData {
data_chunks: audio_delta.data.map(|data| vec![data]).unwrap_or_default(),
format: audio_delta.format,
sample_rate: audio_delta.sample_rate,
channels: audio_delta.channels,
transcript: audio_delta.transcript.unwrap_or_default(),
id: audio_delta.id,
}),
PartDelta::Reasoning(reasoning_delta) => AccumulatedData::Reasoning(reasoning_delta),
}
}
fn merge_delta(existing: &mut AccumulatedData, delta: ContentDelta) -> Result<(), String> {
match (existing, delta.part) {
(AccumulatedData::Text(ref mut existing_text), PartDelta::Text(text_delta)) => {
existing_text.text.push_str(&text_delta.text);
if let Some(citation) = text_delta.citation {
let index = existing_text.citations.len();
existing_text.citations.insert(index, citation);
}
}
(AccumulatedData::ToolCall(ref mut existing_tool), PartDelta::ToolCall(tool_delta)) => {
if let Some(tool_name) = tool_delta.tool_name {
existing_tool
.tool_name
.get_or_insert_default()
.push_str(&tool_name);
}
if tool_delta.tool_call_id.is_some() {
existing_tool.tool_call_id = tool_delta.tool_call_id;
}
if let Some(args) = tool_delta.args {
existing_tool.args.get_or_insert_default().push_str(&args);
}
if tool_delta.signature.is_some() {
existing_tool.signature = tool_delta.signature;
}
if tool_delta.id.is_some() {
existing_tool.id = tool_delta.id;
}
}
(AccumulatedData::Image(ref mut existing_image), PartDelta::Image(image_delta)) => {
if let Some(data) = image_delta.data {
existing_image.data.push_str(&data);
}
if image_delta.mime_type.is_some() {
existing_image.mime_type = image_delta.mime_type;
}
if image_delta.width.is_some() {
existing_image.width = image_delta.width;
}
if image_delta.height.is_some() {
existing_image.height = image_delta.height;
}
if image_delta.id.is_some() {
existing_image.id = image_delta.id;
}
}
(AccumulatedData::Audio(ref mut existing_audio), PartDelta::Audio(audio_delta)) => {
if let Some(data) = audio_delta.data {
existing_audio.data_chunks.push(data);
}
if audio_delta.format.is_some() {
existing_audio.format = audio_delta.format;
}
if audio_delta.sample_rate.is_some() {
existing_audio.sample_rate = audio_delta.sample_rate;
}
if audio_delta.channels.is_some() {
existing_audio.channels = audio_delta.channels;
}
if let Some(transcript) = audio_delta.transcript {
existing_audio.transcript.push_str(&transcript);
}
if audio_delta.id.is_some() {
existing_audio.id = audio_delta.id;
}
}
(
AccumulatedData::Reasoning(ref mut existing_reasoning),
PartDelta::Reasoning(reasoning_delta),
) => {
if let Some(text) = reasoning_delta.text {
existing_reasoning
.text
.get_or_insert_default()
.push_str(&text);
}
if reasoning_delta.signature.is_some() {
existing_reasoning.signature = reasoning_delta.signature;
}
if reasoning_delta.id.is_some() {
existing_reasoning.id = reasoning_delta.id;
}
}
_ => Err(format!(
"Type mismatch at index {}: existing type doesn't match incoming type",
delta.index
))?,
}
Ok(())
}
fn create_text_part(data: AccumulatedTextData, index: usize) -> LanguageModelResult<Part> {
let mut text_part = TextPart {
text: data.text,
citations: None,
};
if !data.citations.is_empty() {
let citation_count = data.citations.len();
let mut collected_citations = Vec::with_capacity(citation_count);
for (_, citation_delta) in data.citations {
let CitationDelta {
r#type,
source,
title,
cited_text,
start_index,
end_index,
} = citation_delta;
if !r#type.is_empty() && r#type != "citation" {
return Err(LanguageModelError::Invariant(
"",
format!("Invalid citation type \"{type}\" for text part at index {index}"),
));
}
let source_dbg = source.clone();
let start_dbg = start_index;
let end_dbg = end_index;
let (Some(source), Some(start_index), Some(end_index)) =
(source, start_index, end_index)
else {
return Err(LanguageModelError::Invariant(
"",
format!(
"Incomplete citation data for text part at index {index}: \
source={source_dbg:?}, start_index={start_dbg:?}, end_index={end_dbg:?}"
),
));
};
collected_citations.push(Citation {
source,
title,
cited_text,
start_index,
end_index,
});
}
if !collected_citations.is_empty() {
text_part.citations = Some(collected_citations);
}
}
Ok(Part::Text(text_part))
}
fn parse_tool_call_args(args: &str) -> LanguageModelResult<Value> {
if args.trim().is_empty() {
return Ok(Value::Object(serde_json::Map::new()));
}
serde_json::from_str(args).map_err(|e| {
LanguageModelError::Invariant("", format!("Invalid tool call arguments: {args}: {e}"))
})
}
fn create_tool_call_part(data: ToolCallPartDelta, index: usize) -> LanguageModelResult<Part> {
let tool_call_id = data.tool_call_id.ok_or_else(|| {
LanguageModelError::Invariant(
"",
format!("Missing required field tool_call_id at index {index}"),
)
})?;
let tool_name = data.tool_name.ok_or_else(|| {
LanguageModelError::Invariant(
"",
format!("Missing required field tool_name at index {index}"),
)
})?;
let args = data.args.unwrap_or_default();
Ok(Part::ToolCall(ToolCallPart {
tool_call_id,
tool_name,
args: parse_tool_call_args(&args)?,
signature: data.signature,
id: data.id,
}))
}
fn create_image_part(data: AccumulatedImageData, index: usize) -> LanguageModelResult<Part> {
let mime_type = data.mime_type.ok_or_else(|| {
LanguageModelError::Invariant(
"",
format!("Missing required field mime_type for image part at index {index}"),
)
})?;
if data.data.is_empty() {
return Err(LanguageModelError::Invariant(
"",
format!("Missing required field data for image part at index {index}"),
));
}
Ok(Part::Image(ImagePart {
data: data.data,
mime_type,
width: data.width,
height: data.height,
id: data.id,
}))
}
fn create_audio_part(data: AccumulatedAudioData) -> LanguageModelResult<Part> {
let format = data.format.ok_or_else(|| {
LanguageModelError::Invariant(
"",
"Missing required field format for audio part".to_string(),
)
})?;
if !matches!(format, AudioFormat::Linear16) {
return Err(LanguageModelError::NotImplemented(
"",
format!(
"Only linear16 format is supported for audio concatenation. Received: {format:?}"
),
));
}
let concatenated_audio = audio_utils::concatenate_b64_audio_chunks(&data.data_chunks)?;
Ok(Part::Audio(AudioPart {
data: concatenated_audio,
format,
sample_rate: data.sample_rate,
channels: data.channels,
transcript: if data.transcript.is_empty() {
None
} else {
Some(data.transcript)
},
id: data.id,
}))
}
fn create_reasoning_part(data: ReasoningPartDelta) -> Part {
let mut reasoning_part = ReasoningPart::new(data.text.unwrap_or_default());
if let Some(signature) = data.signature {
reasoning_part = reasoning_part.with_signature(signature);
}
if let Some(id) = data.id {
reasoning_part = reasoning_part.with_id(id);
}
reasoning_part.into()
}
fn create_part(data: AccumulatedData, index: usize) -> LanguageModelResult<Part> {
match data {
AccumulatedData::Text(text_data) => create_text_part(text_data, index),
AccumulatedData::ToolCall(tool_data) => create_tool_call_part(tool_data, index),
AccumulatedData::Image(data) => create_image_part(data, index),
AccumulatedData::Audio(data) => create_audio_part(data),
AccumulatedData::Reasoning(reasoning_data) => Ok(create_reasoning_part(reasoning_data)),
}
}
pub struct StreamAccumulator {
accumulated_parts: BTreeMap<usize, AccumulatedData>,
accumulated_usage: Option<ModelUsage>,
cost: Option<f64>,
}
impl StreamAccumulator {
#[must_use]
pub fn new() -> Self {
Self {
accumulated_parts: BTreeMap::new(),
accumulated_usage: None,
cost: None,
}
}
pub fn add_partial(&mut self, partial: PartialModelResponse) -> Result<(), String> {
if let Some(delta) = partial.delta {
self.process_delta(delta.clone())?;
}
if let Some(usage) = partial.usage {
self.process_usage(&usage, partial.cost);
}
Ok(())
}
pub fn compute_response(self) -> LanguageModelResult<ModelResponse> {
let content = self
.accumulated_parts
.into_iter()
.map(|(index, data)| create_part(data, index))
.collect::<Result<Vec<_>, _>>()?;
Ok(ModelResponse {
content,
cost: self.cost,
usage: self.accumulated_usage,
})
}
pub fn clear(&mut self) {
self.accumulated_parts.clear();
}
#[must_use]
pub fn size(&self) -> usize {
self.accumulated_parts.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.accumulated_parts.is_empty()
}
fn process_delta(&mut self, delta: ContentDelta) -> Result<(), String> {
let index = delta.index;
if let Some(existing) = self.accumulated_parts.get_mut(&index) {
merge_delta(existing, delta)
} else {
let accumulated = initialize_accumulated_data(delta);
self.accumulated_parts.insert(index, accumulated);
Ok(())
}
}
fn process_usage(&mut self, usage: &ModelUsage, cost: Option<f64>) {
let accumulated_usage = self
.accumulated_usage
.get_or_insert_with(ModelUsage::default);
accumulated_usage.add(usage);
if let Some(cost) = cost {
self.cost = Some(self.cost.unwrap_or(0.0) + cost);
}
}
}
impl Default for StreamAccumulator {
fn default() -> Self {
Self::new()
}
}