1mod handle;
8mod types;
9mod writer;
10
11use std::sync::{Arc, Mutex};
12
13use tokio::sync::mpsc;
14
15use self::types::StreamReceivers;
16pub use self::{
17 handle::ChatResponseHandle,
18 types::{
19 ChatResponseSharedState, ChatResult, ResponseEvent, StreamChunk, StreamError, ToolCallEvent,
20 },
21 writer::{ChatResponseWriter, WriterError},
22};
23
24const CHANNEL_BUFFER: usize = 256;
27
28#[must_use]
33pub fn channel() -> (ChatResponseWriter, ChatResponseHandle) {
34 let (text_tx, text_rx) = mpsc::channel(CHANNEL_BUFFER);
35 let (thought_tx, thought_rx) = mpsc::channel(CHANNEL_BUFFER);
36 let (tool_call_tx, tool_call_rx) = mpsc::channel(CHANNEL_BUFFER);
37 let (error_tx, error_rx) = mpsc::channel(1);
38 let (event_tx, event_rx) = mpsc::channel(CHANNEL_BUFFER);
39 let (step_tx, step_rx) = mpsc::channel(CHANNEL_BUFFER);
40 let (chunk_tx, chunk_rx) = mpsc::channel(CHANNEL_BUFFER);
41
42 let shared_state = Arc::new(Mutex::new(ChatResponseSharedState::default()));
43
44 let writer = ChatResponseWriter {
45 text_tx,
46 thought_tx,
47 tool_call_tx,
48 error_tx,
49 event_tx,
50 step_tx,
51 chunk_tx,
52 shared_state: Arc::clone(&shared_state),
53 };
54
55 let handle = ChatResponseHandle {
56 rx: StreamReceivers::new(
57 text_rx,
58 thought_rx,
59 tool_call_rx,
60 error_rx,
61 event_rx,
62 step_rx,
63 chunk_rx,
64 ),
65 usage: None,
66 structured_output_value: None,
67 shared_state,
68 };
69
70 (writer, handle)
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[tokio::test]
78 async fn streaming_receives_all_tokens_in_order() {
79 let (writer, mut handle) = channel();
80
81 let tokens = ["Hello", " ", "world", "!"];
82 let expected: String = tokens.iter().copied().collect();
83
84 let send_task = tokio::spawn(async move {
86 for token in &["Hello", " ", "world", "!"] {
87 writer
88 .text_tx
89 .send((*token).to_owned())
90 .await
91 .expect("send should succeed");
92 }
93 });
95
96 let mut rx = handle.take_text_stream().expect("should get receiver");
98 let mut received = Vec::new();
99 while let Some(token) = rx.recv().await {
100 received.push(token);
101 }
102
103 send_task.await.expect("send task should complete");
104 let full: String = received.iter().map(String::as_str).collect();
105 assert_eq!(full, expected);
106 }
107
108 #[tokio::test]
109 async fn text_returns_complete_response() {
110 let (writer, handle) = channel();
111
112 tokio::spawn(async move {
113 for token in &["The ", "answer ", "is ", "42."] {
114 writer
115 .text_tx
116 .send((*token).to_owned())
117 .await
118 .expect("send");
119 }
120 });
121
122 let text = handle.text().await.expect("should succeed");
123 assert_eq!(text, "The answer is 42.");
124 }
125
126 #[tokio::test]
127 async fn text_returns_empty_when_no_tokens() {
128 let (writer, handle) = channel();
129 drop(writer);
131
132 let text = handle.text().await.expect("should succeed");
133 assert!(text.is_empty());
134 }
135
136 #[tokio::test]
137 async fn stream_error_propagated() {
138 let (writer, handle) = channel();
139
140 tokio::spawn(async move {
141 writer
142 .text_tx
143 .send("partial".to_owned())
144 .await
145 .expect("send");
146 writer
147 .error_tx
148 .send(StreamError {
149 message: "Python exception: quota exceeded".to_owned(),
150 })
151 .await
152 .expect("send error");
153 });
154
155 let result = handle.text().await;
156 assert!(result.is_err());
157 let err = result.unwrap_err();
158 assert!(err.message.contains("quota exceeded"));
159 }
160
161 #[tokio::test]
162 async fn thought_stream_works() {
163 let (writer, mut handle) = channel();
164
165 tokio::spawn(async move {
166 writer
167 .thought_tx
168 .send("thinking...".to_owned())
169 .await
170 .expect("send");
171 writer
172 .thought_tx
173 .send("done.".to_owned())
174 .await
175 .expect("send");
176 });
177
178 let mut rx = handle.take_thought_stream().expect("should get receiver");
179 let mut thoughts = Vec::new();
180 while let Some(t) = rx.recv().await {
181 thoughts.push(t);
182 }
183 assert_eq!(thoughts, vec!["thinking...", "done."]);
184 }
185
186 #[tokio::test]
187 async fn tool_call_stream_works() {
188 let (writer, mut handle) = channel();
189
190 let event = ToolCallEvent {
191 name: "view_file".to_owned(),
192 args: serde_json::json!({"path": "/tmp/test.txt"}),
193 id: Some("call_1".to_owned()),
194 canonical_path: None,
195 };
196
197 let event_clone = event.clone();
198 tokio::spawn(async move {
199 writer.tool_call_tx.send(event_clone).await.expect("send");
200 });
201
202 let mut rx = handle.take_tool_call_stream().expect("should get receiver");
203 let received = rx.recv().await.expect("should receive event");
204 assert_eq!(received.name, "view_file");
205 assert_eq!(received.id, Some("call_1".to_owned()));
206 }
207
208 #[tokio::test]
209 async fn usage_metadata_available_after_finalize() {
210 let (writer, mut handle) = channel();
211 assert!(handle.usage_metadata().is_none());
212
213 writer.set_usage(crate::types::UsageMetadata {
214 prompt_token_count: Some(100),
215 cached_content_token_count: Some(10),
216 candidates_token_count: Some(50),
217 thoughts_token_count: Some(20),
218 total_token_count: Some(170),
219 });
220 drop(writer);
221 handle.finalize();
222
223 let usage = handle.usage_metadata().expect("should have usage");
224 assert_eq!(usage.prompt_token_count, Some(100));
225 assert_eq!(usage.total_token_count, Some(170));
226 }
227
228 #[test]
229 fn take_text_stream_returns_none_second_time() {
230 let (_writer, mut handle) = channel();
231 assert!(handle.take_text_stream().is_some());
232 assert!(handle.take_text_stream().is_none());
233 }
234
235 #[test]
236 fn tool_call_event_serde_roundtrip() {
237 let event = ToolCallEvent {
238 name: "run_command".to_owned(),
239 args: serde_json::json!({"command": "ls"}),
240 id: Some("call_42".to_owned()),
241 canonical_path: None,
242 };
243 let json = serde_json::to_string(&event).expect("serialize");
244 let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
245 assert_eq!(parsed.name, event.name);
246 assert_eq!(parsed.args, event.args);
247 assert_eq!(parsed.id, event.id);
248 }
249
250 #[test]
251 fn take_thought_stream_returns_none_second_time() {
252 let (_writer, mut handle) = channel();
253 assert!(handle.take_thought_stream().is_some());
254 assert!(handle.take_thought_stream().is_none());
255 }
256
257 #[test]
258 fn take_tool_call_stream_returns_none_second_time() {
259 let (_writer, mut handle) = channel();
260 assert!(handle.take_tool_call_stream().is_some());
261 assert!(handle.take_tool_call_stream().is_none());
262 }
263
264 #[test]
265 fn stream_error_display() {
266 let err = StreamError {
267 message: "quota exceeded".to_owned(),
268 };
269 assert_eq!(format!("{err}"), "stream error: quota exceeded");
270 }
271
272 #[test]
273 fn stream_error_is_std_error() {
274 let err = StreamError {
275 message: "test".to_owned(),
276 };
277 let _: &dyn std::error::Error = &err;
279 }
280
281 #[tokio::test]
282 async fn concurrent_text_and_thought_streams() {
283 let (writer, mut handle) = channel();
284
285 tokio::spawn(async move {
286 writer
287 .text_tx
288 .send("Hello".to_owned())
289 .await
290 .expect("send text");
291 writer
292 .thought_tx
293 .send("thinking...".to_owned())
294 .await
295 .expect("send thought");
296 });
297
298 let mut text_rx = handle.take_text_stream().expect("text rx");
299 let mut thought_rx = handle.take_thought_stream().expect("thought rx");
300
301 let text = text_rx.recv().await.expect("receive text");
302 let thought = thought_rx.recv().await.expect("receive thought");
303
304 assert_eq!(text, "Hello");
305 assert_eq!(thought, "thinking...");
306 }
307
308 #[tokio::test]
309 async fn writer_dropped_without_sending_closes_text() {
310 let (writer, handle) = channel();
311 drop(writer);
312
313 let text = handle.text().await.expect("should succeed");
314 assert!(text.is_empty());
315 }
316
317 #[tokio::test]
318 async fn writer_dropped_without_sending_closes_thought_stream() {
319 let (writer, mut handle) = channel();
320 drop(writer);
321
322 let mut thought_rx = handle.take_thought_stream().expect("rx");
323 assert!(thought_rx.recv().await.is_none());
324 }
325
326 #[test]
327 fn tool_call_event_without_id() {
328 let event = ToolCallEvent {
329 name: "custom".to_owned(),
330 args: serde_json::json!(null),
331 id: None,
332 canonical_path: None,
333 };
334 let json = serde_json::to_string(&event).expect("serialize");
335 let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
336 assert_eq!(parsed.name, "custom");
337 assert_eq!(parsed.args, serde_json::json!(null));
338 }
339
340 #[tokio::test]
341 async fn large_token_stream() {
342 let (writer, handle) = channel();
343 let token_count = 200;
344
345 tokio::spawn(async move {
346 for i in 0..token_count {
347 writer.text_tx.send(format!("t{i}")).await.expect("send");
348 }
349 });
350
351 let text = handle.text().await.expect("should succeed");
352 for i in 0..token_count {
354 assert!(
355 text.contains(&format!("t{i}")),
356 "Missing token t{i} in output"
357 );
358 }
359 }
360
361 #[tokio::test]
362 async fn resolve_returns_events_in_order() {
363 let (writer, handle) = channel();
364
365 let tool_event = ToolCallEvent {
366 name: "view_file".to_owned(),
367 args: serde_json::json!({"path": "/tmp/x.rs"}),
368 id: Some("call_1".to_owned()),
369 canonical_path: None,
370 };
371
372 let tool_clone = tool_event.clone();
373 tokio::spawn(async move {
374 writer
375 .event_tx
376 .send(ResponseEvent::TextChunk("Hello ".to_owned()))
377 .await
378 .expect("send");
379 writer
380 .event_tx
381 .send(ResponseEvent::ThoughtChunk("hmm".to_owned()))
382 .await
383 .expect("send");
384 writer
385 .event_tx
386 .send(ResponseEvent::ToolCall(tool_clone))
387 .await
388 .expect("send");
389 writer
390 .event_tx
391 .send(ResponseEvent::TextChunk("world".to_owned()))
392 .await
393 .expect("send");
394 writer
395 .event_tx
396 .send(ResponseEvent::ToolResult(crate::types::ToolResult {
397 name: "view_file".to_owned(),
398 id: Some("call_1".to_owned()),
399 result: serde_json::json!({"output": "file contents"}),
400 error: None,
401 }))
402 .await
403 .expect("send");
404 });
406
407 let events = handle.resolve().await;
408 assert_eq!(events.len(), 5, "Expected 5 events, got {}", events.len());
409
410 assert!(
412 matches!(&events[0], ResponseEvent::TextChunk(s) if s == "Hello "),
413 "events[0] should be TextChunk(\"Hello \")"
414 );
415 assert!(
416 matches!(&events[1], ResponseEvent::ThoughtChunk(s) if s == "hmm"),
417 "events[1] should be ThoughtChunk(\"hmm\")"
418 );
419 assert!(
420 matches!(&events[2], ResponseEvent::ToolCall(tc) if tc.name == "view_file"),
421 "events[2] should be ToolCall(view_file)"
422 );
423 assert!(
424 matches!(&events[3], ResponseEvent::TextChunk(s) if s == "world"),
425 "events[3] should be TextChunk(\"world\")"
426 );
427 assert!(
428 matches!(&events[4], ResponseEvent::ToolResult(tr) if tr.name == "view_file"),
429 "events[4] should be ToolResult(view_file)"
430 );
431 }
432
433 #[test]
434 fn response_event_serde_roundtrip() {
435 let events = vec![
436 ResponseEvent::TextChunk("hello".to_owned()),
437 ResponseEvent::ThoughtChunk("thinking".to_owned()),
438 ResponseEvent::ToolCall(ToolCallEvent {
439 name: "run_command".to_owned(),
440 args: serde_json::json!({"cmd": "ls"}),
441 id: Some("c1".to_owned()),
442 canonical_path: None,
443 }),
444 ResponseEvent::ToolResult(crate::types::ToolResult {
445 name: "run_command".to_owned(),
446 id: Some("c1".to_owned()),
447 result: serde_json::json!({"output": "done"}),
448 error: None,
449 }),
450 ];
451
452 let json = serde_json::to_string(&events).expect("serialize");
453 let parsed: Vec<ResponseEvent> = serde_json::from_str(&json).expect("deserialize");
454 assert_eq!(parsed.len(), events.len());
455 }
456
457 #[tokio::test]
460 async fn receive_chunks_returns_chunks_in_order() {
461 use tokio_stream::StreamExt;
462
463 let (writer, mut handle) = channel();
464
465 tokio::spawn(async move {
466 writer
467 .chunk_tx
468 .send(StreamChunk::Text("hello".to_owned()))
469 .await
470 .expect("send");
471 writer
472 .chunk_tx
473 .send(StreamChunk::Thought("hmm".to_owned()))
474 .await
475 .expect("send");
476 writer
477 .chunk_tx
478 .send(StreamChunk::ToolCall(ToolCallEvent {
479 name: "view_file".to_owned(),
480 args: serde_json::json!({}),
481 id: None,
482 canonical_path: None,
483 }))
484 .await
485 .expect("send");
486 writer
487 .chunk_tx
488 .send(StreamChunk::Text(" world".to_owned()))
489 .await
490 .expect("send");
491 });
492
493 let mut stream = handle.receive_chunks().expect("should get stream");
494 let mut items = Vec::new();
495 while let Some(chunk) = stream.next().await {
496 items.push(chunk);
497 }
498
499 assert_eq!(items.len(), 4);
500 assert!(matches!(&items[0], StreamChunk::Text(t) if t == "hello"));
501 assert!(matches!(&items[1], StreamChunk::Thought(t) if t == "hmm"));
502 assert!(matches!(&items[2], StreamChunk::ToolCall(tc) if tc.name == "view_file"));
503 assert!(matches!(&items[3], StreamChunk::Text(t) if t == " world"));
504 }
505
506 #[tokio::test]
507 async fn receive_steps_returns_steps() {
508 use tokio_stream::StreamExt;
509
510 let (writer, mut handle) = channel();
511
512 tokio::spawn(async move {
513 writer
514 .step_tx
515 .send(crate::types::Step {
516 id: "step-0".to_owned(),
517 step_index: 0,
518 step_type: crate::types::StepType::TextResponse,
519 source: crate::types::StepSource::Model,
520 target: crate::types::StepTarget::User,
521 status: crate::types::StepStatus::Done,
522 content: "Hello".to_owned(),
523 content_delta: "Hello".to_owned(),
524 thinking: String::new(),
525 thinking_delta: String::new(),
526 tool_calls: vec![],
527 error: String::new(),
528 is_complete_response: Some(true),
529 structured_output: None,
530 usage_metadata: None,
531 })
532 .await
533 .expect("send");
534 });
535
536 let mut stream = handle.receive_steps().expect("should get stream");
537 let step = stream.next().await.expect("should get a step");
538 assert_eq!(step.id, "step-0");
539 assert_eq!(step.step_type, crate::types::StepType::TextResponse);
540 assert_eq!(step.content, "Hello");
541 }
542
543 #[tokio::test]
544 async fn existing_channels_work_alongside_chunk_stream() {
545 use tokio_stream::StreamExt;
546
547 let (writer, mut handle) = channel();
548
549 tokio::spawn(async move {
550 writer
552 .text_tx
553 .send("text-tok".to_owned())
554 .await
555 .expect("send text");
556 writer
557 .chunk_tx
558 .send(StreamChunk::Text("text-tok".to_owned()))
559 .await
560 .expect("send chunk");
561 });
562
563 let mut text_rx = handle.take_text_stream().expect("text rx");
564 let text = text_rx.recv().await.expect("receive text");
565 assert_eq!(text, "text-tok");
566
567 let mut chunk_stream = handle.receive_chunks().expect("chunk stream");
568 let chunk = chunk_stream.next().await.expect("receive chunk");
569 assert!(matches!(chunk, StreamChunk::Text(t) if t == "text-tok"));
570 }
571
572 #[test]
573 fn receive_chunks_returns_none_on_second_call() {
574 let (_writer, mut handle) = channel();
575 assert!(handle.receive_chunks().is_some());
576 assert!(handle.receive_chunks().is_none());
577 }
578
579 #[test]
580 fn receive_steps_returns_none_on_second_call() {
581 let (_writer, mut handle) = channel();
582 assert!(handle.receive_steps().is_some());
583 assert!(handle.receive_steps().is_none());
584 }
585
586 #[test]
587 fn stream_chunk_serde_roundtrip() {
588 let chunks = vec![
589 StreamChunk::Text("hello".to_owned()),
590 StreamChunk::Thought("hmm".to_owned()),
591 StreamChunk::ToolCall(ToolCallEvent {
592 name: "run".to_owned(),
593 args: serde_json::json!({"cmd": "ls"}),
594 id: Some("c1".to_owned()),
595 canonical_path: None,
596 }),
597 ];
598 for chunk in &chunks {
599 let json = serde_json::to_string(chunk).expect("serialize");
600 let parsed: StreamChunk = serde_json::from_str(&json).expect("deserialize");
601 match (chunk, &parsed) {
603 (StreamChunk::Text(a), StreamChunk::Text(b))
604 | (StreamChunk::Thought(a), StreamChunk::Thought(b)) => assert_eq!(a, b),
605 (StreamChunk::ToolCall(a), StreamChunk::ToolCall(b)) => {
606 assert_eq!(a.name, b.name);
607 assert_eq!(a.id, b.id);
608 }
609 _ => panic!("variant mismatch after roundtrip"),
610 }
611 }
612 }
613
614 #[tokio::test]
615 async fn usage_metadata_populated_from_writer_after_resolve() {
616 let (writer, handle) = channel();
617
618 tokio::spawn(async move {
619 writer
620 .event_tx
621 .send(ResponseEvent::TextChunk("hello".to_owned()))
622 .await
623 .unwrap();
624 writer.set_usage(crate::types::UsageMetadata {
625 prompt_token_count: Some(5),
626 cached_content_token_count: None,
627 candidates_token_count: Some(1),
628 thoughts_token_count: None,
629 total_token_count: Some(6),
630 });
631 writer.set_structured_output(serde_json::json!({"key": "value"}));
632 });
633
634 let shared = handle.shared_state();
637 let events = handle.resolve().await;
638 assert_eq!(events.len(), 1);
639
640 let state = shared.lock().expect("lock shared state");
641 assert_eq!(state.usage.as_ref().unwrap().total_token_count, Some(6));
642 assert_eq!(
643 state.structured_output.as_ref().unwrap(),
644 &serde_json::json!({"key": "value"})
645 );
646 }
647
648 #[test]
649 fn chat_result_into_string() {
650 let (writer, handle) = channel();
651 drop(writer);
652 let rt = tokio::runtime::Runtime::new().unwrap();
653 let result = rt.block_on(handle.text()).unwrap();
654 let s: String = result.into();
655 assert!(s.is_empty());
656 }
657}