Skip to main content

cognis_llm/
streaming.rs

1//! Stream-chunk aggregation utilities.
2//!
3//! Streaming chat surfaces a sequence of [`StreamChunk`]s; agents and
4//! callers often want the *complete* result reconstructed at the end.
5//! [`StreamAggregator`] owns the accumulation: text deltas concatenate,
6//! tool-call deltas merge by index, the final usage / finish_reason are
7//! captured.
8
9use std::collections::HashMap;
10
11use serde_json::Value;
12
13use cognis_core::{AiMessage, Message, ToolCall};
14
15use crate::chat::{StreamChunk, ToolCallDelta, Usage};
16
17/// Accumulates a streaming response into a final [`Message`] plus
18/// aggregated metadata. Chunk ordering matters for text content and for
19/// tool-call argument deltas — feed chunks in the order they arrive.
20#[derive(Debug, Default, Clone)]
21pub struct StreamAggregator {
22    /// Concatenated text content.
23    content: String,
24    /// Per-tool-call accumulators, keyed by chunk `index`.
25    tool_calls: HashMap<u32, ToolCallAccumulator>,
26    /// Reason the stream terminated (last `is_done` chunk's value).
27    finish_reason: Option<String>,
28    /// Final usage stats (last reported).
29    usage: Option<Usage>,
30}
31
32#[derive(Debug, Default, Clone)]
33struct ToolCallAccumulator {
34    /// First-seen `id` for this index — providers send it once.
35    id: Option<String>,
36    /// First-seen `name` — providers send it once.
37    name: Option<String>,
38    /// Concatenated argument fragments (typically a JSON-encoded string).
39    arguments_raw: String,
40}
41
42impl StreamAggregator {
43    /// Construct an empty aggregator.
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    /// Feed a single chunk. Cheap; safe to call inline as chunks arrive.
49    pub fn push(&mut self, chunk: StreamChunk) {
50        if !chunk.content.is_empty() {
51            self.content.push_str(&chunk.content);
52        }
53        for d in chunk.tool_calls_delta {
54            self.merge_tool_delta(d);
55        }
56        if chunk.is_done {
57            if chunk.finish_reason.is_some() {
58                self.finish_reason = chunk.finish_reason;
59            }
60            if chunk.usage.is_some() {
61                self.usage = chunk.usage;
62            }
63        }
64    }
65
66    /// Drain the aggregator into a finalized assistant message + metadata.
67    pub fn finalize(self) -> Aggregated {
68        let mut tool_calls = Vec::with_capacity(self.tool_calls.len());
69        // Stable order: by index ascending.
70        let mut keyed: Vec<(u32, ToolCallAccumulator)> = self.tool_calls.into_iter().collect();
71        keyed.sort_by_key(|(i, _)| *i);
72        for (_, acc) in keyed {
73            let id = acc.id.unwrap_or_default();
74            let name = acc.name.unwrap_or_default();
75            let arguments: Value = if acc.arguments_raw.is_empty() {
76                Value::Null
77            } else {
78                serde_json::from_str(&acc.arguments_raw).unwrap_or(Value::String(acc.arguments_raw))
79            };
80            tool_calls.push(ToolCall {
81                id,
82                name,
83                arguments,
84            });
85        }
86        Aggregated {
87            message: Message::Ai(AiMessage {
88                content: self.content,
89                tool_calls,
90                parts: Vec::new(),
91            }),
92            finish_reason: self.finish_reason,
93            usage: self.usage,
94        }
95    }
96
97    fn merge_tool_delta(&mut self, d: ToolCallDelta) {
98        let entry = self.tool_calls.entry(d.index).or_default();
99        if entry.id.is_none() {
100            entry.id = d.id;
101        }
102        if entry.name.is_none() {
103            entry.name = d.name;
104        }
105        if let Some(frag) = d.arguments_delta {
106            entry.arguments_raw.push_str(&frag);
107        }
108    }
109}
110
111/// Output of [`StreamAggregator::finalize`].
112#[derive(Debug, Clone)]
113pub struct Aggregated {
114    /// The reconstructed assistant message.
115    pub message: Message,
116    /// Reason the stream stopped, if any.
117    pub finish_reason: Option<String>,
118    /// Final usage stats, if reported.
119    pub usage: Option<Usage>,
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    fn text(s: &str) -> StreamChunk {
127        StreamChunk {
128            content: s.into(),
129            is_delta: true,
130            is_done: false,
131            finish_reason: None,
132            usage: None,
133            tool_calls_delta: Vec::new(),
134        }
135    }
136
137    fn done(reason: &str) -> StreamChunk {
138        StreamChunk {
139            content: String::new(),
140            is_delta: false,
141            is_done: true,
142            finish_reason: Some(reason.into()),
143            usage: Some(Usage {
144                prompt_tokens: 5,
145                completion_tokens: 7,
146                total_tokens: 12,
147            }),
148            tool_calls_delta: Vec::new(),
149        }
150    }
151
152    #[test]
153    fn concatenates_text_chunks() {
154        let mut a = StreamAggregator::new();
155        a.push(text("hel"));
156        a.push(text("lo "));
157        a.push(text("world"));
158        a.push(done("stop"));
159        let out = a.finalize();
160        assert_eq!(out.message.content(), "hello world");
161        assert_eq!(out.finish_reason.as_deref(), Some("stop"));
162        assert_eq!(out.usage.unwrap().total_tokens, 12);
163    }
164
165    #[test]
166    fn merges_tool_call_deltas_by_index() {
167        let mut a = StreamAggregator::new();
168        a.push(StreamChunk {
169            content: String::new(),
170            is_delta: true,
171            is_done: false,
172            finish_reason: None,
173            usage: None,
174            tool_calls_delta: vec![ToolCallDelta {
175                index: 0,
176                id: Some("c1".into()),
177                name: Some("search".into()),
178                arguments_delta: Some(r#"{"q":"#.into()),
179            }],
180        });
181        a.push(StreamChunk {
182            content: String::new(),
183            is_delta: true,
184            is_done: false,
185            finish_reason: None,
186            usage: None,
187            tool_calls_delta: vec![ToolCallDelta {
188                index: 0,
189                id: None,
190                name: None,
191                arguments_delta: Some(r#""rust"}"#.into()),
192            }],
193        });
194        a.push(done("tool_calls"));
195        let out = a.finalize();
196        assert_eq!(out.message.tool_calls().len(), 1);
197        let tc = &out.message.tool_calls()[0];
198        assert_eq!(tc.id, "c1");
199        assert_eq!(tc.name, "search");
200        assert_eq!(tc.arguments["q"], "rust");
201    }
202
203    #[test]
204    fn multiple_tool_calls_kept_in_order_by_index() {
205        let mut a = StreamAggregator::new();
206        a.push(StreamChunk {
207            content: String::new(),
208            is_delta: true,
209            is_done: false,
210            finish_reason: None,
211            usage: None,
212            tool_calls_delta: vec![
213                ToolCallDelta {
214                    index: 1,
215                    id: Some("c2".into()),
216                    name: Some("b_tool".into()),
217                    arguments_delta: Some("{}".into()),
218                },
219                ToolCallDelta {
220                    index: 0,
221                    id: Some("c1".into()),
222                    name: Some("a_tool".into()),
223                    arguments_delta: Some("{}".into()),
224                },
225            ],
226        });
227        a.push(done("tool_calls"));
228        let out = a.finalize();
229        let calls = out.message.tool_calls();
230        assert_eq!(calls.len(), 2);
231        assert_eq!(calls[0].name, "a_tool");
232        assert_eq!(calls[1].name, "b_tool");
233    }
234}