llm_sdk/
accumulator.rs

1use crate::{
2    utils::audio_utils, AudioFormat, AudioPart, Citation, CitationDelta, ContentDelta, ImagePart,
3    LanguageModelError, LanguageModelResult, ModelResponse, ModelUsage, Part, PartDelta,
4    PartialModelResponse, ReasoningPart, ReasoningPartDelta, TextPart, ToolCallPart,
5    ToolCallPartDelta,
6};
7use serde_json::Value;
8use std::collections::BTreeMap;
9
10/// Internal representation of accumulated text data
11#[derive(Debug, Clone)]
12struct AccumulatedTextData {
13    text: String,
14    citations: BTreeMap<usize, CitationDelta>,
15}
16
17/// Internal representation of accumulated image data
18#[derive(Debug, Clone)]
19struct AccumulatedImageData {
20    data: String,
21    mime_type: Option<String>,
22    width: Option<u32>,
23    height: Option<u32>,
24    id: Option<String>,
25}
26
27/// Internal representation of accumulated audio data
28#[derive(Debug, Clone)]
29struct AccumulatedAudioData {
30    data_chunks: Vec<String>,
31    format: Option<AudioFormat>,
32    sample_rate: Option<u32>,
33    channels: Option<u32>,
34    transcript: String,
35    id: Option<String>,
36}
37
38/// Represents accumulated data for different part types
39#[derive(Debug, Clone)]
40enum AccumulatedData {
41    Text(AccumulatedTextData),
42    ToolCall(ToolCallPartDelta),
43    Image(AccumulatedImageData),
44    Audio(AccumulatedAudioData),
45    Reasoning(ReasoningPartDelta),
46}
47
48/// Initializes accumulated data from a delta
49fn initialize_accumulated_data(delta: ContentDelta) -> AccumulatedData {
50    match delta.part {
51        PartDelta::Text(text_delta) => AccumulatedData::Text(AccumulatedTextData {
52            text: text_delta.text,
53            citations: text_delta
54                .citation
55                .map(|citation| {
56                    let mut map = BTreeMap::new();
57                    map.insert(0, citation);
58                    map
59                })
60                .unwrap_or_default(),
61        }),
62        PartDelta::ToolCall(tool_delta) => AccumulatedData::ToolCall(tool_delta),
63        PartDelta::Image(image_delta) => AccumulatedData::Image(AccumulatedImageData {
64            data: image_delta.data.unwrap_or_default(),
65            mime_type: image_delta.mime_type,
66            width: image_delta.width,
67            height: image_delta.height,
68            id: image_delta.id,
69        }),
70        PartDelta::Audio(audio_delta) => AccumulatedData::Audio(AccumulatedAudioData {
71            data_chunks: audio_delta.data.map(|data| vec![data]).unwrap_or_default(),
72            format: audio_delta.format,
73            sample_rate: audio_delta.sample_rate,
74            channels: audio_delta.channels,
75            transcript: audio_delta.transcript.unwrap_or_default(),
76            id: audio_delta.id,
77        }),
78        PartDelta::Reasoning(reasoning_delta) => AccumulatedData::Reasoning(reasoning_delta),
79    }
80}
81
82/// Merges an incoming delta with existing accumulated data
83fn merge_delta(existing: &mut AccumulatedData, delta: ContentDelta) -> Result<(), String> {
84    match (existing, delta.part) {
85        (AccumulatedData::Text(ref mut existing_text), PartDelta::Text(text_delta)) => {
86            existing_text.text.push_str(&text_delta.text);
87            if let Some(citation) = text_delta.citation {
88                let index = existing_text.citations.len();
89                existing_text.citations.insert(index, citation);
90            }
91        }
92        (AccumulatedData::ToolCall(ref mut existing_tool), PartDelta::ToolCall(tool_delta)) => {
93            if let Some(tool_name) = tool_delta.tool_name {
94                existing_tool
95                    .tool_name
96                    .get_or_insert_default()
97                    .push_str(&tool_name);
98            }
99            if tool_delta.tool_call_id.is_some() {
100                existing_tool.tool_call_id = tool_delta.tool_call_id;
101            }
102            if let Some(args) = tool_delta.args {
103                existing_tool.args.get_or_insert_default().push_str(&args);
104            }
105            if tool_delta.id.is_some() {
106                existing_tool.id = tool_delta.id;
107            }
108        }
109        (AccumulatedData::Image(ref mut existing_image), PartDelta::Image(image_delta)) => {
110            if let Some(data) = image_delta.data {
111                existing_image.data.push_str(&data);
112            }
113            if image_delta.mime_type.is_some() {
114                existing_image.mime_type = image_delta.mime_type;
115            }
116            if image_delta.width.is_some() {
117                existing_image.width = image_delta.width;
118            }
119            if image_delta.height.is_some() {
120                existing_image.height = image_delta.height;
121            }
122            if image_delta.id.is_some() {
123                existing_image.id = image_delta.id;
124            }
125        }
126        (AccumulatedData::Audio(ref mut existing_audio), PartDelta::Audio(audio_delta)) => {
127            if let Some(data) = audio_delta.data {
128                existing_audio.data_chunks.push(data);
129            }
130            if audio_delta.format.is_some() {
131                existing_audio.format = audio_delta.format;
132            }
133            if audio_delta.sample_rate.is_some() {
134                existing_audio.sample_rate = audio_delta.sample_rate;
135            }
136            if audio_delta.channels.is_some() {
137                existing_audio.channels = audio_delta.channels;
138            }
139            if let Some(transcript) = audio_delta.transcript {
140                existing_audio.transcript.push_str(&transcript);
141            }
142            if audio_delta.id.is_some() {
143                existing_audio.id = audio_delta.id;
144            }
145        }
146        (
147            AccumulatedData::Reasoning(ref mut existing_reasoning),
148            PartDelta::Reasoning(reasoning_delta),
149        ) => {
150            if let Some(text) = reasoning_delta.text {
151                existing_reasoning
152                    .text
153                    .get_or_insert_default()
154                    .push_str(&text);
155            }
156            if reasoning_delta.signature.is_some() {
157                existing_reasoning.signature = reasoning_delta.signature;
158            }
159            if reasoning_delta.id.is_some() {
160                existing_reasoning.id = reasoning_delta.id;
161            }
162        }
163        _ => Err(format!(
164            "Type mismatch at index {}: existing type doesn't match incoming type",
165            delta.index
166        ))?,
167    }
168
169    Ok(())
170}
171
172/// Creates a text part from accumulated text data
173fn create_text_part(data: AccumulatedTextData, index: usize) -> LanguageModelResult<Part> {
174    let mut text_part = TextPart {
175        text: data.text,
176        citations: None,
177    };
178
179    if !data.citations.is_empty() {
180        let citation_count = data.citations.len();
181        let mut collected_citations = Vec::with_capacity(citation_count);
182
183        for (_, citation_delta) in data.citations {
184            let CitationDelta {
185                r#type,
186                source,
187                title,
188                cited_text,
189                start_index,
190                end_index,
191            } = citation_delta;
192
193            if !r#type.is_empty() && r#type != "citation" {
194                return Err(LanguageModelError::Invariant(
195                    "",
196                    format!("Invalid citation type \"{type}\" for text part at index {index}"),
197                ));
198            }
199
200            let source_dbg = source.clone();
201            let start_dbg = start_index;
202            let end_dbg = end_index;
203
204            let (Some(source), Some(start_index), Some(end_index)) =
205                (source, start_index, end_index)
206            else {
207                return Err(LanguageModelError::Invariant(
208                    "",
209                    format!(
210                        "Incomplete citation data for text part at index {index}: \
211                         source={source_dbg:?}, start_index={start_dbg:?}, end_index={end_dbg:?}"
212                    ),
213                ));
214            };
215
216            collected_citations.push(Citation {
217                source,
218                title,
219                cited_text,
220                start_index,
221                end_index,
222            });
223        }
224
225        if !collected_citations.is_empty() {
226            text_part.citations = Some(collected_citations);
227        }
228    }
229
230    Ok(Part::Text(text_part))
231}
232
233/// Parses tool call arguments from JSON string
234fn parse_tool_call_args(args: &str) -> LanguageModelResult<Value> {
235    if args.trim().is_empty() {
236        return Ok(Value::Object(serde_json::Map::new()));
237    }
238
239    serde_json::from_str(args).map_err(|e| {
240        LanguageModelError::Invariant("", format!("Invalid tool call arguments: {args}: {e}"))
241    })
242}
243
244/// Creates a tool call part from accumulated tool call data
245fn create_tool_call_part(data: ToolCallPartDelta, index: usize) -> LanguageModelResult<Part> {
246    let tool_call_id = data.tool_call_id.ok_or_else(|| {
247        LanguageModelError::Invariant(
248            "",
249            format!("Missing required field tool_call_id at index {index}"),
250        )
251    })?;
252
253    let tool_name = data.tool_name.ok_or_else(|| {
254        LanguageModelError::Invariant(
255            "",
256            format!("Missing required field tool_name at index {index}"),
257        )
258    })?;
259
260    let args = data.args.unwrap_or_default();
261
262    Ok(Part::ToolCall(ToolCallPart {
263        tool_call_id,
264        tool_name,
265        args: parse_tool_call_args(&args)?,
266        id: data.id,
267    }))
268}
269
270/// Creates an image part from accumulated image data
271fn create_image_part(data: AccumulatedImageData, index: usize) -> LanguageModelResult<Part> {
272    let mime_type = data.mime_type.ok_or_else(|| {
273        LanguageModelError::Invariant(
274            "",
275            format!("Missing required field mime_type for image part at index {index}"),
276        )
277    })?;
278
279    if data.data.is_empty() {
280        return Err(LanguageModelError::Invariant(
281            "",
282            format!("Missing required field data for image part at index {index}"),
283        ));
284    }
285
286    Ok(Part::Image(ImagePart {
287        data: data.data,
288        mime_type,
289        width: data.width,
290        height: data.height,
291        id: data.id,
292    }))
293}
294
295/// Creates an audio part from accumulated audio data
296fn create_audio_part(data: AccumulatedAudioData) -> LanguageModelResult<Part> {
297    let format = data.format.ok_or_else(|| {
298        LanguageModelError::Invariant(
299            "",
300            "Missing required field format for audio part".to_string(),
301        )
302    })?;
303
304    if !matches!(format, AudioFormat::Linear16) {
305        return Err(LanguageModelError::NotImplemented(
306            "",
307            format!(
308                "Only linear16 format is supported for audio concatenation. Received: {format:?}"
309            ),
310        ));
311    }
312
313    let concatenated_audio = audio_utils::concatenate_b64_audio_chunks(&data.data_chunks)?;
314
315    Ok(Part::Audio(AudioPart {
316        data: concatenated_audio,
317        format,
318        sample_rate: data.sample_rate,
319        channels: data.channels,
320        transcript: if data.transcript.is_empty() {
321            None
322        } else {
323            Some(data.transcript)
324        },
325        id: data.id,
326    }))
327}
328
329fn create_reasoning_part(data: ReasoningPartDelta) -> Part {
330    let mut reasoning_part = ReasoningPart::new(data.text.unwrap_or_default());
331    if let Some(signature) = data.signature {
332        reasoning_part = reasoning_part.with_signature(signature);
333    }
334    if let Some(id) = data.id {
335        reasoning_part = reasoning_part.with_id(id);
336    }
337    reasoning_part.into()
338}
339
340/// Creates a final Part from accumulated data
341fn create_part(data: AccumulatedData, index: usize) -> LanguageModelResult<Part> {
342    match data {
343        AccumulatedData::Text(text_data) => create_text_part(text_data, index),
344        AccumulatedData::ToolCall(tool_data) => create_tool_call_part(tool_data, index),
345        AccumulatedData::Image(data) => create_image_part(data, index),
346        AccumulatedData::Audio(data) => create_audio_part(data),
347        AccumulatedData::Reasoning(reasoning_data) => Ok(create_reasoning_part(reasoning_data)),
348    }
349}
350
351/// Manages the accumulation and merging of content deltas for streaming
352/// responses
353pub struct StreamAccumulator {
354    /// Map of index to accumulated data, using `BTreeMap` for automatic sorting
355    accumulated_parts: BTreeMap<usize, AccumulatedData>,
356    /// Accumulated usage statistics
357    accumulated_usage: Option<ModelUsage>,
358    /// Accumulated cost
359    cost: Option<f64>,
360}
361
362impl StreamAccumulator {
363    /// Creates a new `StreamAccumulator`
364    #[must_use]
365    pub fn new() -> Self {
366        Self {
367            accumulated_parts: BTreeMap::new(),
368            accumulated_usage: None,
369            cost: None,
370        }
371    }
372
373    /// Adds a chunk of content deltas to the accumulator
374    ///
375    /// # Errors
376    /// Returns an error if delta types mismatch for the same index
377    pub fn add_partial(&mut self, partial: PartialModelResponse) -> Result<(), String> {
378        if let Some(delta) = partial.delta {
379            self.process_delta(delta.clone())?;
380        }
381        if let Some(usage) = partial.usage {
382            self.process_usage(&usage, partial.cost);
383        }
384        Ok(())
385    }
386
387    /// Computes the final response from accumulated deltas
388    ///
389    /// # Errors
390    /// Returns an error if required fields are missing or format is unsupported
391    pub fn compute_response(self) -> LanguageModelResult<ModelResponse> {
392        let content = self
393            .accumulated_parts
394            .into_iter()
395            .map(|(index, data)| create_part(data, index))
396            .collect::<Result<Vec<_>, _>>()?;
397
398        Ok(ModelResponse {
399            content,
400            cost: self.cost,
401            usage: self.accumulated_usage,
402        })
403    }
404
405    /// Clears all accumulated data
406    pub fn clear(&mut self) {
407        self.accumulated_parts.clear();
408    }
409
410    /// Gets the number of accumulated parts
411    #[must_use]
412    pub fn size(&self) -> usize {
413        self.accumulated_parts.len()
414    }
415
416    /// Checks if the accumulator has any data
417    #[must_use]
418    pub fn is_empty(&self) -> bool {
419        self.accumulated_parts.is_empty()
420    }
421
422    /// Processes a single delta, either merging with existing or creating new
423    fn process_delta(&mut self, delta: ContentDelta) -> Result<(), String> {
424        let index = delta.index;
425
426        if let Some(existing) = self.accumulated_parts.get_mut(&index) {
427            merge_delta(existing, delta)
428        } else {
429            let accumulated = initialize_accumulated_data(delta);
430            self.accumulated_parts.insert(index, accumulated);
431            Ok(())
432        }
433    }
434
435    fn process_usage(&mut self, usage: &ModelUsage, cost: Option<f64>) {
436        let accumulated_usage = self
437            .accumulated_usage
438            .get_or_insert_with(ModelUsage::default);
439        accumulated_usage.add(usage);
440        if let Some(cost) = cost {
441            self.cost = Some(self.cost.unwrap_or(0.0) + cost);
442        }
443    }
444}
445
446impl Default for StreamAccumulator {
447    fn default() -> Self {
448        Self::new()
449    }
450}