use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use futures::StreamExt;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::message::{ContentBlock, Message, Role, TextBlock};
use crate::response::{StopReason, Usage};
use crate::thinking::ThinkingBlock;
use crate::tool::ToolUseBlock;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
MessageStart {
id: String,
model: String,
},
ContentBlockStart {
index: u32,
content_type: StreamingContentType,
},
Delta {
index: u32,
delta: StreamingDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
#[serde(default, skip_serializing_if = "Option::is_none")]
stop_reason: Option<StopReason>,
#[serde(default, skip_serializing_if = "Option::is_none")]
usage_delta: Option<Usage>,
},
MessageStop,
Ping,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamingContentType {
Text,
ToolUse {
id: String,
name: String,
},
Thinking,
Image,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamingDelta {
Text(String),
ToolUseInputJson(String),
Thinking(String),
}
pub type MessageStream = Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + 'static>>;
#[allow(clippy::too_many_lines)]
pub async fn collect_message(mut stream: MessageStream) -> Result<(Message, StopReason, Usage)> {
let mut blocks: Vec<ContentBlock> = Vec::new();
let mut block_types: Vec<StreamingContentType> = Vec::new();
let mut block_text: Vec<String> = Vec::new();
let mut block_json: Vec<String> = Vec::new();
let mut stop_reason: Option<StopReason> = None;
let mut usage = Usage::default();
while let Some(evt) = stream.next().await {
match evt? {
StreamEvent::MessageStart { .. } | StreamEvent::Ping | StreamEvent::MessageStop => {}
StreamEvent::ContentBlockStart {
index,
content_type,
} => {
let i = index as usize;
if blocks.len() <= i {
blocks.resize(
i + 1,
ContentBlock::Text(TextBlock {
text: String::new(),
cache_control: None,
}),
);
block_types.resize(i + 1, StreamingContentType::Text);
block_text.resize(i + 1, String::new());
block_json.resize(i + 1, String::new());
}
block_types[i] = content_type;
}
StreamEvent::Delta { index, delta } => {
let i = index as usize;
if i >= block_types.len() {
return Err(Error::InvalidRequest(format!(
"Delta event for uninitialized block index {i}"
)));
}
match delta {
StreamingDelta::Text(s) | StreamingDelta::Thinking(s) => {
block_text[i].push_str(&s);
}
StreamingDelta::ToolUseInputJson(s) => block_json[i].push_str(&s),
}
}
StreamEvent::ContentBlockStop { index } => {
let i = index as usize;
if i >= block_types.len() {
return Err(Error::InvalidRequest(format!(
"ContentBlockStop for uninitialized block index {i}"
)));
}
let block = match &block_types[i] {
StreamingContentType::Text => ContentBlock::Text(TextBlock {
text: std::mem::take(&mut block_text[i]),
cache_control: None,
}),
StreamingContentType::Thinking => ContentBlock::Thinking(ThinkingBlock {
thinking: std::mem::take(&mut block_text[i]),
signature: None,
}),
StreamingContentType::ToolUse { id, name } => {
let json_str = std::mem::take(&mut block_json[i]);
let input = if json_str.is_empty() {
serde_json::json!({})
} else {
serde_json::from_str(&json_str).map_err(|e| {
Error::InvalidRequest(format!(
"tool_use input json parse error: {e}"
))
})?
};
ContentBlock::ToolUse(ToolUseBlock {
id: id.clone(),
name: name.clone(),
input,
})
}
StreamingContentType::Image => {
return Err(Error::InvalidRequest(
"streaming Image blocks are not supported in collect_message".into(),
));
}
};
blocks[i] = block;
}
StreamEvent::MessageDelta {
stop_reason: sr,
usage_delta,
} => {
if let Some(sr) = sr {
stop_reason = Some(sr);
}
if let Some(u) = usage_delta {
usage.merge(u);
}
}
}
}
let stop = stop_reason.unwrap_or(StopReason::EndTurn);
Ok((
Message {
role: Role::Assistant,
content: blocks,
},
stop,
usage,
))
}
pub struct WatchedStream<S> {
inner: S,
idle: Duration,
last_chunk_at: Instant,
warned: bool,
}
impl<S> WatchedStream<S> {
pub fn new(inner: S, idle: Duration) -> Self {
Self {
inner,
idle,
last_chunk_at: Instant::now(),
warned: false,
}
}
}
impl<S> Stream for WatchedStream<S>
where
S: Stream<Item = Result<StreamEvent>> + Unpin,
{
type Item = Result<StreamEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(item)) => {
self.last_chunk_at = Instant::now();
self.warned = false;
Poll::Ready(Some(item))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
let elapsed = self.last_chunk_at.elapsed();
if elapsed >= self.idle {
tracing::error!(
target: "caliban::stream",
elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
"recovery.stream_idle.abort"
);
return Poll::Ready(Some(Err(Error::StreamIdle(elapsed))));
}
if !self.warned && elapsed >= self.idle / 2 {
self.warned = true;
tracing::warn!(
target: "caliban::stream",
elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
"recovery.stream_idle.warning"
);
}
let remaining = self.idle.checked_sub(elapsed).unwrap_or(Duration::ZERO);
let waker = cx.waker().clone();
tokio::spawn(async move {
tokio::time::sleep(remaining + Duration::from_millis(1)).await;
waker.wake();
});
Poll::Pending
}
}
}
}
#[cfg(test)]
mod watched_tests {
use super::*;
use futures::stream;
use std::time::Duration;
#[tokio::test]
async fn passes_through_normal_data() {
let inner = stream::iter(vec![
Ok(StreamEvent::MessageStop),
Ok(StreamEvent::MessageStop),
]);
let mut w = WatchedStream::new(inner, Duration::from_secs(1));
let mut seen = 0;
while let Some(item) = w.next().await {
item.unwrap();
seen += 1;
}
assert_eq!(seen, 2);
}
#[tokio::test]
async fn aborts_after_idle_timeout() {
let inner = stream::pending::<Result<StreamEvent>>();
let mut w = WatchedStream::new(inner, Duration::from_millis(20));
let r = w.next().await.expect("Some(_)");
assert!(matches!(r, Err(Error::StreamIdle(_))));
}
}