1use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use futures::StreamExt;
8use futures::stream::Stream;
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Error, Result};
12use crate::message::{ContentBlock, Message, Role, TextBlock};
13use crate::response::{StopReason, Usage};
14use crate::thinking::ThinkingBlock;
15use crate::tool::ToolUseBlock;
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum StreamEvent {
21 MessageStart {
23 id: String,
25 model: String,
27 },
28 ContentBlockStart {
30 index: u32,
32 content_type: StreamingContentType,
34 },
35 Delta {
37 index: u32,
39 delta: StreamingDelta,
41 },
42 ContentBlockStop {
44 index: u32,
46 },
47 MessageDelta {
49 #[serde(default, skip_serializing_if = "Option::is_none")]
51 stop_reason: Option<StopReason>,
52 #[serde(default, skip_serializing_if = "Option::is_none")]
54 usage_delta: Option<Usage>,
55 },
56 MessageStop,
58 Ping,
60}
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64#[serde(tag = "type", rename_all = "snake_case")]
65pub enum StreamingContentType {
66 Text,
68 ToolUse {
70 id: String,
72 name: String,
74 },
75 Thinking,
77 Image,
79}
80
81#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83#[serde(tag = "type", rename_all = "snake_case")]
84pub enum StreamingDelta {
85 Text(String),
87 ToolUseInputJson(String),
89 Thinking(String),
91}
92
93pub type MessageStream = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 'static>>;
95
96#[allow(clippy::too_many_lines)]
103pub async fn collect_message(mut stream: MessageStream) -> Result<(Message, StopReason, Usage)> {
104 let mut blocks: Vec<ContentBlock> = Vec::new();
105 let mut block_types: Vec<StreamingContentType> = Vec::new();
106 let mut block_text: Vec<String> = Vec::new();
107 let mut block_json: Vec<String> = Vec::new();
108 let mut stop_reason: Option<StopReason> = None;
109 let mut usage = Usage::default();
110
111 while let Some(evt) = stream.next().await {
112 match evt? {
113 StreamEvent::MessageStart { .. } | StreamEvent::Ping | StreamEvent::MessageStop => {}
114 StreamEvent::ContentBlockStart {
115 index,
116 content_type,
117 } => {
118 let i = index as usize;
119 if blocks.len() <= i {
120 blocks.resize(
121 i + 1,
122 ContentBlock::Text(TextBlock {
123 text: String::new(),
124 cache_control: None,
125 }),
126 );
127 block_types.resize(i + 1, StreamingContentType::Text);
128 block_text.resize(i + 1, String::new());
129 block_json.resize(i + 1, String::new());
130 }
131 block_types[i] = content_type;
132 }
133 StreamEvent::Delta { index, delta } => {
134 let i = index as usize;
135 if i >= block_types.len() {
136 return Err(Error::InvalidRequest(format!(
137 "Delta event for uninitialized block index {i}"
138 )));
139 }
140 match delta {
141 StreamingDelta::Text(s) | StreamingDelta::Thinking(s) => {
142 block_text[i].push_str(&s);
143 }
144 StreamingDelta::ToolUseInputJson(s) => block_json[i].push_str(&s),
145 }
146 }
147 StreamEvent::ContentBlockStop { index } => {
148 let i = index as usize;
149 if i >= block_types.len() {
150 return Err(Error::InvalidRequest(format!(
151 "ContentBlockStop for uninitialized block index {i}"
152 )));
153 }
154 let block = match &block_types[i] {
155 StreamingContentType::Text => ContentBlock::Text(TextBlock {
156 text: std::mem::take(&mut block_text[i]),
157 cache_control: None,
158 }),
159 StreamingContentType::Thinking => ContentBlock::Thinking(ThinkingBlock {
160 thinking: std::mem::take(&mut block_text[i]),
161 signature: None,
162 }),
163 StreamingContentType::ToolUse { id, name } => {
164 let json_str = std::mem::take(&mut block_json[i]);
165 let input = if json_str.is_empty() {
166 serde_json::json!({})
167 } else {
168 serde_json::from_str(&json_str).map_err(|e| {
169 Error::InvalidRequest(format!(
170 "tool_use input json parse error: {e}"
171 ))
172 })?
173 };
174 ContentBlock::ToolUse(ToolUseBlock {
175 id: id.clone(),
176 name: name.clone(),
177 input,
178 })
179 }
180 StreamingContentType::Image => {
181 return Err(Error::InvalidRequest(
182 "streaming Image blocks are not supported in collect_message".into(),
183 ));
184 }
185 };
186 blocks[i] = block;
187 }
188 StreamEvent::MessageDelta {
189 stop_reason: sr,
190 usage_delta,
191 } => {
192 if let Some(sr) = sr {
193 stop_reason = Some(sr);
194 }
195 if let Some(u) = usage_delta {
196 usage.merge(u);
197 }
198 }
199 }
200 }
201
202 let stop = stop_reason.unwrap_or(StopReason::EndTurn);
203 Ok((
204 Message {
205 role: Role::Assistant,
206 content: blocks,
207 },
208 stop,
209 usage,
210 ))
211}
212
213pub struct WatchedStream<S> {
229 inner: S,
230 idle: Duration,
231 last_chunk_at: Instant,
232 warned: bool,
233}
234
235impl<S> WatchedStream<S> {
236 pub fn new(inner: S, idle: Duration) -> Self {
239 Self {
240 inner,
241 idle,
242 last_chunk_at: Instant::now(),
243 warned: false,
244 }
245 }
246}
247
248impl<S> Stream for WatchedStream<S>
249where
250 S: Stream<Item = Result<StreamEvent>> + Unpin,
251{
252 type Item = Result<StreamEvent>;
253
254 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255 match Pin::new(&mut self.inner).poll_next(cx) {
256 Poll::Ready(Some(item)) => {
257 self.last_chunk_at = Instant::now();
258 self.warned = false;
259 Poll::Ready(Some(item))
260 }
261 Poll::Ready(None) => Poll::Ready(None),
262 Poll::Pending => {
263 let elapsed = self.last_chunk_at.elapsed();
264 if elapsed >= self.idle {
265 tracing::error!(
266 target: "caliban::stream",
267 elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
268 "recovery.stream_idle.abort"
269 );
270 return Poll::Ready(Some(Err(Error::StreamIdle(elapsed))));
271 }
272 if !self.warned && elapsed >= self.idle / 2 {
273 self.warned = true;
274 tracing::warn!(
275 target: "caliban::stream",
276 elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
277 "recovery.stream_idle.warning"
278 );
279 }
280 let remaining = self.idle.checked_sub(elapsed).unwrap_or(Duration::ZERO);
284 let waker = cx.waker().clone();
285 tokio::spawn(async move {
286 tokio::time::sleep(remaining + Duration::from_millis(1)).await;
287 waker.wake();
288 });
289 Poll::Pending
290 }
291 }
292 }
293}
294
295#[cfg(test)]
296mod watched_tests {
297 use super::*;
298 use futures::stream;
299 use std::time::Duration;
300
301 #[tokio::test]
302 async fn passes_through_normal_data() {
303 let inner = stream::iter(vec![
304 Ok(StreamEvent::MessageStop),
305 Ok(StreamEvent::MessageStop),
306 ]);
307 let mut w = WatchedStream::new(inner, Duration::from_secs(1));
308 let mut seen = 0;
309 while let Some(item) = w.next().await {
310 item.unwrap();
311 seen += 1;
312 }
313 assert_eq!(seen, 2);
314 }
315
316 #[tokio::test]
317 async fn aborts_after_idle_timeout() {
318 let inner = stream::pending::<Result<StreamEvent>>();
319 let mut w = WatchedStream::new(inner, Duration::from_millis(20));
320 let r = w.next().await.expect("Some(_)");
321 assert!(matches!(r, Err(Error::StreamIdle(_))));
322 }
323}