Skip to main content

polyc_llm/
turn.rs

1//! Turn helpers: a [`StubProvider`] for wiring/tests and [`collect_turn`],
2//! which folds a provider's [`Chunk`] stream into a single [`TurnOutput`].
3//!
4//! `collect_turn` is the output half of the bridge between this crate's
5//! streaming vocabulary and the message-granular wire types: the harness drains
6//! a provider stream into a `TurnOutput`, then maps that to wire `Message`s.
7
8use async_trait::async_trait;
9use futures::{Stream, StreamExt, stream};
10
11use crate::{
12    Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError, request::ToolCall,
13};
14
15/// An incremental event observed while folding a turn, for live streaming.
16///
17/// Surfaces like Slack `chat.appendStream` or a streaming CLI consume these;
18/// the buffered [`TurnOutput`] is still returned in full — this is a side
19/// channel, not a replacement.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TurnStreamEvent {
22    /// A freshly-generated piece of answer text (concatenate to reconstruct).
23    TextDelta(String),
24    /// The model has begun a tool call (`id` + `name` known up front).
25    ToolStarted {
26        /// Provider-assigned call id.
27        id: String,
28        /// Name of the tool being called.
29        name: String,
30    },
31}
32
33/// The fully-assembled result of one turn, folded from a [`Chunk`] stream.
34#[derive(Debug, Default, Clone)]
35pub struct TurnOutput {
36    /// Concatenated text deltas.
37    pub text: String,
38    /// Completed tool calls, in arrival order.
39    pub tool_calls: Vec<ToolCall>,
40    /// Final token accounting (last [`Chunk::Usage`] seen).
41    pub usage: Usage,
42    /// Why the turn ended, if the stream reported it.
43    pub stop: Option<StopReason>,
44}
45
46/// Drain a provider stream into a [`TurnOutput`].
47///
48/// Text deltas concatenate; a tool call accretes from its
49/// `ToolCallStart`/`ToolCallArgsDelta`/`ToolCallEnd` run (matched by `id`);
50/// usage and stop reason are taken from their chunks.
51///
52/// # Errors
53///
54/// Propagates the first `Err` item from the stream.
55pub async fn collect_turn<S, E>(stream: S) -> Result<TurnOutput, E>
56where
57    S: Stream<Item = Result<Chunk, E>> + Unpin,
58{
59    collect_turn_observed(stream, |_| {}).await
60}
61
62/// Like [`collect_turn`], but observes each streamable event as it arrives.
63///
64/// Invokes `on_event` for each text delta / tool start while still folding and
65/// returning the complete [`TurnOutput`]. `on_event` is synchronous and must
66/// not block (e.g. an unbounded-channel `send`).
67///
68/// # Errors
69///
70/// Propagates the first `Err` item from the stream.
71pub async fn collect_turn_observed<S, E, F>(mut stream: S, mut on_event: F) -> Result<TurnOutput, E>
72where
73    S: Stream<Item = Result<Chunk, E>> + Unpin,
74    F: FnMut(TurnStreamEvent),
75{
76    let mut out = TurnOutput::default();
77    // In-progress tool calls, kept in start order and matched by id. A provider
78    // may interleave several calls (OpenAI's `parallel_tool_calls` defaults to
79    // true) and/or defer all their `ToolCallEnd`s to the end of the stream, so a
80    // single `Option` would let a second `ToolCallStart` clobber the first and
81    // an `ToolCallEnd` close the wrong call. Matching by id throughout keeps
82    // every parallel call intact regardless of emission order.
83    let mut pending: Vec<ToolCall> = Vec::new();
84    while let Some(item) = stream.next().await {
85        match item? {
86            Chunk::TextDelta(s) => {
87                on_event(TurnStreamEvent::TextDelta(s.clone()));
88                out.text.push_str(&s);
89            }
90            Chunk::ToolCallStart {
91                id,
92                name,
93                signature,
94            } => {
95                on_event(TurnStreamEvent::ToolStarted {
96                    id: id.clone(),
97                    name: name.clone(),
98                });
99                pending.push(ToolCall {
100                    id,
101                    name,
102                    args_json: String::new(),
103                    signature,
104                });
105            }
106            Chunk::ToolCallArgsDelta {
107                id,
108                args_json_delta,
109            } => {
110                if let Some(tc) = pending.iter_mut().find(|tc| tc.id == id) {
111                    tc.args_json.push_str(&args_json_delta);
112                }
113            }
114            Chunk::ToolCallEnd { id } => {
115                // Move the matching call to the output in completion order. An
116                // unmatched id is ignored (defensive); calls still open at EOF
117                // are flushed after the loop so none are silently dropped.
118                if let Some(pos) = pending.iter().position(|tc| tc.id == id) {
119                    out.tool_calls.push(pending.remove(pos));
120                }
121            }
122            Chunk::Usage(u) => out.usage = u,
123            // A `ToolUse` stop is sticky against a *later* `EndTurn`. Some
124            // providers stream the tool call in one event and then a separate
125            // trailing terminator event carrying an end-of-turn finish reason;
126            // letting that later `EndTurn` overwrite the `ToolUse` stop would
127            // make the agent loop skip executing the tool and end the turn with
128            // no output.
129            //
130            // A *hard* stop (MaxTokens / Refusal / StopSequence) is the
131            // opposite: it means the turn was truncated or refused, so it must
132            // win over an earlier `ToolUse` — the tool call may be incomplete
133            // and must not be executed.
134            Chunk::Stop(r) => {
135                let keep_tool_use =
136                    out.stop == Some(StopReason::ToolUse) && matches!(r, StopReason::EndTurn);
137                if !keep_tool_use {
138                    out.stop = Some(r);
139                }
140            }
141        }
142    }
143    // Flush any call that started (and may have accreted args) but whose
144    // `ToolCallEnd` never arrived — a provider that omits the terminator must
145    // not lose the call.
146    out.tool_calls.append(&mut pending);
147    Ok(out)
148}
149
150/// Env var: emit a synthetic tool call for `<name>` on the stub provider.
151///
152/// First `complete()` of a turn emits a synthetic tool call for the named
153/// tool, subsequent calls (once a `tool_result` has landed in the
154/// transcript) fall back to canned `EndTurn` text. Empty / unset keeps the
155/// canned-text behaviour. Used by the HITL resume loopback verification to
156/// drive the data path without a real provider backend.
157pub const STUB_TOOL_CALL_ENV: &str = "POLYCHROME_STUB_TOOL_CALL";
158
159fn stub_tool_name() -> Option<String> {
160    std::env::var(STUB_TOOL_CALL_ENV)
161        .ok()
162        .filter(|s| !s.is_empty())
163}
164
165/// A canned [`LlmProvider`] for wiring and tests.
166///
167/// Emits two text deltas, a usage tally, and an end-of-turn stop. No
168/// network, no credentials.
169///
170/// When [`STUB_TOOL_CALL_ENV`] is set, the first `complete()` of a turn
171/// emits a synthetic tool call (id `stub-call-1`) for that tool name and
172/// the caller's function-calling loop drives the rest. Subsequent calls
173/// in the same turn fall back to the `EndTurn` text path. Used by the
174/// HITL resume loopback verification.
175#[derive(Clone, Copy, Default)]
176pub struct StubProvider;
177
178#[async_trait]
179impl LlmProvider for StubProvider {
180    type Error = DummyError;
181
182    async fn complete(
183        &self,
184        req: CompletionRequest,
185    ) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
186        // If POLYCHROME_STUB_TOOL_CALL is set and we haven't yet seen a
187        // matching tool_result in the transcript, emit the synthetic tool
188        // call. Otherwise fall through to canned text.
189        if let Some(tool_name) = stub_tool_name() {
190            let saw_result = req.messages.iter().any(|m| {
191                m.content
192                    .iter()
193                    .any(|c| matches!(c, crate::Content::ToolResult(_)))
194            });
195            if !saw_result {
196                let chunks = vec![
197                    Ok(Chunk::tool_call_start("stub-call-1", &tool_name)),
198                    Ok(Chunk::tool_call_args_delta("stub-call-1", "{}")),
199                    Ok(Chunk::tool_call_end("stub-call-1")),
200                    Ok(Chunk::Stop(StopReason::ToolUse)),
201                ];
202                return Ok(stream::iter(chunks).boxed());
203            }
204        }
205        let chunks = vec![
206            Ok(Chunk::text_delta("Hello from the ")),
207            Ok(Chunk::text_delta("stub provider.")),
208            Ok(Chunk::Usage(Usage {
209                input_tokens: 5,
210                output_tokens: 4,
211            })),
212            Ok(Chunk::Stop(StopReason::EndTurn)),
213        ];
214        Ok(stream::iter(chunks).boxed())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
221
222    use super::*;
223
224    #[tokio::test]
225    async fn stub_provider_collects_into_text() {
226        let stream = StubProvider
227            .complete(CompletionRequest::new("stub"))
228            .await
229            .expect("stream opens");
230        let out = collect_turn(stream).await.expect("collect");
231        assert_eq!(out.text, "Hello from the stub provider.");
232        assert!(out.tool_calls.is_empty());
233        assert_eq!(out.usage.output_tokens, 4);
234        assert_eq!(out.stop, Some(StopReason::EndTurn));
235    }
236
237    #[tokio::test]
238    async fn collect_assembles_tool_call_from_deltas() {
239        let chunks: Vec<Result<Chunk, DummyError>> = vec![
240            Ok(Chunk::text_delta("calling ")),
241            Ok(Chunk::tool_call_start("c1", "search")),
242            Ok(Chunk::tool_call_args_delta("c1", r#"{"q":"#)),
243            Ok(Chunk::tool_call_args_delta("c1", r#""rust"}"#)),
244            Ok(Chunk::tool_call_end("c1")),
245            Ok(Chunk::Stop(StopReason::ToolUse)),
246        ];
247        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
248        assert_eq!(out.text, "calling ");
249        assert_eq!(out.tool_calls.len(), 1);
250        assert_eq!(out.tool_calls[0].name, "search");
251        assert_eq!(out.tool_calls[0].args_json, r#"{"q":"rust"}"#);
252        assert_eq!(out.stop, Some(StopReason::ToolUse));
253    }
254
255    #[tokio::test]
256    async fn collect_keeps_parallel_tool_calls_with_deferred_ends() {
257        // Two interleaved calls whose `ToolCallEnd`s are both deferred to the
258        // end of the stream (the OpenAI-compatible provider's shape). A single
259        // `Option` would drop call 0 and close the survivor with the wrong end;
260        // id-matching must preserve both, in completion order.
261        let chunks: Vec<Result<Chunk, DummyError>> = vec![
262            Ok(Chunk::tool_call_start("c0", "search")),
263            Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
264            Ok(Chunk::tool_call_start("c1", "fetch")),
265            Ok(Chunk::tool_call_args_delta("c1", r#"{"u":"b"}"#)),
266            Ok(Chunk::tool_call_end("c0")),
267            Ok(Chunk::tool_call_end("c1")),
268            Ok(Chunk::Stop(StopReason::ToolUse)),
269        ];
270        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
271        assert_eq!(out.tool_calls.len(), 2, "both parallel calls preserved");
272        assert_eq!(out.tool_calls[0].id, "c0");
273        assert_eq!(out.tool_calls[0].name, "search");
274        assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
275        assert_eq!(out.tool_calls[1].id, "c1");
276        assert_eq!(out.tool_calls[1].name, "fetch");
277        assert_eq!(out.tool_calls[1].args_json, r#"{"u":"b"}"#);
278        assert_eq!(out.stop, Some(StopReason::ToolUse));
279    }
280
281    #[tokio::test]
282    async fn collect_flushes_a_call_left_open_at_eof() {
283        // A provider that omits the terminal `ToolCallEnd` must not lose the
284        // call — it is flushed when the stream ends.
285        let chunks: Vec<Result<Chunk, DummyError>> = vec![
286            Ok(Chunk::tool_call_start("c0", "search")),
287            Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
288            Ok(Chunk::Stop(StopReason::ToolUse)),
289        ];
290        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
291        assert_eq!(out.tool_calls.len(), 1);
292        assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
293    }
294
295    #[tokio::test]
296    async fn tool_use_stop_is_sticky_against_later_end_turn() {
297        // Provider streams the tool call (ToolUse) then a trailing terminator
298        // event (EndTurn). The terminator must NOT clobber ToolUse, else the
299        // agent loop skips the tool.
300        let chunks: Vec<Result<Chunk, DummyError>> = vec![
301            Ok(Chunk::tool_call_start("c1", "search")),
302            Ok(Chunk::tool_call_end("c1")),
303            Ok(Chunk::Stop(StopReason::ToolUse)),
304            Ok(Chunk::Stop(StopReason::EndTurn)),
305        ];
306        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
307        assert_eq!(out.stop, Some(StopReason::ToolUse));
308    }
309
310    #[tokio::test]
311    async fn hard_stop_wins_over_earlier_tool_use() {
312        // A later MaxTokens (truncation) MUST override an earlier ToolUse so
313        // the agent doesn't execute a tool call with truncated arguments.
314        let chunks: Vec<Result<Chunk, DummyError>> = vec![
315            Ok(Chunk::tool_call_start("c1", "search")),
316            Ok(Chunk::tool_call_end("c1")),
317            Ok(Chunk::Stop(StopReason::ToolUse)),
318            Ok(Chunk::Stop(StopReason::MaxTokens)),
319        ];
320        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
321        assert_eq!(out.stop, Some(StopReason::MaxTokens));
322    }
323
324    #[tokio::test]
325    async fn collect_propagates_error() {
326        let chunks: Vec<Result<Chunk, DummyError>> = vec![
327            Ok(Chunk::text_delta("partial")),
328            Err(DummyError::Other("mid-stream fault".to_owned())),
329        ];
330        let res = collect_turn(stream::iter(chunks)).await;
331        assert!(res.is_err());
332    }
333}