mod handle;
mod types;
mod writer;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
use self::types::StreamReceivers;
pub use self::{
handle::ChatResponseHandle,
types::{
ChatResponseSharedState, ChatResult, ResponseEvent, StreamChunk, StreamError, ToolCallEvent,
},
writer::{ChatResponseWriter, WriterError},
};
const CHANNEL_BUFFER: usize = 256;
#[must_use]
pub fn channel() -> (ChatResponseWriter, ChatResponseHandle) {
let (text_tx, text_rx) = mpsc::channel(CHANNEL_BUFFER);
let (thought_tx, thought_rx) = mpsc::channel(CHANNEL_BUFFER);
let (tool_call_tx, tool_call_rx) = mpsc::channel(CHANNEL_BUFFER);
let (error_tx, error_rx) = mpsc::channel(1);
let (event_tx, event_rx) = mpsc::channel(CHANNEL_BUFFER);
let (step_tx, step_rx) = mpsc::channel(CHANNEL_BUFFER);
let (chunk_tx, chunk_rx) = mpsc::channel(CHANNEL_BUFFER);
let shared_state = Arc::new(Mutex::new(ChatResponseSharedState::default()));
let writer = ChatResponseWriter {
text_tx,
thought_tx,
tool_call_tx,
error_tx,
event_tx,
step_tx,
chunk_tx,
shared_state: Arc::clone(&shared_state),
};
let handle = ChatResponseHandle {
rx: StreamReceivers::new(
text_rx,
thought_rx,
tool_call_rx,
error_rx,
event_rx,
step_rx,
chunk_rx,
),
usage: None,
structured_output_value: None,
shared_state,
};
(writer, handle)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn streaming_receives_all_tokens_in_order() {
let (writer, mut handle) = channel();
let tokens = ["Hello", " ", "world", "!"];
let expected: String = tokens.iter().copied().collect();
let send_task = tokio::spawn(async move {
for token in &["Hello", " ", "world", "!"] {
writer
.text_tx
.send((*token).to_owned())
.await
.expect("send should succeed");
}
});
let mut rx = handle.take_text_stream().expect("should get receiver");
let mut received = Vec::new();
while let Some(token) = rx.recv().await {
received.push(token);
}
send_task.await.expect("send task should complete");
let full: String = received.iter().map(String::as_str).collect();
assert_eq!(full, expected);
}
#[tokio::test]
async fn text_returns_complete_response() {
let (writer, handle) = channel();
tokio::spawn(async move {
for token in &["The ", "answer ", "is ", "42."] {
writer
.text_tx
.send((*token).to_owned())
.await
.expect("send");
}
});
let text = handle.text().await.expect("should succeed");
assert_eq!(text, "The answer is 42.");
}
#[tokio::test]
async fn text_returns_empty_when_no_tokens() {
let (writer, handle) = channel();
drop(writer);
let text = handle.text().await.expect("should succeed");
assert!(text.is_empty());
}
#[tokio::test]
async fn stream_error_propagated() {
let (writer, handle) = channel();
tokio::spawn(async move {
writer
.text_tx
.send("partial".to_owned())
.await
.expect("send");
writer
.error_tx
.send(StreamError {
message: "Python exception: quota exceeded".to_owned(),
})
.await
.expect("send error");
});
let result = handle.text().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("quota exceeded"));
}
#[tokio::test]
async fn thought_stream_works() {
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.thought_tx
.send("thinking...".to_owned())
.await
.expect("send");
writer
.thought_tx
.send("done.".to_owned())
.await
.expect("send");
});
let mut rx = handle.take_thought_stream().expect("should get receiver");
let mut thoughts = Vec::new();
while let Some(t) = rx.recv().await {
thoughts.push(t);
}
assert_eq!(thoughts, vec!["thinking...", "done."]);
}
#[tokio::test]
async fn tool_call_stream_works() {
let (writer, mut handle) = channel();
let event = ToolCallEvent {
name: "view_file".to_owned(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
id: Some("call_1".to_owned()),
canonical_path: None,
};
let event_clone = event.clone();
tokio::spawn(async move {
writer.tool_call_tx.send(event_clone).await.expect("send");
});
let mut rx = handle.take_tool_call_stream().expect("should get receiver");
let received = rx.recv().await.expect("should receive event");
assert_eq!(received.name, "view_file");
assert_eq!(received.id, Some("call_1".to_owned()));
}
#[tokio::test]
async fn usage_metadata_available_after_finalize() {
let (writer, mut handle) = channel();
assert!(handle.usage_metadata().is_none());
writer.set_usage(crate::types::UsageMetadata {
prompt_token_count: Some(100),
cached_content_token_count: Some(10),
candidates_token_count: Some(50),
thoughts_token_count: Some(20),
total_token_count: Some(170),
});
drop(writer);
handle.finalize();
let usage = handle.usage_metadata().expect("should have usage");
assert_eq!(usage.prompt_token_count, Some(100));
assert_eq!(usage.total_token_count, Some(170));
}
#[test]
fn take_text_stream_returns_none_second_time() {
let (_writer, mut handle) = channel();
assert!(handle.take_text_stream().is_some());
assert!(handle.take_text_stream().is_none());
}
#[test]
fn tool_call_event_serde_roundtrip() {
let event = ToolCallEvent {
name: "run_command".to_owned(),
args: serde_json::json!({"command": "ls"}),
id: Some("call_42".to_owned()),
canonical_path: None,
};
let json = serde_json::to_string(&event).expect("serialize");
let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.name, event.name);
assert_eq!(parsed.args, event.args);
assert_eq!(parsed.id, event.id);
}
#[test]
fn take_thought_stream_returns_none_second_time() {
let (_writer, mut handle) = channel();
assert!(handle.take_thought_stream().is_some());
assert!(handle.take_thought_stream().is_none());
}
#[test]
fn take_tool_call_stream_returns_none_second_time() {
let (_writer, mut handle) = channel();
assert!(handle.take_tool_call_stream().is_some());
assert!(handle.take_tool_call_stream().is_none());
}
#[test]
fn stream_error_display() {
let err = StreamError {
message: "quota exceeded".to_owned(),
};
assert_eq!(format!("{err}"), "stream error: quota exceeded");
}
#[test]
fn stream_error_is_std_error() {
let err = StreamError {
message: "test".to_owned(),
};
let _: &dyn std::error::Error = &err;
}
#[tokio::test]
async fn concurrent_text_and_thought_streams() {
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.text_tx
.send("Hello".to_owned())
.await
.expect("send text");
writer
.thought_tx
.send("thinking...".to_owned())
.await
.expect("send thought");
});
let mut text_rx = handle.take_text_stream().expect("text rx");
let mut thought_rx = handle.take_thought_stream().expect("thought rx");
let text = text_rx.recv().await.expect("receive text");
let thought = thought_rx.recv().await.expect("receive thought");
assert_eq!(text, "Hello");
assert_eq!(thought, "thinking...");
}
#[tokio::test]
async fn writer_dropped_without_sending_closes_text() {
let (writer, handle) = channel();
drop(writer);
let text = handle.text().await.expect("should succeed");
assert!(text.is_empty());
}
#[tokio::test]
async fn writer_dropped_without_sending_closes_thought_stream() {
let (writer, mut handle) = channel();
drop(writer);
let mut thought_rx = handle.take_thought_stream().expect("rx");
assert!(thought_rx.recv().await.is_none());
}
#[test]
fn tool_call_event_without_id() {
let event = ToolCallEvent {
name: "custom".to_owned(),
args: serde_json::json!(null),
id: None,
canonical_path: None,
};
let json = serde_json::to_string(&event).expect("serialize");
let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.name, "custom");
assert_eq!(parsed.args, serde_json::json!(null));
}
#[tokio::test]
async fn large_token_stream() {
let (writer, handle) = channel();
let token_count = 200;
tokio::spawn(async move {
for i in 0..token_count {
writer.text_tx.send(format!("t{i}")).await.expect("send");
}
});
let text = handle.text().await.expect("should succeed");
for i in 0..token_count {
assert!(
text.contains(&format!("t{i}")),
"Missing token t{i} in output"
);
}
}
#[tokio::test]
async fn resolve_returns_events_in_order() {
let (writer, handle) = channel();
let tool_event = ToolCallEvent {
name: "view_file".to_owned(),
args: serde_json::json!({"path": "/tmp/x.rs"}),
id: Some("call_1".to_owned()),
canonical_path: None,
};
let tool_clone = tool_event.clone();
tokio::spawn(async move {
writer
.event_tx
.send(ResponseEvent::TextChunk("Hello ".to_owned()))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::ThoughtChunk("hmm".to_owned()))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::ToolCall(tool_clone))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::TextChunk("world".to_owned()))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::ToolResult(crate::types::ToolResult {
name: "view_file".to_owned(),
id: Some("call_1".to_owned()),
result: serde_json::json!({"output": "file contents"}),
error: None,
}))
.await
.expect("send");
});
let events = handle.resolve().await;
assert_eq!(events.len(), 5, "Expected 5 events, got {}", events.len());
assert!(
matches!(&events[0], ResponseEvent::TextChunk(s) if s == "Hello "),
"events[0] should be TextChunk(\"Hello \")"
);
assert!(
matches!(&events[1], ResponseEvent::ThoughtChunk(s) if s == "hmm"),
"events[1] should be ThoughtChunk(\"hmm\")"
);
assert!(
matches!(&events[2], ResponseEvent::ToolCall(tc) if tc.name == "view_file"),
"events[2] should be ToolCall(view_file)"
);
assert!(
matches!(&events[3], ResponseEvent::TextChunk(s) if s == "world"),
"events[3] should be TextChunk(\"world\")"
);
assert!(
matches!(&events[4], ResponseEvent::ToolResult(tr) if tr.name == "view_file"),
"events[4] should be ToolResult(view_file)"
);
}
#[test]
fn response_event_serde_roundtrip() {
let events = vec![
ResponseEvent::TextChunk("hello".to_owned()),
ResponseEvent::ThoughtChunk("thinking".to_owned()),
ResponseEvent::ToolCall(ToolCallEvent {
name: "run_command".to_owned(),
args: serde_json::json!({"cmd": "ls"}),
id: Some("c1".to_owned()),
canonical_path: None,
}),
ResponseEvent::ToolResult(crate::types::ToolResult {
name: "run_command".to_owned(),
id: Some("c1".to_owned()),
result: serde_json::json!({"output": "done"}),
error: None,
}),
];
let json = serde_json::to_string(&events).expect("serialize");
let parsed: Vec<ResponseEvent> = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.len(), events.len());
}
#[tokio::test]
async fn receive_chunks_returns_chunks_in_order() {
use tokio_stream::StreamExt;
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.chunk_tx
.send(StreamChunk::Text("hello".to_owned()))
.await
.expect("send");
writer
.chunk_tx
.send(StreamChunk::Thought("hmm".to_owned()))
.await
.expect("send");
writer
.chunk_tx
.send(StreamChunk::ToolCall(ToolCallEvent {
name: "view_file".to_owned(),
args: serde_json::json!({}),
id: None,
canonical_path: None,
}))
.await
.expect("send");
writer
.chunk_tx
.send(StreamChunk::Text(" world".to_owned()))
.await
.expect("send");
});
let mut stream = handle.receive_chunks().expect("should get stream");
let mut items = Vec::new();
while let Some(chunk) = stream.next().await {
items.push(chunk);
}
assert_eq!(items.len(), 4);
assert!(matches!(&items[0], StreamChunk::Text(t) if t == "hello"));
assert!(matches!(&items[1], StreamChunk::Thought(t) if t == "hmm"));
assert!(matches!(&items[2], StreamChunk::ToolCall(tc) if tc.name == "view_file"));
assert!(matches!(&items[3], StreamChunk::Text(t) if t == " world"));
}
#[tokio::test]
async fn receive_steps_returns_steps() {
use tokio_stream::StreamExt;
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.step_tx
.send(crate::types::Step {
id: "step-0".to_owned(),
step_index: 0,
step_type: crate::types::StepType::TextResponse,
source: crate::types::StepSource::Model,
target: crate::types::StepTarget::User,
status: crate::types::StepStatus::Done,
content: "Hello".to_owned(),
content_delta: "Hello".to_owned(),
thinking: String::new(),
thinking_delta: String::new(),
tool_calls: vec![],
error: String::new(),
is_complete_response: Some(true),
structured_output: None,
usage_metadata: None,
})
.await
.expect("send");
});
let mut stream = handle.receive_steps().expect("should get stream");
let step = stream.next().await.expect("should get a step");
assert_eq!(step.id, "step-0");
assert_eq!(step.step_type, crate::types::StepType::TextResponse);
assert_eq!(step.content, "Hello");
}
#[tokio::test]
async fn existing_channels_work_alongside_chunk_stream() {
use tokio_stream::StreamExt;
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.text_tx
.send("text-tok".to_owned())
.await
.expect("send text");
writer
.chunk_tx
.send(StreamChunk::Text("text-tok".to_owned()))
.await
.expect("send chunk");
});
let mut text_rx = handle.take_text_stream().expect("text rx");
let text = text_rx.recv().await.expect("receive text");
assert_eq!(text, "text-tok");
let mut chunk_stream = handle.receive_chunks().expect("chunk stream");
let chunk = chunk_stream.next().await.expect("receive chunk");
assert!(matches!(chunk, StreamChunk::Text(t) if t == "text-tok"));
}
#[test]
fn receive_chunks_returns_none_on_second_call() {
let (_writer, mut handle) = channel();
assert!(handle.receive_chunks().is_some());
assert!(handle.receive_chunks().is_none());
}
#[test]
fn receive_steps_returns_none_on_second_call() {
let (_writer, mut handle) = channel();
assert!(handle.receive_steps().is_some());
assert!(handle.receive_steps().is_none());
}
#[test]
fn stream_chunk_serde_roundtrip() {
let chunks = vec![
StreamChunk::Text("hello".to_owned()),
StreamChunk::Thought("hmm".to_owned()),
StreamChunk::ToolCall(ToolCallEvent {
name: "run".to_owned(),
args: serde_json::json!({"cmd": "ls"}),
id: Some("c1".to_owned()),
canonical_path: None,
}),
];
for chunk in &chunks {
let json = serde_json::to_string(chunk).expect("serialize");
let parsed: StreamChunk = serde_json::from_str(&json).expect("deserialize");
match (chunk, &parsed) {
(StreamChunk::Text(a), StreamChunk::Text(b))
| (StreamChunk::Thought(a), StreamChunk::Thought(b)) => assert_eq!(a, b),
(StreamChunk::ToolCall(a), StreamChunk::ToolCall(b)) => {
assert_eq!(a.name, b.name);
assert_eq!(a.id, b.id);
}
_ => panic!("variant mismatch after roundtrip"),
}
}
}
#[tokio::test]
async fn usage_metadata_populated_from_writer_after_resolve() {
let (writer, handle) = channel();
tokio::spawn(async move {
writer
.event_tx
.send(ResponseEvent::TextChunk("hello".to_owned()))
.await
.unwrap();
writer.set_usage(crate::types::UsageMetadata {
prompt_token_count: Some(5),
cached_content_token_count: None,
candidates_token_count: Some(1),
thoughts_token_count: None,
total_token_count: Some(6),
});
writer.set_structured_output(serde_json::json!({"key": "value"}));
});
let shared = handle.shared_state();
let events = handle.resolve().await;
assert_eq!(events.len(), 1);
let state = shared.lock().expect("lock shared state");
assert_eq!(state.usage.as_ref().unwrap().total_token_count, Some(6));
assert_eq!(
state.structured_output.as_ref().unwrap(),
&serde_json::json!({"key": "value"})
);
}
#[test]
fn chat_result_into_string() {
let (writer, handle) = channel();
drop(writer);
let rt = tokio::runtime::Runtime::new().unwrap();
let result = rt.block_on(handle.text()).unwrap();
let s: String = result.into();
assert!(s.is_empty());
}
}