1use crate::{
2 utils::audio_utils, AudioFormat, AudioPart, Citation, CitationDelta, ContentDelta, ImagePart,
3 LanguageModelError, LanguageModelResult, ModelResponse, ModelUsage, Part, PartDelta,
4 PartialModelResponse, ReasoningPart, ReasoningPartDelta, TextPart, ToolCallPart,
5 ToolCallPartDelta,
6};
7use serde_json::Value;
8use std::collections::BTreeMap;
9
10#[derive(Debug, Clone)]
12struct AccumulatedTextData {
13 text: String,
14 citations: BTreeMap<usize, CitationDelta>,
15}
16
17#[derive(Debug, Clone)]
19struct AccumulatedImageData {
20 data: String,
21 mime_type: Option<String>,
22 width: Option<u32>,
23 height: Option<u32>,
24 id: Option<String>,
25}
26
27#[derive(Debug, Clone)]
29struct AccumulatedAudioData {
30 data_chunks: Vec<String>,
31 format: Option<AudioFormat>,
32 sample_rate: Option<u32>,
33 channels: Option<u32>,
34 transcript: String,
35 id: Option<String>,
36}
37
38#[derive(Debug, Clone)]
40enum AccumulatedData {
41 Text(AccumulatedTextData),
42 ToolCall(ToolCallPartDelta),
43 Image(AccumulatedImageData),
44 Audio(AccumulatedAudioData),
45 Reasoning(ReasoningPartDelta),
46}
47
48fn initialize_accumulated_data(delta: ContentDelta) -> AccumulatedData {
50 match delta.part {
51 PartDelta::Text(text_delta) => AccumulatedData::Text(AccumulatedTextData {
52 text: text_delta.text,
53 citations: text_delta
54 .citation
55 .map(|citation| {
56 let mut map = BTreeMap::new();
57 map.insert(0, citation);
58 map
59 })
60 .unwrap_or_default(),
61 }),
62 PartDelta::ToolCall(tool_delta) => AccumulatedData::ToolCall(tool_delta),
63 PartDelta::Image(image_delta) => AccumulatedData::Image(AccumulatedImageData {
64 data: image_delta.data.unwrap_or_default(),
65 mime_type: image_delta.mime_type,
66 width: image_delta.width,
67 height: image_delta.height,
68 id: image_delta.id,
69 }),
70 PartDelta::Audio(audio_delta) => AccumulatedData::Audio(AccumulatedAudioData {
71 data_chunks: audio_delta.data.map(|data| vec![data]).unwrap_or_default(),
72 format: audio_delta.format,
73 sample_rate: audio_delta.sample_rate,
74 channels: audio_delta.channels,
75 transcript: audio_delta.transcript.unwrap_or_default(),
76 id: audio_delta.id,
77 }),
78 PartDelta::Reasoning(reasoning_delta) => AccumulatedData::Reasoning(reasoning_delta),
79 }
80}
81
82fn merge_delta(existing: &mut AccumulatedData, delta: ContentDelta) -> Result<(), String> {
84 match (existing, delta.part) {
85 (AccumulatedData::Text(ref mut existing_text), PartDelta::Text(text_delta)) => {
86 existing_text.text.push_str(&text_delta.text);
87 if let Some(citation) = text_delta.citation {
88 let index = existing_text.citations.len();
89 existing_text.citations.insert(index, citation);
90 }
91 }
92 (AccumulatedData::ToolCall(ref mut existing_tool), PartDelta::ToolCall(tool_delta)) => {
93 if let Some(tool_name) = tool_delta.tool_name {
94 existing_tool
95 .tool_name
96 .get_or_insert_default()
97 .push_str(&tool_name);
98 }
99 if tool_delta.tool_call_id.is_some() {
100 existing_tool.tool_call_id = tool_delta.tool_call_id;
101 }
102 if let Some(args) = tool_delta.args {
103 existing_tool.args.get_or_insert_default().push_str(&args);
104 }
105 if tool_delta.id.is_some() {
106 existing_tool.id = tool_delta.id;
107 }
108 }
109 (AccumulatedData::Image(ref mut existing_image), PartDelta::Image(image_delta)) => {
110 if let Some(data) = image_delta.data {
111 existing_image.data.push_str(&data);
112 }
113 if image_delta.mime_type.is_some() {
114 existing_image.mime_type = image_delta.mime_type;
115 }
116 if image_delta.width.is_some() {
117 existing_image.width = image_delta.width;
118 }
119 if image_delta.height.is_some() {
120 existing_image.height = image_delta.height;
121 }
122 if image_delta.id.is_some() {
123 existing_image.id = image_delta.id;
124 }
125 }
126 (AccumulatedData::Audio(ref mut existing_audio), PartDelta::Audio(audio_delta)) => {
127 if let Some(data) = audio_delta.data {
128 existing_audio.data_chunks.push(data);
129 }
130 if audio_delta.format.is_some() {
131 existing_audio.format = audio_delta.format;
132 }
133 if audio_delta.sample_rate.is_some() {
134 existing_audio.sample_rate = audio_delta.sample_rate;
135 }
136 if audio_delta.channels.is_some() {
137 existing_audio.channels = audio_delta.channels;
138 }
139 if let Some(transcript) = audio_delta.transcript {
140 existing_audio.transcript.push_str(&transcript);
141 }
142 if audio_delta.id.is_some() {
143 existing_audio.id = audio_delta.id;
144 }
145 }
146 (
147 AccumulatedData::Reasoning(ref mut existing_reasoning),
148 PartDelta::Reasoning(reasoning_delta),
149 ) => {
150 if let Some(text) = reasoning_delta.text {
151 existing_reasoning
152 .text
153 .get_or_insert_default()
154 .push_str(&text);
155 }
156 if reasoning_delta.signature.is_some() {
157 existing_reasoning.signature = reasoning_delta.signature;
158 }
159 if reasoning_delta.id.is_some() {
160 existing_reasoning.id = reasoning_delta.id;
161 }
162 }
163 _ => Err(format!(
164 "Type mismatch at index {}: existing type doesn't match incoming type",
165 delta.index
166 ))?,
167 }
168
169 Ok(())
170}
171
172fn create_text_part(data: AccumulatedTextData, index: usize) -> LanguageModelResult<Part> {
174 let mut text_part = TextPart {
175 text: data.text,
176 citations: None,
177 };
178
179 if !data.citations.is_empty() {
180 let citation_count = data.citations.len();
181 let mut collected_citations = Vec::with_capacity(citation_count);
182
183 for (_, citation_delta) in data.citations {
184 let CitationDelta {
185 r#type,
186 source,
187 title,
188 cited_text,
189 start_index,
190 end_index,
191 } = citation_delta;
192
193 if !r#type.is_empty() && r#type != "citation" {
194 return Err(LanguageModelError::Invariant(
195 "",
196 format!("Invalid citation type \"{type}\" for text part at index {index}"),
197 ));
198 }
199
200 let source_dbg = source.clone();
201 let start_dbg = start_index;
202 let end_dbg = end_index;
203
204 let (Some(source), Some(start_index), Some(end_index)) =
205 (source, start_index, end_index)
206 else {
207 return Err(LanguageModelError::Invariant(
208 "",
209 format!(
210 "Incomplete citation data for text part at index {index}: \
211 source={source_dbg:?}, start_index={start_dbg:?}, end_index={end_dbg:?}"
212 ),
213 ));
214 };
215
216 collected_citations.push(Citation {
217 source,
218 title,
219 cited_text,
220 start_index,
221 end_index,
222 });
223 }
224
225 if !collected_citations.is_empty() {
226 text_part.citations = Some(collected_citations);
227 }
228 }
229
230 Ok(Part::Text(text_part))
231}
232
233fn parse_tool_call_args(args: &str) -> LanguageModelResult<Value> {
235 if args.trim().is_empty() {
236 return Ok(Value::Object(serde_json::Map::new()));
237 }
238
239 serde_json::from_str(args).map_err(|e| {
240 LanguageModelError::Invariant("", format!("Invalid tool call arguments: {args}: {e}"))
241 })
242}
243
244fn create_tool_call_part(data: ToolCallPartDelta, index: usize) -> LanguageModelResult<Part> {
246 let tool_call_id = data.tool_call_id.ok_or_else(|| {
247 LanguageModelError::Invariant(
248 "",
249 format!("Missing required field tool_call_id at index {index}"),
250 )
251 })?;
252
253 let tool_name = data.tool_name.ok_or_else(|| {
254 LanguageModelError::Invariant(
255 "",
256 format!("Missing required field tool_name at index {index}"),
257 )
258 })?;
259
260 let args = data.args.unwrap_or_default();
261
262 Ok(Part::ToolCall(ToolCallPart {
263 tool_call_id,
264 tool_name,
265 args: parse_tool_call_args(&args)?,
266 id: data.id,
267 }))
268}
269
270fn create_image_part(data: AccumulatedImageData, index: usize) -> LanguageModelResult<Part> {
272 let mime_type = data.mime_type.ok_or_else(|| {
273 LanguageModelError::Invariant(
274 "",
275 format!("Missing required field mime_type for image part at index {index}"),
276 )
277 })?;
278
279 if data.data.is_empty() {
280 return Err(LanguageModelError::Invariant(
281 "",
282 format!("Missing required field data for image part at index {index}"),
283 ));
284 }
285
286 Ok(Part::Image(ImagePart {
287 data: data.data,
288 mime_type,
289 width: data.width,
290 height: data.height,
291 id: data.id,
292 }))
293}
294
295fn create_audio_part(data: AccumulatedAudioData) -> LanguageModelResult<Part> {
297 let format = data.format.ok_or_else(|| {
298 LanguageModelError::Invariant(
299 "",
300 "Missing required field format for audio part".to_string(),
301 )
302 })?;
303
304 if !matches!(format, AudioFormat::Linear16) {
305 return Err(LanguageModelError::NotImplemented(
306 "",
307 format!(
308 "Only linear16 format is supported for audio concatenation. Received: {format:?}"
309 ),
310 ));
311 }
312
313 let concatenated_audio = audio_utils::concatenate_b64_audio_chunks(&data.data_chunks)?;
314
315 Ok(Part::Audio(AudioPart {
316 data: concatenated_audio,
317 format,
318 sample_rate: data.sample_rate,
319 channels: data.channels,
320 transcript: if data.transcript.is_empty() {
321 None
322 } else {
323 Some(data.transcript)
324 },
325 id: data.id,
326 }))
327}
328
329fn create_reasoning_part(data: ReasoningPartDelta) -> Part {
330 let mut reasoning_part = ReasoningPart::new(data.text.unwrap_or_default());
331 if let Some(signature) = data.signature {
332 reasoning_part = reasoning_part.with_signature(signature);
333 }
334 if let Some(id) = data.id {
335 reasoning_part = reasoning_part.with_id(id);
336 }
337 reasoning_part.into()
338}
339
340fn create_part(data: AccumulatedData, index: usize) -> LanguageModelResult<Part> {
342 match data {
343 AccumulatedData::Text(text_data) => create_text_part(text_data, index),
344 AccumulatedData::ToolCall(tool_data) => create_tool_call_part(tool_data, index),
345 AccumulatedData::Image(data) => create_image_part(data, index),
346 AccumulatedData::Audio(data) => create_audio_part(data),
347 AccumulatedData::Reasoning(reasoning_data) => Ok(create_reasoning_part(reasoning_data)),
348 }
349}
350
351pub struct StreamAccumulator {
354 accumulated_parts: BTreeMap<usize, AccumulatedData>,
356 accumulated_usage: Option<ModelUsage>,
358 cost: Option<f64>,
360}
361
362impl StreamAccumulator {
363 #[must_use]
365 pub fn new() -> Self {
366 Self {
367 accumulated_parts: BTreeMap::new(),
368 accumulated_usage: None,
369 cost: None,
370 }
371 }
372
373 pub fn add_partial(&mut self, partial: PartialModelResponse) -> Result<(), String> {
378 if let Some(delta) = partial.delta {
379 self.process_delta(delta.clone())?;
380 }
381 if let Some(usage) = partial.usage {
382 self.process_usage(&usage, partial.cost);
383 }
384 Ok(())
385 }
386
387 pub fn compute_response(self) -> LanguageModelResult<ModelResponse> {
392 let content = self
393 .accumulated_parts
394 .into_iter()
395 .map(|(index, data)| create_part(data, index))
396 .collect::<Result<Vec<_>, _>>()?;
397
398 Ok(ModelResponse {
399 content,
400 cost: self.cost,
401 usage: self.accumulated_usage,
402 })
403 }
404
405 pub fn clear(&mut self) {
407 self.accumulated_parts.clear();
408 }
409
410 #[must_use]
412 pub fn size(&self) -> usize {
413 self.accumulated_parts.len()
414 }
415
416 #[must_use]
418 pub fn is_empty(&self) -> bool {
419 self.accumulated_parts.is_empty()
420 }
421
422 fn process_delta(&mut self, delta: ContentDelta) -> Result<(), String> {
424 let index = delta.index;
425
426 if let Some(existing) = self.accumulated_parts.get_mut(&index) {
427 merge_delta(existing, delta)
428 } else {
429 let accumulated = initialize_accumulated_data(delta);
430 self.accumulated_parts.insert(index, accumulated);
431 Ok(())
432 }
433 }
434
435 fn process_usage(&mut self, usage: &ModelUsage, cost: Option<f64>) {
436 let accumulated_usage = self
437 .accumulated_usage
438 .get_or_insert_with(ModelUsage::default);
439 accumulated_usage.add(usage);
440 if let Some(cost) = cost {
441 self.cost = Some(self.cost.unwrap_or(0.0) + cost);
442 }
443 }
444}
445
446impl Default for StreamAccumulator {
447 fn default() -> Self {
448 Self::new()
449 }
450}