Skip to main content

llm_stack/
stream.rs

1//! Streaming response types.
2//!
3//! When a provider streams its response, it yields a sequence of
4//! [`StreamEvent`]s through a [`ChatStream`]. Events arrive
5//! incrementally — text deltas, tool-call fragments, and finally a
6//! [`Done`](StreamEvent::Done) event with the stop reason.
7//!
8//! # Collecting a stream
9//!
10//! ```rust,no_run
11//! use futures::StreamExt;
12//! use llm_stack::{ChatStream, StreamEvent};
13//!
14//! async fn print_stream(mut stream: ChatStream) {
15//!     while let Some(event) = stream.next().await {
16//!         match event {
17//!             Ok(StreamEvent::TextDelta(text)) => print!("{text}"),
18//!             Ok(StreamEvent::Done { stop_reason }) => {
19//!                 println!("\n[done: {stop_reason:?}]");
20//!             }
21//!             Err(e) => eprintln!("stream error: {e}"),
22//!             _ => {} // handle other events as needed
23//!         }
24//!     }
25//! }
26//! ```
27//!
28//! # Tool-call reassembly
29//!
30//! Tool calls arrive in three phases:
31//! 1. [`ToolCallStart`](StreamEvent::ToolCallStart) — announces the
32//!    call's `id` and `name`.
33//! 2. [`ToolCallDelta`](StreamEvent::ToolCallDelta) — one or more JSON
34//!    argument fragments, streamed as they're generated.
35//! 3. [`ToolCallComplete`](StreamEvent::ToolCallComplete) — the fully
36//!    assembled [`ToolCall`] with parsed arguments.
37//!
38//! The `index` field on each event identifies which call it belongs to
39//! when the model invokes multiple tools in parallel.
40
41use std::pin::Pin;
42
43use futures::Stream;
44use serde::{Deserialize, Serialize};
45
46use crate::chat::{StopReason, ToolCall};
47use crate::error::LlmError;
48use crate::usage::Usage;
49
50/// A pinned, boxed, `Send` stream of [`StreamEvent`] results.
51///
52/// This type alias keeps signatures readable. Consume it with
53/// [`StreamExt`](futures::StreamExt) from the `futures` crate.
54pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>;
55
56/// An incremental event emitted during a streaming response.
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58#[non_exhaustive]
59pub enum StreamEvent {
60    /// A fragment of the model's text output.
61    TextDelta(String),
62    /// A fragment of the model's reasoning (chain-of-thought) output.
63    ReasoningDelta(String),
64    /// Announces that a new tool call has started.
65    ToolCallStart {
66        /// Zero-based index identifying this call when multiple tools
67        /// are invoked in parallel.
68        index: u32,
69        /// Provider-assigned identifier linking start → deltas → complete.
70        id: String,
71        /// The name of the tool being called.
72        name: String,
73    },
74    /// A JSON fragment of the tool call's arguments.
75    ToolCallDelta {
76        /// The tool-call index this delta belongs to.
77        index: u32,
78        /// A chunk of the JSON arguments string.
79        json_chunk: String,
80    },
81    /// The fully assembled tool call, ready to execute.
82    ToolCallComplete {
83        /// The tool-call index this completion corresponds to.
84        index: u32,
85        /// The complete, parsed tool call.
86        call: ToolCall,
87    },
88    /// Token usage information for the request so far.
89    Usage(Usage),
90    /// The stream has ended.
91    Done {
92        /// Why the model stopped generating.
93        stop_reason: StopReason,
94    },
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use futures::StreamExt;
101
102    #[test]
103    fn test_stream_event_text_delta_eq() {
104        let a = StreamEvent::TextDelta("hello".into());
105        let b = a.clone();
106        assert_eq!(a, b);
107    }
108
109    #[test]
110    fn test_stream_event_reasoning_delta_eq() {
111        let a = StreamEvent::ReasoningDelta("step 1".into());
112        assert_eq!(a, a.clone());
113    }
114
115    #[test]
116    fn test_stream_event_tool_call_start() {
117        let e = StreamEvent::ToolCallStart {
118            index: 0,
119            id: "tc_1".into(),
120            name: "search".into(),
121        };
122        assert!(matches!(
123            &e,
124            StreamEvent::ToolCallStart { index: 0, id, name }
125                if id == "tc_1" && name == "search"
126        ));
127    }
128
129    #[test]
130    fn test_stream_event_tool_call_delta() {
131        let e = StreamEvent::ToolCallDelta {
132            index: 0,
133            json_chunk: r#"{"q":"#.into(),
134        };
135        assert_eq!(e, e.clone());
136    }
137
138    #[test]
139    fn test_stream_event_tool_call_complete() {
140        let call = ToolCall {
141            id: "tc_1".into(),
142            name: "search".into(),
143            arguments: serde_json::json!({"q": "rust"}),
144        };
145        let e = StreamEvent::ToolCallComplete {
146            index: 0,
147            call: call.clone(),
148        };
149        assert!(matches!(
150            &e,
151            StreamEvent::ToolCallComplete { call: c, .. } if *c == call
152        ));
153    }
154
155    #[test]
156    fn test_stream_event_usage() {
157        let e = StreamEvent::Usage(Usage {
158            input_tokens: 100,
159            output_tokens: 50,
160            ..Usage::default()
161        });
162        assert_eq!(e, e.clone());
163    }
164
165    #[test]
166    fn test_stream_event_done() {
167        let e = StreamEvent::Done {
168            stop_reason: StopReason::EndTurn,
169        };
170        assert!(matches!(
171            &e,
172            StreamEvent::Done { stop_reason } if *stop_reason == StopReason::EndTurn
173        ));
174    }
175
176    #[tokio::test]
177    async fn test_chat_stream_collect() {
178        let events = vec![
179            Ok(StreamEvent::TextDelta("hello ".into())),
180            Ok(StreamEvent::TextDelta("world".into())),
181            Ok(StreamEvent::Done {
182                stop_reason: StopReason::EndTurn,
183            }),
184        ];
185        let stream: ChatStream = Box::pin(futures::stream::iter(events));
186        let collected: Vec<_> = stream.collect().await;
187        assert_eq!(collected.len(), 3);
188        assert!(collected.iter().all(Result::is_ok));
189    }
190
191    #[tokio::test]
192    async fn test_chat_stream_error_mid_stream() {
193        let events = vec![
194            Ok(StreamEvent::TextDelta("hello".into())),
195            Ok(StreamEvent::TextDelta(" world".into())),
196            Err(LlmError::Http {
197                status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
198                message: "server error".into(),
199                retryable: true,
200            }),
201        ];
202        let stream: ChatStream = Box::pin(futures::stream::iter(events));
203        let collected: Vec<_> = stream.collect().await;
204        assert_eq!(collected.len(), 3);
205        assert!(collected[0].is_ok());
206        assert!(collected[1].is_ok());
207        assert!(collected[2].is_err());
208    }
209
210    #[test]
211    fn test_chat_stream_is_send() {
212        fn assert_send<T: Send>() {}
213        assert_send::<ChatStream>();
214    }
215}