1use std::collections::HashMap;
10
11use serde_json::Value;
12
13use cognis_core::{AiMessage, Message, ToolCall};
14
15use crate::chat::{StreamChunk, ToolCallDelta, Usage};
16
17#[derive(Debug, Default, Clone)]
21pub struct StreamAggregator {
22 content: String,
24 tool_calls: HashMap<u32, ToolCallAccumulator>,
26 finish_reason: Option<String>,
28 usage: Option<Usage>,
30}
31
32#[derive(Debug, Default, Clone)]
33struct ToolCallAccumulator {
34 id: Option<String>,
36 name: Option<String>,
38 arguments_raw: String,
40}
41
42impl StreamAggregator {
43 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn push(&mut self, chunk: StreamChunk) {
50 if !chunk.content.is_empty() {
51 self.content.push_str(&chunk.content);
52 }
53 for d in chunk.tool_calls_delta {
54 self.merge_tool_delta(d);
55 }
56 if chunk.is_done {
57 if chunk.finish_reason.is_some() {
58 self.finish_reason = chunk.finish_reason;
59 }
60 if chunk.usage.is_some() {
61 self.usage = chunk.usage;
62 }
63 }
64 }
65
66 pub fn finalize(self) -> Aggregated {
68 let mut tool_calls = Vec::with_capacity(self.tool_calls.len());
69 let mut keyed: Vec<(u32, ToolCallAccumulator)> = self.tool_calls.into_iter().collect();
71 keyed.sort_by_key(|(i, _)| *i);
72 for (_, acc) in keyed {
73 let id = acc.id.unwrap_or_default();
74 let name = acc.name.unwrap_or_default();
75 let arguments: Value = if acc.arguments_raw.is_empty() {
76 Value::Null
77 } else {
78 serde_json::from_str(&acc.arguments_raw).unwrap_or(Value::String(acc.arguments_raw))
79 };
80 tool_calls.push(ToolCall {
81 id,
82 name,
83 arguments,
84 });
85 }
86 Aggregated {
87 message: Message::Ai(AiMessage {
88 content: self.content,
89 tool_calls,
90 parts: Vec::new(),
91 }),
92 finish_reason: self.finish_reason,
93 usage: self.usage,
94 }
95 }
96
97 fn merge_tool_delta(&mut self, d: ToolCallDelta) {
98 let entry = self.tool_calls.entry(d.index).or_default();
99 if entry.id.is_none() {
100 entry.id = d.id;
101 }
102 if entry.name.is_none() {
103 entry.name = d.name;
104 }
105 if let Some(frag) = d.arguments_delta {
106 entry.arguments_raw.push_str(&frag);
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct Aggregated {
114 pub message: Message,
116 pub finish_reason: Option<String>,
118 pub usage: Option<Usage>,
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 fn text(s: &str) -> StreamChunk {
127 StreamChunk {
128 content: s.into(),
129 is_delta: true,
130 is_done: false,
131 finish_reason: None,
132 usage: None,
133 tool_calls_delta: Vec::new(),
134 }
135 }
136
137 fn done(reason: &str) -> StreamChunk {
138 StreamChunk {
139 content: String::new(),
140 is_delta: false,
141 is_done: true,
142 finish_reason: Some(reason.into()),
143 usage: Some(Usage {
144 prompt_tokens: 5,
145 completion_tokens: 7,
146 total_tokens: 12,
147 }),
148 tool_calls_delta: Vec::new(),
149 }
150 }
151
152 #[test]
153 fn concatenates_text_chunks() {
154 let mut a = StreamAggregator::new();
155 a.push(text("hel"));
156 a.push(text("lo "));
157 a.push(text("world"));
158 a.push(done("stop"));
159 let out = a.finalize();
160 assert_eq!(out.message.content(), "hello world");
161 assert_eq!(out.finish_reason.as_deref(), Some("stop"));
162 assert_eq!(out.usage.unwrap().total_tokens, 12);
163 }
164
165 #[test]
166 fn merges_tool_call_deltas_by_index() {
167 let mut a = StreamAggregator::new();
168 a.push(StreamChunk {
169 content: String::new(),
170 is_delta: true,
171 is_done: false,
172 finish_reason: None,
173 usage: None,
174 tool_calls_delta: vec![ToolCallDelta {
175 index: 0,
176 id: Some("c1".into()),
177 name: Some("search".into()),
178 arguments_delta: Some(r#"{"q":"#.into()),
179 }],
180 });
181 a.push(StreamChunk {
182 content: String::new(),
183 is_delta: true,
184 is_done: false,
185 finish_reason: None,
186 usage: None,
187 tool_calls_delta: vec![ToolCallDelta {
188 index: 0,
189 id: None,
190 name: None,
191 arguments_delta: Some(r#""rust"}"#.into()),
192 }],
193 });
194 a.push(done("tool_calls"));
195 let out = a.finalize();
196 assert_eq!(out.message.tool_calls().len(), 1);
197 let tc = &out.message.tool_calls()[0];
198 assert_eq!(tc.id, "c1");
199 assert_eq!(tc.name, "search");
200 assert_eq!(tc.arguments["q"], "rust");
201 }
202
203 #[test]
204 fn multiple_tool_calls_kept_in_order_by_index() {
205 let mut a = StreamAggregator::new();
206 a.push(StreamChunk {
207 content: String::new(),
208 is_delta: true,
209 is_done: false,
210 finish_reason: None,
211 usage: None,
212 tool_calls_delta: vec![
213 ToolCallDelta {
214 index: 1,
215 id: Some("c2".into()),
216 name: Some("b_tool".into()),
217 arguments_delta: Some("{}".into()),
218 },
219 ToolCallDelta {
220 index: 0,
221 id: Some("c1".into()),
222 name: Some("a_tool".into()),
223 arguments_delta: Some("{}".into()),
224 },
225 ],
226 });
227 a.push(done("tool_calls"));
228 let out = a.finalize();
229 let calls = out.message.tool_calls();
230 assert_eq!(calls.len(), 2);
231 assert_eq!(calls[0].name, "a_tool");
232 assert_eq!(calls[1].name, "b_tool");
233 }
234}