1use async_trait::async_trait;
9use futures::{Stream, StreamExt, stream};
10
11use crate::{
12 Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError, request::ToolCall,
13};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TurnStreamEvent {
22 TextDelta(String),
24 ToolStarted {
26 id: String,
28 name: String,
30 },
31}
32
33#[derive(Debug, Default, Clone)]
35pub struct TurnOutput {
36 pub text: String,
38 pub tool_calls: Vec<ToolCall>,
40 pub usage: Usage,
42 pub stop: Option<StopReason>,
44}
45
46pub async fn collect_turn<S, E>(stream: S) -> Result<TurnOutput, E>
56where
57 S: Stream<Item = Result<Chunk, E>> + Unpin,
58{
59 collect_turn_observed(stream, |_| {}).await
60}
61
62pub async fn collect_turn_observed<S, E, F>(mut stream: S, mut on_event: F) -> Result<TurnOutput, E>
72where
73 S: Stream<Item = Result<Chunk, E>> + Unpin,
74 F: FnMut(TurnStreamEvent),
75{
76 let mut out = TurnOutput::default();
77 let mut pending: Vec<ToolCall> = Vec::new();
84 while let Some(item) = stream.next().await {
85 match item? {
86 Chunk::TextDelta(s) => {
87 on_event(TurnStreamEvent::TextDelta(s.clone()));
88 out.text.push_str(&s);
89 }
90 Chunk::ToolCallStart {
91 id,
92 name,
93 signature,
94 } => {
95 on_event(TurnStreamEvent::ToolStarted {
96 id: id.clone(),
97 name: name.clone(),
98 });
99 pending.push(ToolCall {
100 id,
101 name,
102 args_json: String::new(),
103 signature,
104 });
105 }
106 Chunk::ToolCallArgsDelta {
107 id,
108 args_json_delta,
109 } => {
110 if let Some(tc) = pending.iter_mut().find(|tc| tc.id == id) {
111 tc.args_json.push_str(&args_json_delta);
112 }
113 }
114 Chunk::ToolCallEnd { id } => {
115 if let Some(pos) = pending.iter().position(|tc| tc.id == id) {
119 out.tool_calls.push(pending.remove(pos));
120 }
121 }
122 Chunk::Usage(u) => out.usage = u,
123 Chunk::Stop(r) => {
135 let keep_tool_use =
136 out.stop == Some(StopReason::ToolUse) && matches!(r, StopReason::EndTurn);
137 if !keep_tool_use {
138 out.stop = Some(r);
139 }
140 }
141 }
142 }
143 out.tool_calls.append(&mut pending);
147 Ok(out)
148}
149
150pub const STUB_TOOL_CALL_ENV: &str = "POLYCHROME_STUB_TOOL_CALL";
158
159fn stub_tool_name() -> Option<String> {
160 std::env::var(STUB_TOOL_CALL_ENV)
161 .ok()
162 .filter(|s| !s.is_empty())
163}
164
165#[derive(Clone, Copy, Default)]
176pub struct StubProvider;
177
178#[async_trait]
179impl LlmProvider for StubProvider {
180 type Error = DummyError;
181
182 async fn complete(
183 &self,
184 req: CompletionRequest,
185 ) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
186 if let Some(tool_name) = stub_tool_name() {
190 let saw_result = req.messages.iter().any(|m| {
191 m.content
192 .iter()
193 .any(|c| matches!(c, crate::Content::ToolResult(_)))
194 });
195 if !saw_result {
196 let chunks = vec![
197 Ok(Chunk::tool_call_start("stub-call-1", &tool_name)),
198 Ok(Chunk::tool_call_args_delta("stub-call-1", "{}")),
199 Ok(Chunk::tool_call_end("stub-call-1")),
200 Ok(Chunk::Stop(StopReason::ToolUse)),
201 ];
202 return Ok(stream::iter(chunks).boxed());
203 }
204 }
205 let chunks = vec![
206 Ok(Chunk::text_delta("Hello from the ")),
207 Ok(Chunk::text_delta("stub provider.")),
208 Ok(Chunk::Usage(Usage {
209 input_tokens: 5,
210 output_tokens: 4,
211 })),
212 Ok(Chunk::Stop(StopReason::EndTurn)),
213 ];
214 Ok(stream::iter(chunks).boxed())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
221
222 use super::*;
223
224 #[tokio::test]
225 async fn stub_provider_collects_into_text() {
226 let stream = StubProvider
227 .complete(CompletionRequest::new("stub"))
228 .await
229 .expect("stream opens");
230 let out = collect_turn(stream).await.expect("collect");
231 assert_eq!(out.text, "Hello from the stub provider.");
232 assert!(out.tool_calls.is_empty());
233 assert_eq!(out.usage.output_tokens, 4);
234 assert_eq!(out.stop, Some(StopReason::EndTurn));
235 }
236
237 #[tokio::test]
238 async fn collect_assembles_tool_call_from_deltas() {
239 let chunks: Vec<Result<Chunk, DummyError>> = vec![
240 Ok(Chunk::text_delta("calling ")),
241 Ok(Chunk::tool_call_start("c1", "search")),
242 Ok(Chunk::tool_call_args_delta("c1", r#"{"q":"#)),
243 Ok(Chunk::tool_call_args_delta("c1", r#""rust"}"#)),
244 Ok(Chunk::tool_call_end("c1")),
245 Ok(Chunk::Stop(StopReason::ToolUse)),
246 ];
247 let out = collect_turn(stream::iter(chunks)).await.expect("collect");
248 assert_eq!(out.text, "calling ");
249 assert_eq!(out.tool_calls.len(), 1);
250 assert_eq!(out.tool_calls[0].name, "search");
251 assert_eq!(out.tool_calls[0].args_json, r#"{"q":"rust"}"#);
252 assert_eq!(out.stop, Some(StopReason::ToolUse));
253 }
254
255 #[tokio::test]
256 async fn collect_keeps_parallel_tool_calls_with_deferred_ends() {
257 let chunks: Vec<Result<Chunk, DummyError>> = vec![
262 Ok(Chunk::tool_call_start("c0", "search")),
263 Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
264 Ok(Chunk::tool_call_start("c1", "fetch")),
265 Ok(Chunk::tool_call_args_delta("c1", r#"{"u":"b"}"#)),
266 Ok(Chunk::tool_call_end("c0")),
267 Ok(Chunk::tool_call_end("c1")),
268 Ok(Chunk::Stop(StopReason::ToolUse)),
269 ];
270 let out = collect_turn(stream::iter(chunks)).await.expect("collect");
271 assert_eq!(out.tool_calls.len(), 2, "both parallel calls preserved");
272 assert_eq!(out.tool_calls[0].id, "c0");
273 assert_eq!(out.tool_calls[0].name, "search");
274 assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
275 assert_eq!(out.tool_calls[1].id, "c1");
276 assert_eq!(out.tool_calls[1].name, "fetch");
277 assert_eq!(out.tool_calls[1].args_json, r#"{"u":"b"}"#);
278 assert_eq!(out.stop, Some(StopReason::ToolUse));
279 }
280
281 #[tokio::test]
282 async fn collect_flushes_a_call_left_open_at_eof() {
283 let chunks: Vec<Result<Chunk, DummyError>> = vec![
286 Ok(Chunk::tool_call_start("c0", "search")),
287 Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
288 Ok(Chunk::Stop(StopReason::ToolUse)),
289 ];
290 let out = collect_turn(stream::iter(chunks)).await.expect("collect");
291 assert_eq!(out.tool_calls.len(), 1);
292 assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
293 }
294
295 #[tokio::test]
296 async fn tool_use_stop_is_sticky_against_later_end_turn() {
297 let chunks: Vec<Result<Chunk, DummyError>> = vec![
301 Ok(Chunk::tool_call_start("c1", "search")),
302 Ok(Chunk::tool_call_end("c1")),
303 Ok(Chunk::Stop(StopReason::ToolUse)),
304 Ok(Chunk::Stop(StopReason::EndTurn)),
305 ];
306 let out = collect_turn(stream::iter(chunks)).await.expect("collect");
307 assert_eq!(out.stop, Some(StopReason::ToolUse));
308 }
309
310 #[tokio::test]
311 async fn hard_stop_wins_over_earlier_tool_use() {
312 let chunks: Vec<Result<Chunk, DummyError>> = vec![
315 Ok(Chunk::tool_call_start("c1", "search")),
316 Ok(Chunk::tool_call_end("c1")),
317 Ok(Chunk::Stop(StopReason::ToolUse)),
318 Ok(Chunk::Stop(StopReason::MaxTokens)),
319 ];
320 let out = collect_turn(stream::iter(chunks)).await.expect("collect");
321 assert_eq!(out.stop, Some(StopReason::MaxTokens));
322 }
323
324 #[tokio::test]
325 async fn collect_propagates_error() {
326 let chunks: Vec<Result<Chunk, DummyError>> = vec![
327 Ok(Chunk::text_delta("partial")),
328 Err(DummyError::Other("mid-stream fault".to_owned())),
329 ];
330 let res = collect_turn(stream::iter(chunks)).await;
331 assert!(res.is_err());
332 }
333}