Skip to main content

caliban_provider/
stream.rs

1//! Streaming events.
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use futures::StreamExt;
8use futures::stream::Stream;
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Error, Result};
12use crate::message::{ContentBlock, Message, Role, TextBlock};
13use crate::response::{StopReason, Usage};
14use crate::thinking::ThinkingBlock;
15use crate::tool::ToolUseBlock;
16
17/// A single event in a streaming completion.
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum StreamEvent {
21    /// The message has started; carries the assigned ID and model.
22    MessageStart {
23        /// Provider-assigned message identifier.
24        id: String,
25        /// Model that is generating the message.
26        model: String,
27    },
28    /// A content block is starting at the given index.
29    ContentBlockStart {
30        /// Zero-based block index.
31        index: u32,
32        /// The type of content block that is starting.
33        content_type: StreamingContentType,
34    },
35    /// An incremental delta for the block at the given index.
36    Delta {
37        /// Zero-based block index.
38        index: u32,
39        /// The incremental content.
40        delta: StreamingDelta,
41    },
42    /// The content block at the given index is complete.
43    ContentBlockStop {
44        /// Zero-based block index.
45        index: u32,
46    },
47    /// End-of-message metadata delta.
48    MessageDelta {
49        /// Why the model stopped, if known.
50        #[serde(default, skip_serializing_if = "Option::is_none")]
51        stop_reason: Option<StopReason>,
52        /// Incremental usage update.
53        #[serde(default, skip_serializing_if = "Option::is_none")]
54        usage_delta: Option<Usage>,
55    },
56    /// The message is fully complete.
57    MessageStop,
58    /// A keep-alive ping from the provider.
59    Ping,
60}
61
62/// The type of content block that is opening in a stream.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64#[serde(tag = "type", rename_all = "snake_case")]
65pub enum StreamingContentType {
66    /// A plain-text block.
67    Text,
68    /// A tool-use block with the call ID and tool name.
69    ToolUse {
70        /// Unique call identifier.
71        id: String,
72        /// Name of the tool being called.
73        name: String,
74    },
75    /// An extended-thinking block.
76    Thinking,
77    /// An image block.
78    Image,
79}
80
81/// An incremental delta for a streaming content block.
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83#[serde(tag = "type", rename_all = "snake_case")]
84pub enum StreamingDelta {
85    /// A text increment.
86    Text(String),
87    /// A JSON-fragment increment for a tool-use input.
88    ToolUseInputJson(String),
89    /// A thinking-text increment.
90    Thinking(String),
91}
92
93/// Boxed dynamic stream of stream events.
94pub type MessageStream = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 'static>>;
95
96/// Consume a [`MessageStream`] and assemble the final [`Message`], [`StopReason`], and [`Usage`].
97///
98/// # Errors
99///
100/// Returns the first stream error encountered, or `Error::InvalidRequest` if
101/// an unsupported block type is streamed.
102#[allow(clippy::too_many_lines)]
103pub async fn collect_message(mut stream: MessageStream) -> Result<(Message, StopReason, Usage)> {
104    let mut blocks: Vec<ContentBlock> = Vec::new();
105    let mut block_types: Vec<StreamingContentType> = Vec::new();
106    let mut block_text: Vec<String> = Vec::new();
107    let mut block_json: Vec<String> = Vec::new();
108    let mut stop_reason: Option<StopReason> = None;
109    let mut usage = Usage::default();
110
111    while let Some(evt) = stream.next().await {
112        match evt? {
113            StreamEvent::MessageStart { .. } | StreamEvent::Ping | StreamEvent::MessageStop => {}
114            StreamEvent::ContentBlockStart {
115                index,
116                content_type,
117            } => {
118                let i = index as usize;
119                if blocks.len() <= i {
120                    blocks.resize(
121                        i + 1,
122                        ContentBlock::Text(TextBlock {
123                            text: String::new(),
124                            cache_control: None,
125                        }),
126                    );
127                    block_types.resize(i + 1, StreamingContentType::Text);
128                    block_text.resize(i + 1, String::new());
129                    block_json.resize(i + 1, String::new());
130                }
131                block_types[i] = content_type;
132            }
133            StreamEvent::Delta { index, delta } => {
134                let i = index as usize;
135                if i >= block_types.len() {
136                    return Err(Error::InvalidRequest(format!(
137                        "Delta event for uninitialized block index {i}"
138                    )));
139                }
140                match delta {
141                    StreamingDelta::Text(s) | StreamingDelta::Thinking(s) => {
142                        block_text[i].push_str(&s);
143                    }
144                    StreamingDelta::ToolUseInputJson(s) => block_json[i].push_str(&s),
145                }
146            }
147            StreamEvent::ContentBlockStop { index } => {
148                let i = index as usize;
149                if i >= block_types.len() {
150                    return Err(Error::InvalidRequest(format!(
151                        "ContentBlockStop for uninitialized block index {i}"
152                    )));
153                }
154                let block = match &block_types[i] {
155                    StreamingContentType::Text => ContentBlock::Text(TextBlock {
156                        text: std::mem::take(&mut block_text[i]),
157                        cache_control: None,
158                    }),
159                    StreamingContentType::Thinking => ContentBlock::Thinking(ThinkingBlock {
160                        thinking: std::mem::take(&mut block_text[i]),
161                        signature: None,
162                    }),
163                    StreamingContentType::ToolUse { id, name } => {
164                        let json_str = std::mem::take(&mut block_json[i]);
165                        let input = if json_str.is_empty() {
166                            serde_json::json!({})
167                        } else {
168                            serde_json::from_str(&json_str).map_err(|e| {
169                                Error::InvalidRequest(format!(
170                                    "tool_use input json parse error: {e}"
171                                ))
172                            })?
173                        };
174                        ContentBlock::ToolUse(ToolUseBlock {
175                            id: id.clone(),
176                            name: name.clone(),
177                            input,
178                        })
179                    }
180                    StreamingContentType::Image => {
181                        return Err(Error::InvalidRequest(
182                            "streaming Image blocks are not supported in collect_message".into(),
183                        ));
184                    }
185                };
186                blocks[i] = block;
187            }
188            StreamEvent::MessageDelta {
189                stop_reason: sr,
190                usage_delta,
191            } => {
192                if let Some(sr) = sr {
193                    stop_reason = Some(sr);
194                }
195                if let Some(u) = usage_delta {
196                    usage.merge(u);
197                }
198            }
199        }
200    }
201
202    let stop = stop_reason.unwrap_or(StopReason::EndTurn);
203    Ok((
204        Message {
205            role: Role::Assistant,
206            content: blocks,
207        },
208        stop,
209        usage,
210    ))
211}
212
213// ---------------------------------------------------------------------------
214// WatchedStream — stream-idle watchdog (ADR Plan A, Task 8)
215// ---------------------------------------------------------------------------
216
217/// Wraps a `Stream` and aborts with [`Error::StreamIdle`] when no chunk
218/// arrives within `idle`.
219///
220/// Emits a `tracing::warn` at half-time (helpful operational signal for
221/// observability dashboards) and `Err(Error::StreamIdle)` on full timeout.
222///
223/// `S` must be `Unpin` because we hold the inner stream in a `Box<dyn ...>`
224/// behind a `Pin<&mut Self>`-style `poll_next`. The concrete provider streams
225/// (`MessageStream = Pin<Box<dyn Stream + Send>>`) are already pinned at
226/// construction; `WatchedStream` owns the pointer directly so projection
227/// stays simple without pulling in `pin_project_lite`.
228pub struct WatchedStream<S> {
229    inner: S,
230    idle: Duration,
231    last_chunk_at: Instant,
232    warned: bool,
233}
234
235impl<S> WatchedStream<S> {
236    /// Build a new `WatchedStream`. `idle` is the maximum time the inner
237    /// stream may stay silent before [`Error::StreamIdle`] is surfaced.
238    pub fn new(inner: S, idle: Duration) -> Self {
239        Self {
240            inner,
241            idle,
242            last_chunk_at: Instant::now(),
243            warned: false,
244        }
245    }
246}
247
248impl<S> Stream for WatchedStream<S>
249where
250    S: Stream<Item = Result<StreamEvent>> + Unpin,
251{
252    type Item = Result<StreamEvent>;
253
254    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255        match Pin::new(&mut self.inner).poll_next(cx) {
256            Poll::Ready(Some(item)) => {
257                self.last_chunk_at = Instant::now();
258                self.warned = false;
259                Poll::Ready(Some(item))
260            }
261            Poll::Ready(None) => Poll::Ready(None),
262            Poll::Pending => {
263                let elapsed = self.last_chunk_at.elapsed();
264                if elapsed >= self.idle {
265                    tracing::error!(
266                        target: "caliban::stream",
267                        elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
268                        "recovery.stream_idle.abort"
269                    );
270                    return Poll::Ready(Some(Err(Error::StreamIdle(elapsed))));
271                }
272                if !self.warned && elapsed >= self.idle / 2 {
273                    self.warned = true;
274                    tracing::warn!(
275                        target: "caliban::stream",
276                        elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
277                        "recovery.stream_idle.warning"
278                    );
279                }
280                // Schedule a wakeup at the remaining time so we can fire the
281                // abort even if `inner` stays Pending. The spawned future is
282                // a single sleep + wake; it self-terminates.
283                let remaining = self.idle.checked_sub(elapsed).unwrap_or(Duration::ZERO);
284                let waker = cx.waker().clone();
285                tokio::spawn(async move {
286                    tokio::time::sleep(remaining + Duration::from_millis(1)).await;
287                    waker.wake();
288                });
289                Poll::Pending
290            }
291        }
292    }
293}
294
295#[cfg(test)]
296mod watched_tests {
297    use super::*;
298    use futures::stream;
299    use std::time::Duration;
300
301    #[tokio::test]
302    async fn passes_through_normal_data() {
303        let inner = stream::iter(vec![
304            Ok(StreamEvent::MessageStop),
305            Ok(StreamEvent::MessageStop),
306        ]);
307        let mut w = WatchedStream::new(inner, Duration::from_secs(1));
308        let mut seen = 0;
309        while let Some(item) = w.next().await {
310            item.unwrap();
311            seen += 1;
312        }
313        assert_eq!(seen, 2);
314    }
315
316    #[tokio::test]
317    async fn aborts_after_idle_timeout() {
318        let inner = stream::pending::<Result<StreamEvent>>();
319        let mut w = WatchedStream::new(inner, Duration::from_millis(20));
320        let r = w.next().await.expect("Some(_)");
321        assert!(matches!(r, Err(Error::StreamIdle(_))));
322    }
323}