Skip to main content

polyc_llm/
chunk.rs

1//! Streaming response types: [`Chunk`], [`Usage`], and [`StopReason`].
2//!
3//! Where [`request`](crate::request) describes what goes *into* a provider,
4//! this module describes what streams *out*. A provider yields an ordered
5//! sequence of [`Chunk`]s; the planner reassembles them into assistant turns.
6//!
7//! The shape is the richest streaming superset across the backends we target:
8//! text deltas, a tool call that announces itself and then accretes its JSON
9//! arguments incrementally ([`Chunk::ToolCallArgsDelta`] `args_json_delta`), and
10//! usage. Critically, tool-call arguments arrive as **JSON string deltas** in
11//! every backend: we do not attempt to deserialize them until the matching
12//! [`Chunk::ToolCallEnd`].
13
14use serde::{Deserialize, Serialize};
15
16// ── Chunk ──────────────────────────────────────────────────────────────────────
17
18/// A single event in a provider's streaming response.
19///
20/// A complete stream is an ordered sequence of these. Text generation surfaces
21/// as a run of [`Chunk::TextDelta`] events; a tool call surfaces as exactly one
22/// [`Chunk::ToolCallStart`], zero or more [`Chunk::ToolCallArgsDelta`] fragments
23/// (whose concatenated `args_json_delta`s form the call's JSON arguments), and
24/// exactly one [`Chunk::ToolCallEnd`]. Every tool-call event carries the call
25/// `id` so concurrently-streamed calls can be demultiplexed. [`Chunk::Usage`]
26/// reports token accounting and may appear more than once. A well-formed stream
27/// ends with a single [`Chunk::Stop`].
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30#[non_exhaustive]
31pub enum Chunk {
32    /// An incremental piece of generated text. Concatenate consecutive
33    /// `TextDelta`s to recover the full text segment.
34    TextDelta(String),
35    /// The model has begun a tool call. The `id` and `name` are known up front;
36    /// arguments stream in as subsequent [`Chunk::ToolCallArgsDelta`]s bearing
37    /// the same `id`.
38    ToolCallStart {
39        /// Provider-assigned call identifier, matching a future
40        /// [`ToolCall::id`](crate::request::ToolCall::id).
41        id: String,
42        /// Name of the tool being called.
43        name: String,
44        /// Opaque provider-specific signature for this call (e.g. a thinking
45        /// model's thought signature), to be carried onto the assembled
46        /// [`ToolCall`](crate::request::ToolCall) and echoed back on the
47        /// follow-up request. `None` when the provider emits no such token.
48        #[serde(default, skip_serializing_if = "Option::is_none")]
49        signature: Option<String>,
50    },
51    /// An incremental fragment of a tool call's JSON arguments. Concatenate the
52    /// `args_json_delta`s of all fragments sharing an `id` to recover the full
53    /// `args_json`. Do not parse until the matching [`Chunk::ToolCallEnd`].
54    ToolCallArgsDelta {
55        /// Identifies which in-progress [`Chunk::ToolCallStart`] this fragment
56        /// belongs to.
57        id: String,
58        /// A partial slice of the call's JSON-encoded arguments.
59        args_json_delta: String,
60    },
61    /// The named tool call's arguments are complete and may now be parsed.
62    ToolCallEnd {
63        /// Identifies the completed [`Chunk::ToolCallStart`].
64        id: String,
65    },
66    /// A token-accounting update. May arrive more than once per stream (e.g.
67    /// input tokens early, output tokens at the end).
68    Usage(Usage),
69    /// Terminal event: generation has finished for the reason given.
70    Stop(StopReason),
71}
72
73impl Chunk {
74    /// Wraps `s` in a [`Chunk::TextDelta`].
75    #[must_use]
76    pub fn text_delta(s: impl Into<String>) -> Self {
77        Self::TextDelta(s.into())
78    }
79
80    /// Constructs a [`Chunk::ToolCallStart`] event (no signature).
81    #[must_use]
82    pub fn tool_call_start(id: impl Into<String>, name: impl Into<String>) -> Self {
83        Self::ToolCallStart {
84            id: id.into(),
85            name: name.into(),
86            signature: None,
87        }
88    }
89
90    /// Constructs a [`Chunk::ToolCallStart`] event carrying an opaque
91    /// provider-specific `signature`.
92    #[must_use]
93    pub fn tool_call_start_signed(
94        id: impl Into<String>,
95        name: impl Into<String>,
96        signature: Option<String>,
97    ) -> Self {
98        Self::ToolCallStart {
99            id: id.into(),
100            name: name.into(),
101            signature,
102        }
103    }
104
105    /// Constructs a [`Chunk::ToolCallArgsDelta`] carrying a partial-args fragment.
106    #[must_use]
107    pub fn tool_call_args_delta(id: impl Into<String>, args_json_delta: impl Into<String>) -> Self {
108        Self::ToolCallArgsDelta {
109            id: id.into(),
110            args_json_delta: args_json_delta.into(),
111        }
112    }
113
114    /// Constructs a [`Chunk::ToolCallEnd`] event.
115    #[must_use]
116    pub fn tool_call_end(id: impl Into<String>) -> Self {
117        Self::ToolCallEnd { id: id.into() }
118    }
119}
120
121// ── Usage ──────────────────────────────────────────────────────────────────────
122
123/// Token accounting for a request/response pair.
124///
125/// Counts are cumulative within a single stream. Surfaced to the
126/// `polychrome_llm_tokens_total{direction}` metric as input/output directions.
127#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
128pub struct Usage {
129    /// Tokens consumed by the prompt (system + messages + tools).
130    pub input_tokens: u64,
131    /// Tokens produced by the model.
132    pub output_tokens: u64,
133}
134
135impl Usage {
136    /// Total tokens billed for this exchange (`input_tokens + output_tokens`).
137    ///
138    /// Saturates rather than overflowing; real responses never approach
139    /// `u64::MAX`, but the arithmetic is total so callers need no guard.
140    #[must_use]
141    pub const fn total_tokens(self) -> u64 {
142        self.input_tokens.saturating_add(self.output_tokens)
143    }
144}
145
146// ── StopReason ───────────────────────────────────────────────────────────────────
147
148/// Why the model stopped generating.
149///
150/// The variants are the provider-agnostic union of the common providers'
151/// terminal states.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153#[serde(rename_all = "snake_case")]
154#[non_exhaustive]
155pub enum StopReason {
156    /// The model finished its turn naturally.
157    EndTurn,
158    /// Generation hit the request's `max_tokens` ceiling.
159    MaxTokens,
160    /// The model emitted one of the request's `stop` sequences.
161    StopSequence,
162    /// The model paused to call one or more tools; the caller is expected to
163    /// run them and continue the conversation.
164    ToolUse,
165    /// The model declined to answer, or the provider's content filter halted
166    /// generation. Mirrors the wire-side `STOP_REASON_REFUSAL`.
167    Refusal,
168}
169
170// ── Tests ─────────────────────────────────────────────────────────────────────
171
172#[cfg(test)]
173mod tests {
174    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
175
176    use serde_json::{Value, json};
177
178    use super::*;
179
180    #[test]
181    fn text_delta_constructor() {
182        assert_eq!(Chunk::text_delta("hi"), Chunk::TextDelta("hi".to_owned()));
183    }
184
185    #[test]
186    fn tool_call_start_constructor() {
187        match Chunk::tool_call_start("call-1", "search") {
188            Chunk::ToolCallStart { id, name, .. } => {
189                assert_eq!(id, "call-1");
190                assert_eq!(name, "search");
191            }
192            _ => panic!("wrong variant"),
193        }
194    }
195
196    #[test]
197    fn tool_call_args_delta_constructor() {
198        match Chunk::tool_call_args_delta("call-1", r#"{"q":"#) {
199            Chunk::ToolCallArgsDelta {
200                id,
201                args_json_delta,
202            } => {
203                assert_eq!(id, "call-1");
204                assert_eq!(args_json_delta, r#"{"q":"#);
205            }
206            _ => panic!("wrong variant"),
207        }
208    }
209
210    #[test]
211    fn tool_call_end_constructor() {
212        assert_eq!(
213            Chunk::tool_call_end("call-1"),
214            Chunk::ToolCallEnd {
215                id: "call-1".to_owned()
216            },
217        );
218    }
219
220    #[test]
221    fn text_delta_serializes_as_tagged_object() {
222        let v: Value = serde_json::to_value(Chunk::text_delta("hello")).unwrap();
223        assert_eq!(v, json!({"text_delta": "hello"}));
224    }
225
226    #[test]
227    fn tool_call_start_serializes_with_named_fields() {
228        let v: Value = serde_json::to_value(Chunk::tool_call_start("id-1", "calc")).unwrap();
229        assert_eq!(
230            v,
231            json!({"tool_call_start": {"id": "id-1", "name": "calc"}})
232        );
233    }
234
235    #[test]
236    fn tool_call_args_delta_serializes_with_named_fields() {
237        let v: Value =
238            serde_json::to_value(Chunk::tool_call_args_delta("id-1", r#"{"x":1}"#)).unwrap();
239        assert_eq!(
240            v,
241            json!({"tool_call_args_delta": {"id": "id-1", "args_json_delta": r#"{"x":1}"#}}),
242        );
243    }
244
245    #[test]
246    fn chunk_round_trips_all_variants() {
247        for chunk in [
248            Chunk::text_delta("partial"),
249            Chunk::tool_call_start("c1", "weather"),
250            Chunk::tool_call_args_delta("c1", r#"{"city":"NYC"}"#),
251            Chunk::tool_call_end("c1"),
252            Chunk::Usage(Usage {
253                input_tokens: 10,
254                output_tokens: 20,
255            }),
256            Chunk::Stop(StopReason::EndTurn),
257        ] {
258            let json = serde_json::to_string(&chunk).unwrap();
259            let back: Chunk = serde_json::from_str(&json).unwrap();
260            assert_eq!(back, chunk);
261        }
262    }
263
264    #[test]
265    fn reassemble_tool_call_args_from_deltas_by_id() {
266        // Concatenating same-id ToolCallArgsDelta payloads recovers the full
267        // JSON; a foreign id must not bleed into the assembly.
268        let stream = [
269            Chunk::tool_call_start("a", "weather"),
270            Chunk::tool_call_args_delta("a", r#"{"city":"#),
271            Chunk::tool_call_args_delta("b", "IGNORED"),
272            Chunk::tool_call_args_delta("a", r#""NYC"}"#),
273            Chunk::tool_call_end("a"),
274        ];
275        let mut assembled = String::new();
276        for c in &stream {
277            if let Chunk::ToolCallArgsDelta {
278                id,
279                args_json_delta,
280            } = c
281                && id == "a"
282            {
283                assembled.push_str(args_json_delta);
284            }
285        }
286        let parsed: Value = serde_json::from_str(&assembled).unwrap();
287        assert_eq!(parsed, json!({"city": "NYC"}));
288    }
289
290    #[test]
291    fn usage_total_sums_input_and_output() {
292        let u = Usage {
293            input_tokens: 100,
294            output_tokens: 250,
295        };
296        assert_eq!(u.total_tokens(), 350);
297    }
298
299    #[test]
300    fn usage_total_saturates_on_overflow() {
301        let u = Usage {
302            input_tokens: u64::MAX,
303            output_tokens: 1,
304        };
305        assert_eq!(u.total_tokens(), u64::MAX);
306    }
307
308    #[test]
309    fn usage_default_is_all_zero() {
310        let u = Usage::default();
311        assert_eq!(u.input_tokens, 0);
312        assert_eq!(u.output_tokens, 0);
313        assert_eq!(u.total_tokens(), 0);
314    }
315
316    #[test]
317    fn stop_reason_serializes_to_snake_case() {
318        assert_eq!(
319            serde_json::to_string(&StopReason::EndTurn).unwrap(),
320            r#""end_turn""#
321        );
322        assert_eq!(
323            serde_json::to_string(&StopReason::MaxTokens).unwrap(),
324            r#""max_tokens""#
325        );
326        assert_eq!(
327            serde_json::to_string(&StopReason::StopSequence).unwrap(),
328            r#""stop_sequence""#,
329        );
330        assert_eq!(
331            serde_json::to_string(&StopReason::ToolUse).unwrap(),
332            r#""tool_use""#
333        );
334        assert_eq!(
335            serde_json::to_string(&StopReason::Refusal).unwrap(),
336            r#""refusal""#,
337        );
338    }
339
340    #[test]
341    fn stop_reason_round_trips() {
342        for reason in [
343            StopReason::EndTurn,
344            StopReason::MaxTokens,
345            StopReason::StopSequence,
346            StopReason::ToolUse,
347            StopReason::Refusal,
348        ] {
349            let json = serde_json::to_string(&reason).unwrap();
350            let back: StopReason = serde_json::from_str(&json).unwrap();
351            assert_eq!(back, reason);
352        }
353    }
354}