use std::{
sync::{Arc, Mutex},
time::Duration,
};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::types::{Step, UsageMetadata};
#[derive(Debug, Clone)]
pub struct ChatResult {
text: String,
usage: Option<UsageMetadata>,
structured_output: Option<serde_json::Value>,
}
impl ChatResult {
#[must_use]
pub fn text(&self) -> &str {
&self.text
}
#[must_use]
pub fn into_string(self) -> String {
self.text
}
#[must_use]
pub fn usage(&self) -> Option<&UsageMetadata> {
self.usage.as_ref()
}
#[must_use]
pub fn structured_output(&self) -> Option<&serde_json::Value> {
self.structured_output.as_ref()
}
}
impl std::ops::Deref for ChatResult {
type Target = str;
fn deref(&self) -> &str {
&self.text
}
}
impl PartialEq<&str> for ChatResult {
fn eq(&self, other: &&str) -> bool {
self.text == *other
}
}
impl PartialEq<String> for ChatResult {
fn eq(&self, other: &String) -> bool {
self.text == *other
}
}
impl std::fmt::Display for ChatResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.text)
}
}
impl From<ChatResult> for String {
fn from(result: ChatResult) -> Self {
result.text
}
}
pub(crate) const ERROR_DRAIN_TIMEOUT: Duration = Duration::from_millis(50);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallEvent {
pub name: String,
pub args: serde_json::Value,
pub id: Option<String>,
#[serde(default)]
pub canonical_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamError {
pub message: String,
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "stream error: {}", self.message)
}
}
impl std::error::Error for StreamError {}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ResponseEvent {
TextChunk(String),
ThoughtChunk(String),
ToolCall(ToolCallEvent),
ToolResult(crate::types::ToolResult),
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StreamChunk {
Text(String),
Thought(String),
ToolCall(ToolCallEvent),
}
#[doc(hidden)]
#[derive(Debug, Default)]
pub struct ChatResponseSharedState {
pub usage: Option<UsageMetadata>,
pub structured_output: Option<serde_json::Value>,
}
#[derive(Debug)]
pub(crate) struct StreamReceivers {
text: Option<mpsc::Receiver<String>>,
thought: Option<mpsc::Receiver<String>>,
tool_call: Option<mpsc::Receiver<ToolCallEvent>>,
error: Option<mpsc::Receiver<StreamError>>,
event: Option<mpsc::Receiver<ResponseEvent>>,
step: Option<mpsc::Receiver<Step>>,
chunk: Option<mpsc::Receiver<StreamChunk>>,
}
#[derive(Debug)]
pub struct ChatResponseHandle {
rx: StreamReceivers,
usage: Option<UsageMetadata>,
structured_output_value: Option<serde_json::Value>,
pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
pub(crate) keep_alive_permit: Option<tokio::sync::OwnedSemaphorePermit>,
}
#[derive(Debug)]
pub struct WriterError {
pub message: String,
}
impl WriterError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for WriterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for WriterError {}
impl<T> From<mpsc::error::SendError<T>> for WriterError {
fn from(err: mpsc::error::SendError<T>) -> Self {
Self {
message: format!("channel send failed: {err}"),
}
}
}
pub struct ChatResponseWriter {
pub(crate) text_tx: mpsc::Sender<String>,
pub(crate) thought_tx: mpsc::Sender<String>,
pub(crate) tool_call_tx: mpsc::Sender<ToolCallEvent>,
pub(crate) error_tx: mpsc::Sender<StreamError>,
pub(crate) event_tx: mpsc::Sender<ResponseEvent>,
pub(crate) step_tx: mpsc::Sender<Step>,
pub(crate) chunk_tx: mpsc::Sender<StreamChunk>,
pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
}
impl ChatResponseWriter {
pub async fn send_text(&self, text: String) -> Result<(), WriterError> {
self.text_tx.send(text).await.map_err(WriterError::from)
}
pub async fn send_thought(&self, thought: String) -> Result<(), WriterError> {
self.thought_tx
.send(thought)
.await
.map_err(WriterError::from)
}
pub async fn send_tool_call(&self, event: ToolCallEvent) -> Result<(), WriterError> {
self.tool_call_tx
.send(event)
.await
.map_err(WriterError::from)
}
pub async fn send_error(&self, error: StreamError) -> Result<(), WriterError> {
self.error_tx.send(error).await.map_err(WriterError::from)
}
pub async fn send_event(&self, event: ResponseEvent) -> Result<(), WriterError> {
self.event_tx.send(event).await.map_err(WriterError::from)
}
pub async fn send_step(&self, step: crate::types::Step) -> Result<(), WriterError> {
self.step_tx.send(step).await.map_err(WriterError::from)
}
pub async fn send_chunk(&self, chunk: StreamChunk) -> Result<(), WriterError> {
self.chunk_tx.send(chunk).await.map_err(WriterError::from)
}
}
const CHANNEL_BUFFER: usize = 256;
#[must_use]
pub fn channel() -> (ChatResponseWriter, ChatResponseHandle) {
let (text_tx, text_rx) = mpsc::channel(CHANNEL_BUFFER);
let (thought_tx, thought_rx) = mpsc::channel(CHANNEL_BUFFER);
let (tool_call_tx, tool_call_rx) = mpsc::channel(CHANNEL_BUFFER);
let (error_tx, error_rx) = mpsc::channel(1);
let (event_tx, event_rx) = mpsc::channel(CHANNEL_BUFFER);
let (step_tx, step_rx) = mpsc::channel(CHANNEL_BUFFER);
let (chunk_tx, chunk_rx) = mpsc::channel(CHANNEL_BUFFER);
let shared_state = Arc::new(Mutex::new(ChatResponseSharedState::default()));
let writer = ChatResponseWriter {
text_tx,
thought_tx,
tool_call_tx,
error_tx,
event_tx,
step_tx,
chunk_tx,
shared_state: Arc::clone(&shared_state),
};
let handle = ChatResponseHandle {
keep_alive_permit: None,
rx: StreamReceivers {
text: Some(text_rx),
thought: Some(thought_rx),
tool_call: Some(tool_call_rx),
error: Some(error_rx),
event: Some(event_rx),
step: Some(step_rx),
chunk: Some(chunk_rx),
},
usage: None,
structured_output_value: None,
shared_state,
};
(writer, handle)
}
impl ChatResponseHandle {
pub const fn take_text_stream(&mut self) -> Option<mpsc::Receiver<String>> {
self.rx.text.take()
}
pub const fn take_thought_stream(&mut self) -> Option<mpsc::Receiver<String>> {
self.rx.thought.take()
}
pub const fn take_tool_call_stream(&mut self) -> Option<mpsc::Receiver<ToolCallEvent>> {
self.rx.tool_call.take()
}
pub const fn take_step_stream(&mut self) -> Option<mpsc::Receiver<Step>> {
self.rx.step.take()
}
pub fn receive_steps(&mut self) -> Option<impl tokio_stream::Stream<Item = Step>> {
self.rx.step.take().map(ReceiverStream::new)
}
pub fn receive_chunks(&mut self) -> Option<impl tokio_stream::Stream<Item = StreamChunk>> {
self.rx.chunk.take().map(ReceiverStream::new)
}
pub async fn text(mut self) -> Result<ChatResult, StreamError> {
let mut buf = String::new();
if let Some(mut rx) = self.rx.text.take() {
while let Some(token) = rx.recv().await {
buf.push_str(&token);
}
}
if let Some(mut err_rx) = self.rx.error.take()
&& let Ok(Some(err)) = tokio::time::timeout(ERROR_DRAIN_TIMEOUT, err_rx.recv()).await
{
return Err(err);
}
self.finalize();
Ok(ChatResult {
text: buf,
usage: self.usage,
structured_output: self.structured_output_value,
})
}
pub fn finalize(&mut self) {
if let Ok(state) = self.shared_state.lock() {
self.usage = state.usage.clone();
self.structured_output_value = state.structured_output.clone();
} else {
tracing::error!(
"ChatResponseHandle shared_state mutex poisoned during finalize — \
usage and structured_output will be unavailable"
);
}
}
#[must_use]
pub const fn structured_output(&self) -> Option<&serde_json::Value> {
self.structured_output_value.as_ref()
}
#[must_use]
pub const fn usage_metadata(&self) -> Option<&UsageMetadata> {
self.usage.as_ref()
}
#[doc(hidden)]
#[must_use]
pub fn shared_state(&self) -> Arc<Mutex<ChatResponseSharedState>> {
Arc::clone(&self.shared_state)
}
pub async fn resolve(mut self) -> Vec<ResponseEvent> {
let mut events = Vec::new();
if let Some(mut rx) = self.rx.event.take() {
while let Some(event) = rx.recv().await {
events.push(event);
}
}
self.finalize();
events
}
}
impl ChatResponseWriter {
pub fn set_usage(&self, usage: crate::types::UsageMetadata) {
match self.shared_state.lock() {
Ok(mut state) => {
state.usage = Some(usage);
}
Err(e) => {
tracing::error!(
error = %e,
"ChatResponseWriter shared_state mutex poisoned in set_usage"
);
}
}
}
pub fn set_structured_output(&self, value: serde_json::Value) {
match self.shared_state.lock() {
Ok(mut state) => {
state.structured_output = Some(value);
}
Err(e) => {
tracing::error!(
error = %e,
"ChatResponseWriter shared_state mutex poisoned in set_structured_output"
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn streaming_receives_all_tokens_in_order() {
let (writer, mut handle) = channel();
let tokens = ["Hello", " ", "world", "!"];
let expected: String = tokens.iter().copied().collect();
let send_task = tokio::spawn(async move {
for token in &["Hello", " ", "world", "!"] {
writer
.text_tx
.send((*token).to_owned())
.await
.expect("send should succeed");
}
});
let mut rx = handle.take_text_stream().expect("should get receiver");
let mut received = Vec::new();
while let Some(token) = rx.recv().await {
received.push(token);
}
send_task.await.expect("send task should complete");
let full: String = received.iter().map(String::as_str).collect();
assert_eq!(full, expected);
}
#[tokio::test]
async fn text_returns_complete_response() {
let (writer, handle) = channel();
tokio::spawn(async move {
for token in &["The ", "answer ", "is ", "42."] {
writer
.text_tx
.send((*token).to_owned())
.await
.expect("send");
}
});
let text = handle.text().await.expect("should succeed");
assert_eq!(text, "The answer is 42.");
}
#[tokio::test]
async fn text_returns_empty_when_no_tokens() {
let (writer, handle) = channel();
drop(writer);
let text = handle.text().await.expect("should succeed");
assert!(text.is_empty());
}
#[tokio::test]
async fn stream_error_propagated() {
let (writer, handle) = channel();
tokio::spawn(async move {
writer
.text_tx
.send("partial".to_owned())
.await
.expect("send");
writer
.error_tx
.send(StreamError {
message: "Python exception: quota exceeded".to_owned(),
})
.await
.expect("send error");
});
let result = handle.text().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("quota exceeded"));
}
#[tokio::test]
async fn thought_stream_works() {
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.thought_tx
.send("thinking...".to_owned())
.await
.expect("send");
writer
.thought_tx
.send("done.".to_owned())
.await
.expect("send");
});
let mut rx = handle.take_thought_stream().expect("should get receiver");
let mut thoughts = Vec::new();
while let Some(t) = rx.recv().await {
thoughts.push(t);
}
assert_eq!(thoughts, vec!["thinking...", "done."]);
}
#[tokio::test]
async fn tool_call_stream_works() {
let (writer, mut handle) = channel();
let event = ToolCallEvent {
name: "view_file".to_owned(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
id: Some("call_1".to_owned()),
canonical_path: None,
};
let event_clone = event.clone();
tokio::spawn(async move {
writer.tool_call_tx.send(event_clone).await.expect("send");
});
let mut rx = handle.take_tool_call_stream().expect("should get receiver");
let received = rx.recv().await.expect("should receive event");
assert_eq!(received.name, "view_file");
assert_eq!(received.id, Some("call_1".to_owned()));
}
#[tokio::test]
async fn usage_metadata_available_after_finalize() {
let (writer, mut handle) = channel();
assert!(handle.usage_metadata().is_none());
writer.set_usage(UsageMetadata {
prompt_token_count: Some(100),
cached_content_token_count: Some(10),
candidates_token_count: Some(50),
thoughts_token_count: Some(20),
total_token_count: Some(170),
});
drop(writer);
handle.finalize();
let usage = handle.usage_metadata().expect("should have usage");
assert_eq!(usage.prompt_token_count, Some(100));
assert_eq!(usage.total_token_count, Some(170));
}
#[test]
fn take_text_stream_returns_none_second_time() {
let (_writer, mut handle) = channel();
assert!(handle.take_text_stream().is_some());
assert!(handle.take_text_stream().is_none());
}
#[test]
fn tool_call_event_serde_roundtrip() {
let event = ToolCallEvent {
name: "run_command".to_owned(),
args: serde_json::json!({"command": "ls"}),
id: Some("call_42".to_owned()),
canonical_path: None,
};
let json = serde_json::to_string(&event).expect("serialize");
let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.name, event.name);
assert_eq!(parsed.args, event.args);
assert_eq!(parsed.id, event.id);
}
#[test]
fn take_thought_stream_returns_none_second_time() {
let (_writer, mut handle) = channel();
assert!(handle.take_thought_stream().is_some());
assert!(handle.take_thought_stream().is_none());
}
#[test]
fn take_tool_call_stream_returns_none_second_time() {
let (_writer, mut handle) = channel();
assert!(handle.take_tool_call_stream().is_some());
assert!(handle.take_tool_call_stream().is_none());
}
#[test]
fn stream_error_display() {
let err = StreamError {
message: "quota exceeded".to_owned(),
};
assert_eq!(format!("{err}"), "stream error: quota exceeded");
}
#[test]
fn stream_error_is_std_error() {
let err = StreamError {
message: "test".to_owned(),
};
let _: &dyn std::error::Error = &err;
}
#[tokio::test]
async fn concurrent_text_and_thought_streams() {
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.text_tx
.send("Hello".to_owned())
.await
.expect("send text");
writer
.thought_tx
.send("thinking...".to_owned())
.await
.expect("send thought");
});
let mut text_rx = handle.take_text_stream().expect("text rx");
let mut thought_rx = handle.take_thought_stream().expect("thought rx");
let text = text_rx.recv().await.expect("receive text");
let thought = thought_rx.recv().await.expect("receive thought");
assert_eq!(text, "Hello");
assert_eq!(thought, "thinking...");
}
#[tokio::test]
async fn writer_dropped_without_sending_closes_text() {
let (writer, handle) = channel();
drop(writer);
let text = handle.text().await.expect("should succeed");
assert!(text.is_empty());
}
#[tokio::test]
async fn writer_dropped_without_sending_closes_thought_stream() {
let (writer, mut handle) = channel();
drop(writer);
let mut thought_rx = handle.take_thought_stream().expect("rx");
assert!(thought_rx.recv().await.is_none());
}
#[test]
fn tool_call_event_without_id() {
let event = ToolCallEvent {
name: "custom".to_owned(),
args: serde_json::json!(null),
id: None,
canonical_path: None,
};
let json = serde_json::to_string(&event).expect("serialize");
let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.name, "custom");
assert_eq!(parsed.args, serde_json::json!(null));
}
#[tokio::test]
async fn large_token_stream() {
let (writer, handle) = channel();
let token_count = 200;
tokio::spawn(async move {
for i in 0..token_count {
writer.text_tx.send(format!("t{i}")).await.expect("send");
}
});
let text = handle.text().await.expect("should succeed");
for i in 0..token_count {
assert!(
text.contains(&format!("t{i}")),
"Missing token t{i} in output"
);
}
}
#[tokio::test]
async fn resolve_returns_events_in_order() {
let (writer, handle) = channel();
let tool_event = ToolCallEvent {
name: "view_file".to_owned(),
args: serde_json::json!({"path": "/tmp/x.rs"}),
id: Some("call_1".to_owned()),
canonical_path: None,
};
let tool_clone = tool_event.clone();
tokio::spawn(async move {
writer
.event_tx
.send(ResponseEvent::TextChunk("Hello ".to_owned()))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::ThoughtChunk("hmm".to_owned()))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::ToolCall(tool_clone))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::TextChunk("world".to_owned()))
.await
.expect("send");
writer
.event_tx
.send(ResponseEvent::ToolResult(crate::types::ToolResult {
name: "view_file".to_owned(),
id: Some("call_1".to_owned()),
result: serde_json::json!({"output": "file contents"}),
error: None,
}))
.await
.expect("send");
});
let events = handle.resolve().await;
assert_eq!(events.len(), 5, "Expected 5 events, got {}", events.len());
assert!(
matches!(&events[0], ResponseEvent::TextChunk(s) if s == "Hello "),
"events[0] should be TextChunk(\"Hello \")"
);
assert!(
matches!(&events[1], ResponseEvent::ThoughtChunk(s) if s == "hmm"),
"events[1] should be ThoughtChunk(\"hmm\")"
);
assert!(
matches!(&events[2], ResponseEvent::ToolCall(tc) if tc.name == "view_file"),
"events[2] should be ToolCall(view_file)"
);
assert!(
matches!(&events[3], ResponseEvent::TextChunk(s) if s == "world"),
"events[3] should be TextChunk(\"world\")"
);
assert!(
matches!(&events[4], ResponseEvent::ToolResult(tr) if tr.name == "view_file"),
"events[4] should be ToolResult(view_file)"
);
}
#[test]
fn response_event_serde_roundtrip() {
let events = vec![
ResponseEvent::TextChunk("hello".to_owned()),
ResponseEvent::ThoughtChunk("thinking".to_owned()),
ResponseEvent::ToolCall(ToolCallEvent {
name: "run_command".to_owned(),
args: serde_json::json!({"cmd": "ls"}),
id: Some("c1".to_owned()),
canonical_path: None,
}),
ResponseEvent::ToolResult(crate::types::ToolResult {
name: "run_command".to_owned(),
id: Some("c1".to_owned()),
result: serde_json::json!({"output": "done"}),
error: None,
}),
];
let json = serde_json::to_string(&events).expect("serialize");
let parsed: Vec<ResponseEvent> = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.len(), events.len());
}
#[tokio::test]
async fn receive_chunks_returns_chunks_in_order() {
use tokio_stream::StreamExt;
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.chunk_tx
.send(StreamChunk::Text("hello".to_owned()))
.await
.expect("send");
writer
.chunk_tx
.send(StreamChunk::Thought("hmm".to_owned()))
.await
.expect("send");
writer
.chunk_tx
.send(StreamChunk::ToolCall(ToolCallEvent {
name: "view_file".to_owned(),
args: serde_json::json!({}),
id: None,
canonical_path: None,
}))
.await
.expect("send");
writer
.chunk_tx
.send(StreamChunk::Text(" world".to_owned()))
.await
.expect("send");
});
let mut stream = handle.receive_chunks().expect("should get stream");
let mut items = Vec::new();
while let Some(chunk) = stream.next().await {
items.push(chunk);
}
assert_eq!(items.len(), 4);
assert!(matches!(&items[0], StreamChunk::Text(t) if t == "hello"));
assert!(matches!(&items[1], StreamChunk::Thought(t) if t == "hmm"));
assert!(matches!(&items[2], StreamChunk::ToolCall(tc) if tc.name == "view_file"));
assert!(matches!(&items[3], StreamChunk::Text(t) if t == " world"));
}
#[tokio::test]
async fn receive_steps_returns_steps() {
use tokio_stream::StreamExt;
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.step_tx
.send(crate::types::Step {
id: "step-0".to_owned(),
step_index: 0,
step_type: crate::types::StepType::TextResponse,
source: crate::types::StepSource::Model,
target: crate::types::StepTarget::User,
status: crate::types::StepStatus::Done,
content: "Hello".to_owned(),
content_delta: "Hello".to_owned(),
thinking: String::new(),
thinking_delta: String::new(),
tool_calls: vec![],
error: String::new(),
is_complete_response: Some(true),
structured_output: None,
usage_metadata: None,
})
.await
.expect("send");
});
let mut stream = handle.receive_steps().expect("should get stream");
let step = stream.next().await.expect("should get a step");
assert_eq!(step.id, "step-0");
assert_eq!(step.step_type, crate::types::StepType::TextResponse);
assert_eq!(step.content, "Hello");
}
#[tokio::test]
async fn existing_channels_work_alongside_chunk_stream() {
use tokio_stream::StreamExt;
let (writer, mut handle) = channel();
tokio::spawn(async move {
writer
.text_tx
.send("text-tok".to_owned())
.await
.expect("send text");
writer
.chunk_tx
.send(StreamChunk::Text("text-tok".to_owned()))
.await
.expect("send chunk");
});
let mut text_rx = handle.take_text_stream().expect("text rx");
let text = text_rx.recv().await.expect("receive text");
assert_eq!(text, "text-tok");
let mut chunk_stream = handle.receive_chunks().expect("chunk stream");
let chunk = chunk_stream.next().await.expect("receive chunk");
assert!(matches!(chunk, StreamChunk::Text(t) if t == "text-tok"));
}
#[test]
fn receive_chunks_returns_none_on_second_call() {
let (_writer, mut handle) = channel();
assert!(handle.receive_chunks().is_some());
assert!(handle.receive_chunks().is_none());
}
#[test]
fn receive_steps_returns_none_on_second_call() {
let (_writer, mut handle) = channel();
assert!(handle.receive_steps().is_some());
assert!(handle.receive_steps().is_none());
}
#[test]
fn stream_chunk_serde_roundtrip() {
let chunks = vec![
StreamChunk::Text("hello".to_owned()),
StreamChunk::Thought("hmm".to_owned()),
StreamChunk::ToolCall(ToolCallEvent {
name: "run".to_owned(),
args: serde_json::json!({"cmd": "ls"}),
id: Some("c1".to_owned()),
canonical_path: None,
}),
];
for chunk in &chunks {
let json = serde_json::to_string(chunk).expect("serialize");
let parsed: StreamChunk = serde_json::from_str(&json).expect("deserialize");
match (chunk, &parsed) {
(StreamChunk::Text(a), StreamChunk::Text(b))
| (StreamChunk::Thought(a), StreamChunk::Thought(b)) => assert_eq!(a, b),
(StreamChunk::ToolCall(a), StreamChunk::ToolCall(b)) => {
assert_eq!(a.name, b.name);
assert_eq!(a.id, b.id);
}
_ => panic!("variant mismatch after roundtrip"),
}
}
}
#[tokio::test]
async fn usage_metadata_populated_from_writer_after_resolve() {
let (writer, handle) = channel();
tokio::spawn(async move {
writer
.event_tx
.send(ResponseEvent::TextChunk("hello".to_owned()))
.await
.unwrap();
writer.set_usage(crate::types::UsageMetadata {
prompt_token_count: Some(5),
cached_content_token_count: None,
candidates_token_count: Some(1),
thoughts_token_count: None,
total_token_count: Some(6),
});
writer.set_structured_output(serde_json::json!({"key": "value"}));
});
let shared = handle.shared_state();
let events = handle.resolve().await;
assert_eq!(events.len(), 1);
let state = shared.lock().expect("lock shared state");
assert_eq!(state.usage.as_ref().unwrap().total_token_count, Some(6));
assert_eq!(
state.structured_output.as_ref().unwrap(),
&serde_json::json!({"key": "value"})
);
}
#[test]
fn chat_result_into_string() {
let (writer, handle) = channel();
drop(writer);
let rt = tokio::runtime::Runtime::new().unwrap();
let result = rt.block_on(handle.text()).unwrap();
let s: String = result.into();
assert!(s.is_empty());
}
}