1use std::{
8 sync::{Arc, Mutex},
9 time::Duration,
10};
11
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15
16use crate::types::{Step, UsageMetadata};
17
18#[derive(Debug, Clone)]
46pub struct ChatResult {
47 text: String,
48 usage: Option<UsageMetadata>,
49 structured_output: Option<serde_json::Value>,
50}
51
52impl ChatResult {
53 #[must_use]
55 pub fn text(&self) -> &str {
56 &self.text
57 }
58
59 #[must_use]
61 pub fn into_string(self) -> String {
62 self.text
63 }
64
65 #[must_use]
67 pub fn usage(&self) -> Option<&UsageMetadata> {
68 self.usage.as_ref()
69 }
70
71 #[must_use]
74 pub fn structured_output(&self) -> Option<&serde_json::Value> {
75 self.structured_output.as_ref()
76 }
77}
78
79impl std::ops::Deref for ChatResult {
80 type Target = str;
81 fn deref(&self) -> &str {
82 &self.text
83 }
84}
85
86impl PartialEq<&str> for ChatResult {
87 fn eq(&self, other: &&str) -> bool {
88 self.text == *other
89 }
90}
91
92impl PartialEq<String> for ChatResult {
93 fn eq(&self, other: &String) -> bool {
94 self.text == *other
95 }
96}
97
98impl std::fmt::Display for ChatResult {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.write_str(&self.text)
101 }
102}
103
104impl From<ChatResult> for String {
105 fn from(result: ChatResult) -> Self {
106 result.text
107 }
108}
109
110pub(crate) const ERROR_DRAIN_TIMEOUT: Duration = Duration::from_millis(50);
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCallEvent {
117 pub name: String,
119 pub args: serde_json::Value,
121 pub id: Option<String>,
123 #[serde(default)]
125 pub canonical_path: Option<String>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct StreamError {
131 pub message: String,
133}
134
135impl std::fmt::Display for StreamError {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write!(f, "stream error: {}", self.message)
138 }
139}
140
141impl std::error::Error for StreamError {}
142
143#[non_exhaustive]
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub enum ResponseEvent {
150 TextChunk(String),
152 ThoughtChunk(String),
154 ToolCall(ToolCallEvent),
156 ToolResult(crate::types::ToolResult),
158}
159
160#[non_exhaustive]
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub enum StreamChunk {
168 Text(String),
170 Thought(String),
172 ToolCall(ToolCallEvent),
174}
175
176#[doc(hidden)]
192#[derive(Debug, Default)]
193pub struct ChatResponseSharedState {
194 pub usage: Option<UsageMetadata>,
196 pub structured_output: Option<serde_json::Value>,
198}
199
200#[derive(Debug)]
205pub(crate) struct StreamReceivers {
206 text: Option<mpsc::Receiver<String>>,
208 thought: Option<mpsc::Receiver<String>>,
210 tool_call: Option<mpsc::Receiver<ToolCallEvent>>,
212 error: Option<mpsc::Receiver<StreamError>>,
214 event: Option<mpsc::Receiver<ResponseEvent>>,
216 step: Option<mpsc::Receiver<Step>>,
218 chunk: Option<mpsc::Receiver<StreamChunk>>,
220}
221
222#[derive(Debug)]
230pub struct ChatResponseHandle {
231 rx: StreamReceivers,
233 usage: Option<UsageMetadata>,
235 structured_output_value: Option<serde_json::Value>,
237 pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
239 pub(crate) keep_alive_permit: Option<tokio::sync::OwnedSemaphorePermit>,
241}
242
243#[derive(Debug)]
257pub struct WriterError {
258 pub message: String,
260}
261
262impl WriterError {
263 pub fn new(message: impl Into<String>) -> Self {
265 Self {
266 message: message.into(),
267 }
268 }
269}
270
271impl std::fmt::Display for WriterError {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 write!(f, "{}", self.message)
274 }
275}
276
277impl std::error::Error for WriterError {}
278
279impl<T> From<mpsc::error::SendError<T>> for WriterError {
280 fn from(err: mpsc::error::SendError<T>) -> Self {
281 Self {
282 message: format!("channel send failed: {err}"),
283 }
284 }
285}
286
287pub struct ChatResponseWriter {
290 pub(crate) text_tx: mpsc::Sender<String>,
292 pub(crate) thought_tx: mpsc::Sender<String>,
294 pub(crate) tool_call_tx: mpsc::Sender<ToolCallEvent>,
296 pub(crate) error_tx: mpsc::Sender<StreamError>,
298 pub(crate) event_tx: mpsc::Sender<ResponseEvent>,
300 pub(crate) step_tx: mpsc::Sender<Step>,
306 pub(crate) chunk_tx: mpsc::Sender<StreamChunk>,
308 pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
310}
311
312impl ChatResponseWriter {
313 pub async fn send_text(&self, text: String) -> Result<(), WriterError> {
319 self.text_tx.send(text).await.map_err(WriterError::from)
320 }
321
322 pub async fn send_thought(&self, thought: String) -> Result<(), WriterError> {
328 self.thought_tx
329 .send(thought)
330 .await
331 .map_err(WriterError::from)
332 }
333
334 pub async fn send_tool_call(&self, event: ToolCallEvent) -> Result<(), WriterError> {
340 self.tool_call_tx
341 .send(event)
342 .await
343 .map_err(WriterError::from)
344 }
345
346 pub async fn send_error(&self, error: StreamError) -> Result<(), WriterError> {
352 self.error_tx.send(error).await.map_err(WriterError::from)
353 }
354
355 pub async fn send_event(&self, event: ResponseEvent) -> Result<(), WriterError> {
361 self.event_tx.send(event).await.map_err(WriterError::from)
362 }
363
364 pub async fn send_step(&self, step: crate::types::Step) -> Result<(), WriterError> {
370 self.step_tx.send(step).await.map_err(WriterError::from)
371 }
372
373 pub async fn send_chunk(&self, chunk: StreamChunk) -> Result<(), WriterError> {
379 self.chunk_tx.send(chunk).await.map_err(WriterError::from)
380 }
381}
382
383const CHANNEL_BUFFER: usize = 256;
386
387#[must_use]
392pub fn channel() -> (ChatResponseWriter, ChatResponseHandle) {
393 let (text_tx, text_rx) = mpsc::channel(CHANNEL_BUFFER);
394 let (thought_tx, thought_rx) = mpsc::channel(CHANNEL_BUFFER);
395 let (tool_call_tx, tool_call_rx) = mpsc::channel(CHANNEL_BUFFER);
396 let (error_tx, error_rx) = mpsc::channel(1);
397 let (event_tx, event_rx) = mpsc::channel(CHANNEL_BUFFER);
398 let (step_tx, step_rx) = mpsc::channel(CHANNEL_BUFFER);
399 let (chunk_tx, chunk_rx) = mpsc::channel(CHANNEL_BUFFER);
400
401 let shared_state = Arc::new(Mutex::new(ChatResponseSharedState::default()));
402
403 let writer = ChatResponseWriter {
404 text_tx,
405 thought_tx,
406 tool_call_tx,
407 error_tx,
408 event_tx,
409 step_tx,
410 chunk_tx,
411 shared_state: Arc::clone(&shared_state),
412 };
413
414 let handle = ChatResponseHandle {
415 keep_alive_permit: None,
416 rx: StreamReceivers {
417 text: Some(text_rx),
418 thought: Some(thought_rx),
419 tool_call: Some(tool_call_rx),
420 error: Some(error_rx),
421 event: Some(event_rx),
422 step: Some(step_rx),
423 chunk: Some(chunk_rx),
424 },
425 usage: None,
426 structured_output_value: None,
427 shared_state,
428 };
429
430 (writer, handle)
431}
432
433impl ChatResponseHandle {
434 pub const fn take_text_stream(&mut self) -> Option<mpsc::Receiver<String>> {
438 self.rx.text.take()
439 }
440
441 pub const fn take_thought_stream(&mut self) -> Option<mpsc::Receiver<String>> {
445 self.rx.thought.take()
446 }
447
448 pub const fn take_tool_call_stream(&mut self) -> Option<mpsc::Receiver<ToolCallEvent>> {
452 self.rx.tool_call.take()
453 }
454
455 pub const fn take_step_stream(&mut self) -> Option<mpsc::Receiver<Step>> {
460 self.rx.step.take()
461 }
462
463 pub fn receive_steps(&mut self) -> Option<impl tokio_stream::Stream<Item = Step>> {
482 self.rx.step.take().map(ReceiverStream::new)
483 }
484
485 pub fn receive_chunks(&mut self) -> Option<impl tokio_stream::Stream<Item = StreamChunk>> {
509 self.rx.chunk.take().map(ReceiverStream::new)
510 }
511
512 pub async fn text(mut self) -> Result<ChatResult, StreamError> {
521 let mut buf = String::new();
522
523 if let Some(mut rx) = self.rx.text.take() {
524 while let Some(token) = rx.recv().await {
525 buf.push_str(&token);
526 }
527 }
528
529 if let Some(mut err_rx) = self.rx.error.take()
532 && let Ok(Some(err)) = tokio::time::timeout(ERROR_DRAIN_TIMEOUT, err_rx.recv()).await
533 {
534 return Err(err);
535 }
536
537 self.finalize();
538
539 Ok(ChatResult {
540 text: buf,
541 usage: self.usage,
542 structured_output: self.structured_output_value,
543 })
544 }
545
546 pub fn finalize(&mut self) {
549 if let Ok(state) = self.shared_state.lock() {
550 self.usage = state.usage.clone();
551 self.structured_output_value = state.structured_output.clone();
552 } else {
553 tracing::error!(
554 "ChatResponseHandle shared_state mutex poisoned during finalize — \
555 usage and structured_output will be unavailable"
556 );
557 }
558 }
559
560 #[must_use]
565 pub const fn structured_output(&self) -> Option<&serde_json::Value> {
566 self.structured_output_value.as_ref()
567 }
568
569 #[must_use]
573 pub const fn usage_metadata(&self) -> Option<&UsageMetadata> {
574 self.usage.as_ref()
575 }
576
577 #[doc(hidden)]
584 #[must_use]
585 pub fn shared_state(&self) -> Arc<Mutex<ChatResponseSharedState>> {
586 Arc::clone(&self.shared_state)
587 }
588
589 pub async fn resolve(mut self) -> Vec<ResponseEvent> {
594 let mut events = Vec::new();
595 if let Some(mut rx) = self.rx.event.take() {
596 while let Some(event) = rx.recv().await {
597 events.push(event);
598 }
599 }
600 self.finalize();
601 events
602 }
603}
604
605impl ChatResponseWriter {
606 pub fn set_usage(&self, usage: crate::types::UsageMetadata) {
609 match self.shared_state.lock() {
610 Ok(mut state) => {
611 state.usage = Some(usage);
612 }
613 Err(e) => {
614 tracing::error!(
615 error = %e,
616 "ChatResponseWriter shared_state mutex poisoned in set_usage"
617 );
618 }
619 }
620 }
621
622 pub fn set_structured_output(&self, value: serde_json::Value) {
625 match self.shared_state.lock() {
626 Ok(mut state) => {
627 state.structured_output = Some(value);
628 }
629 Err(e) => {
630 tracing::error!(
631 error = %e,
632 "ChatResponseWriter shared_state mutex poisoned in set_structured_output"
633 );
634 }
635 }
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642
643 #[tokio::test]
644 async fn streaming_receives_all_tokens_in_order() {
645 let (writer, mut handle) = channel();
646
647 let tokens = ["Hello", " ", "world", "!"];
648 let expected: String = tokens.iter().copied().collect();
649
650 let send_task = tokio::spawn(async move {
652 for token in &["Hello", " ", "world", "!"] {
653 writer
654 .text_tx
655 .send((*token).to_owned())
656 .await
657 .expect("send should succeed");
658 }
659 });
661
662 let mut rx = handle.take_text_stream().expect("should get receiver");
664 let mut received = Vec::new();
665 while let Some(token) = rx.recv().await {
666 received.push(token);
667 }
668
669 send_task.await.expect("send task should complete");
670 let full: String = received.iter().map(String::as_str).collect();
671 assert_eq!(full, expected);
672 }
673
674 #[tokio::test]
675 async fn text_returns_complete_response() {
676 let (writer, handle) = channel();
677
678 tokio::spawn(async move {
679 for token in &["The ", "answer ", "is ", "42."] {
680 writer
681 .text_tx
682 .send((*token).to_owned())
683 .await
684 .expect("send");
685 }
686 });
687
688 let text = handle.text().await.expect("should succeed");
689 assert_eq!(text, "The answer is 42.");
690 }
691
692 #[tokio::test]
693 async fn text_returns_empty_when_no_tokens() {
694 let (writer, handle) = channel();
695 drop(writer);
697
698 let text = handle.text().await.expect("should succeed");
699 assert!(text.is_empty());
700 }
701
702 #[tokio::test]
703 async fn stream_error_propagated() {
704 let (writer, handle) = channel();
705
706 tokio::spawn(async move {
707 writer
708 .text_tx
709 .send("partial".to_owned())
710 .await
711 .expect("send");
712 writer
713 .error_tx
714 .send(StreamError {
715 message: "Python exception: quota exceeded".to_owned(),
716 })
717 .await
718 .expect("send error");
719 });
720
721 let result = handle.text().await;
722 assert!(result.is_err());
723 let err = result.unwrap_err();
724 assert!(err.message.contains("quota exceeded"));
725 }
726
727 #[tokio::test]
728 async fn thought_stream_works() {
729 let (writer, mut handle) = channel();
730
731 tokio::spawn(async move {
732 writer
733 .thought_tx
734 .send("thinking...".to_owned())
735 .await
736 .expect("send");
737 writer
738 .thought_tx
739 .send("done.".to_owned())
740 .await
741 .expect("send");
742 });
743
744 let mut rx = handle.take_thought_stream().expect("should get receiver");
745 let mut thoughts = Vec::new();
746 while let Some(t) = rx.recv().await {
747 thoughts.push(t);
748 }
749 assert_eq!(thoughts, vec!["thinking...", "done."]);
750 }
751
752 #[tokio::test]
753 async fn tool_call_stream_works() {
754 let (writer, mut handle) = channel();
755
756 let event = ToolCallEvent {
757 name: "view_file".to_owned(),
758 args: serde_json::json!({"path": "/tmp/test.txt"}),
759 id: Some("call_1".to_owned()),
760 canonical_path: None,
761 };
762
763 let event_clone = event.clone();
764 tokio::spawn(async move {
765 writer.tool_call_tx.send(event_clone).await.expect("send");
766 });
767
768 let mut rx = handle.take_tool_call_stream().expect("should get receiver");
769 let received = rx.recv().await.expect("should receive event");
770 assert_eq!(received.name, "view_file");
771 assert_eq!(received.id, Some("call_1".to_owned()));
772 }
773
774 #[tokio::test]
775 async fn usage_metadata_available_after_finalize() {
776 let (writer, mut handle) = channel();
777 assert!(handle.usage_metadata().is_none());
778
779 writer.set_usage(UsageMetadata {
780 prompt_token_count: Some(100),
781 cached_content_token_count: Some(10),
782 candidates_token_count: Some(50),
783 thoughts_token_count: Some(20),
784 total_token_count: Some(170),
785 });
786 drop(writer);
787 handle.finalize();
788
789 let usage = handle.usage_metadata().expect("should have usage");
790 assert_eq!(usage.prompt_token_count, Some(100));
791 assert_eq!(usage.total_token_count, Some(170));
792 }
793
794 #[test]
795 fn take_text_stream_returns_none_second_time() {
796 let (_writer, mut handle) = channel();
797 assert!(handle.take_text_stream().is_some());
798 assert!(handle.take_text_stream().is_none());
799 }
800
801 #[test]
802 fn tool_call_event_serde_roundtrip() {
803 let event = ToolCallEvent {
804 name: "run_command".to_owned(),
805 args: serde_json::json!({"command": "ls"}),
806 id: Some("call_42".to_owned()),
807 canonical_path: None,
808 };
809 let json = serde_json::to_string(&event).expect("serialize");
810 let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
811 assert_eq!(parsed.name, event.name);
812 assert_eq!(parsed.args, event.args);
813 assert_eq!(parsed.id, event.id);
814 }
815
816 #[test]
817 fn take_thought_stream_returns_none_second_time() {
818 let (_writer, mut handle) = channel();
819 assert!(handle.take_thought_stream().is_some());
820 assert!(handle.take_thought_stream().is_none());
821 }
822
823 #[test]
824 fn take_tool_call_stream_returns_none_second_time() {
825 let (_writer, mut handle) = channel();
826 assert!(handle.take_tool_call_stream().is_some());
827 assert!(handle.take_tool_call_stream().is_none());
828 }
829
830 #[test]
831 fn stream_error_display() {
832 let err = StreamError {
833 message: "quota exceeded".to_owned(),
834 };
835 assert_eq!(format!("{err}"), "stream error: quota exceeded");
836 }
837
838 #[test]
839 fn stream_error_is_std_error() {
840 let err = StreamError {
841 message: "test".to_owned(),
842 };
843 let _: &dyn std::error::Error = &err;
845 }
846
847 #[tokio::test]
848 async fn concurrent_text_and_thought_streams() {
849 let (writer, mut handle) = channel();
850
851 tokio::spawn(async move {
852 writer
853 .text_tx
854 .send("Hello".to_owned())
855 .await
856 .expect("send text");
857 writer
858 .thought_tx
859 .send("thinking...".to_owned())
860 .await
861 .expect("send thought");
862 });
863
864 let mut text_rx = handle.take_text_stream().expect("text rx");
865 let mut thought_rx = handle.take_thought_stream().expect("thought rx");
866
867 let text = text_rx.recv().await.expect("receive text");
868 let thought = thought_rx.recv().await.expect("receive thought");
869
870 assert_eq!(text, "Hello");
871 assert_eq!(thought, "thinking...");
872 }
873
874 #[tokio::test]
875 async fn writer_dropped_without_sending_closes_text() {
876 let (writer, handle) = channel();
877 drop(writer);
878
879 let text = handle.text().await.expect("should succeed");
880 assert!(text.is_empty());
881 }
882
883 #[tokio::test]
884 async fn writer_dropped_without_sending_closes_thought_stream() {
885 let (writer, mut handle) = channel();
886 drop(writer);
887
888 let mut thought_rx = handle.take_thought_stream().expect("rx");
889 assert!(thought_rx.recv().await.is_none());
890 }
891
892 #[test]
893 fn tool_call_event_without_id() {
894 let event = ToolCallEvent {
895 name: "custom".to_owned(),
896 args: serde_json::json!(null),
897 id: None,
898 canonical_path: None,
899 };
900 let json = serde_json::to_string(&event).expect("serialize");
901 let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
902 assert_eq!(parsed.name, "custom");
903 assert_eq!(parsed.args, serde_json::json!(null));
904 }
905
906 #[tokio::test]
907 async fn large_token_stream() {
908 let (writer, handle) = channel();
909 let token_count = 200;
910
911 tokio::spawn(async move {
912 for i in 0..token_count {
913 writer.text_tx.send(format!("t{i}")).await.expect("send");
914 }
915 });
916
917 let text = handle.text().await.expect("should succeed");
918 for i in 0..token_count {
920 assert!(
921 text.contains(&format!("t{i}")),
922 "Missing token t{i} in output"
923 );
924 }
925 }
926
927 #[tokio::test]
928 async fn resolve_returns_events_in_order() {
929 let (writer, handle) = channel();
930
931 let tool_event = ToolCallEvent {
932 name: "view_file".to_owned(),
933 args: serde_json::json!({"path": "/tmp/x.rs"}),
934 id: Some("call_1".to_owned()),
935 canonical_path: None,
936 };
937
938 let tool_clone = tool_event.clone();
939 tokio::spawn(async move {
940 writer
941 .event_tx
942 .send(ResponseEvent::TextChunk("Hello ".to_owned()))
943 .await
944 .expect("send");
945 writer
946 .event_tx
947 .send(ResponseEvent::ThoughtChunk("hmm".to_owned()))
948 .await
949 .expect("send");
950 writer
951 .event_tx
952 .send(ResponseEvent::ToolCall(tool_clone))
953 .await
954 .expect("send");
955 writer
956 .event_tx
957 .send(ResponseEvent::TextChunk("world".to_owned()))
958 .await
959 .expect("send");
960 writer
961 .event_tx
962 .send(ResponseEvent::ToolResult(crate::types::ToolResult {
963 name: "view_file".to_owned(),
964 id: Some("call_1".to_owned()),
965 result: serde_json::json!({"output": "file contents"}),
966 error: None,
967 }))
968 .await
969 .expect("send");
970 });
972
973 let events = handle.resolve().await;
974 assert_eq!(events.len(), 5, "Expected 5 events, got {}", events.len());
975
976 assert!(
978 matches!(&events[0], ResponseEvent::TextChunk(s) if s == "Hello "),
979 "events[0] should be TextChunk(\"Hello \")"
980 );
981 assert!(
982 matches!(&events[1], ResponseEvent::ThoughtChunk(s) if s == "hmm"),
983 "events[1] should be ThoughtChunk(\"hmm\")"
984 );
985 assert!(
986 matches!(&events[2], ResponseEvent::ToolCall(tc) if tc.name == "view_file"),
987 "events[2] should be ToolCall(view_file)"
988 );
989 assert!(
990 matches!(&events[3], ResponseEvent::TextChunk(s) if s == "world"),
991 "events[3] should be TextChunk(\"world\")"
992 );
993 assert!(
994 matches!(&events[4], ResponseEvent::ToolResult(tr) if tr.name == "view_file"),
995 "events[4] should be ToolResult(view_file)"
996 );
997 }
998
999 #[test]
1000 fn response_event_serde_roundtrip() {
1001 let events = vec![
1002 ResponseEvent::TextChunk("hello".to_owned()),
1003 ResponseEvent::ThoughtChunk("thinking".to_owned()),
1004 ResponseEvent::ToolCall(ToolCallEvent {
1005 name: "run_command".to_owned(),
1006 args: serde_json::json!({"cmd": "ls"}),
1007 id: Some("c1".to_owned()),
1008 canonical_path: None,
1009 }),
1010 ResponseEvent::ToolResult(crate::types::ToolResult {
1011 name: "run_command".to_owned(),
1012 id: Some("c1".to_owned()),
1013 result: serde_json::json!({"output": "done"}),
1014 error: None,
1015 }),
1016 ];
1017
1018 let json = serde_json::to_string(&events).expect("serialize");
1019 let parsed: Vec<ResponseEvent> = serde_json::from_str(&json).expect("deserialize");
1020 assert_eq!(parsed.len(), events.len());
1021 }
1022
1023 #[tokio::test]
1026 async fn receive_chunks_returns_chunks_in_order() {
1027 use tokio_stream::StreamExt;
1028
1029 let (writer, mut handle) = channel();
1030
1031 tokio::spawn(async move {
1032 writer
1033 .chunk_tx
1034 .send(StreamChunk::Text("hello".to_owned()))
1035 .await
1036 .expect("send");
1037 writer
1038 .chunk_tx
1039 .send(StreamChunk::Thought("hmm".to_owned()))
1040 .await
1041 .expect("send");
1042 writer
1043 .chunk_tx
1044 .send(StreamChunk::ToolCall(ToolCallEvent {
1045 name: "view_file".to_owned(),
1046 args: serde_json::json!({}),
1047 id: None,
1048 canonical_path: None,
1049 }))
1050 .await
1051 .expect("send");
1052 writer
1053 .chunk_tx
1054 .send(StreamChunk::Text(" world".to_owned()))
1055 .await
1056 .expect("send");
1057 });
1058
1059 let mut stream = handle.receive_chunks().expect("should get stream");
1060 let mut items = Vec::new();
1061 while let Some(chunk) = stream.next().await {
1062 items.push(chunk);
1063 }
1064
1065 assert_eq!(items.len(), 4);
1066 assert!(matches!(&items[0], StreamChunk::Text(t) if t == "hello"));
1067 assert!(matches!(&items[1], StreamChunk::Thought(t) if t == "hmm"));
1068 assert!(matches!(&items[2], StreamChunk::ToolCall(tc) if tc.name == "view_file"));
1069 assert!(matches!(&items[3], StreamChunk::Text(t) if t == " world"));
1070 }
1071
1072 #[tokio::test]
1073 async fn receive_steps_returns_steps() {
1074 use tokio_stream::StreamExt;
1075
1076 let (writer, mut handle) = channel();
1077
1078 tokio::spawn(async move {
1079 writer
1080 .step_tx
1081 .send(crate::types::Step {
1082 id: "step-0".to_owned(),
1083 step_index: 0,
1084 step_type: crate::types::StepType::TextResponse,
1085 source: crate::types::StepSource::Model,
1086 target: crate::types::StepTarget::User,
1087 status: crate::types::StepStatus::Done,
1088 content: "Hello".to_owned(),
1089 content_delta: "Hello".to_owned(),
1090 thinking: String::new(),
1091 thinking_delta: String::new(),
1092 tool_calls: vec![],
1093 error: String::new(),
1094 is_complete_response: Some(true),
1095 structured_output: None,
1096 usage_metadata: None,
1097 })
1098 .await
1099 .expect("send");
1100 });
1101
1102 let mut stream = handle.receive_steps().expect("should get stream");
1103 let step = stream.next().await.expect("should get a step");
1104 assert_eq!(step.id, "step-0");
1105 assert_eq!(step.step_type, crate::types::StepType::TextResponse);
1106 assert_eq!(step.content, "Hello");
1107 }
1108
1109 #[tokio::test]
1110 async fn existing_channels_work_alongside_chunk_stream() {
1111 use tokio_stream::StreamExt;
1112
1113 let (writer, mut handle) = channel();
1114
1115 tokio::spawn(async move {
1116 writer
1118 .text_tx
1119 .send("text-tok".to_owned())
1120 .await
1121 .expect("send text");
1122 writer
1123 .chunk_tx
1124 .send(StreamChunk::Text("text-tok".to_owned()))
1125 .await
1126 .expect("send chunk");
1127 });
1128
1129 let mut text_rx = handle.take_text_stream().expect("text rx");
1130 let text = text_rx.recv().await.expect("receive text");
1131 assert_eq!(text, "text-tok");
1132
1133 let mut chunk_stream = handle.receive_chunks().expect("chunk stream");
1134 let chunk = chunk_stream.next().await.expect("receive chunk");
1135 assert!(matches!(chunk, StreamChunk::Text(t) if t == "text-tok"));
1136 }
1137
1138 #[test]
1139 fn receive_chunks_returns_none_on_second_call() {
1140 let (_writer, mut handle) = channel();
1141 assert!(handle.receive_chunks().is_some());
1142 assert!(handle.receive_chunks().is_none());
1143 }
1144
1145 #[test]
1146 fn receive_steps_returns_none_on_second_call() {
1147 let (_writer, mut handle) = channel();
1148 assert!(handle.receive_steps().is_some());
1149 assert!(handle.receive_steps().is_none());
1150 }
1151
1152 #[test]
1153 fn stream_chunk_serde_roundtrip() {
1154 let chunks = vec![
1155 StreamChunk::Text("hello".to_owned()),
1156 StreamChunk::Thought("hmm".to_owned()),
1157 StreamChunk::ToolCall(ToolCallEvent {
1158 name: "run".to_owned(),
1159 args: serde_json::json!({"cmd": "ls"}),
1160 id: Some("c1".to_owned()),
1161 canonical_path: None,
1162 }),
1163 ];
1164 for chunk in &chunks {
1165 let json = serde_json::to_string(chunk).expect("serialize");
1166 let parsed: StreamChunk = serde_json::from_str(&json).expect("deserialize");
1167 match (chunk, &parsed) {
1169 (StreamChunk::Text(a), StreamChunk::Text(b))
1170 | (StreamChunk::Thought(a), StreamChunk::Thought(b)) => assert_eq!(a, b),
1171 (StreamChunk::ToolCall(a), StreamChunk::ToolCall(b)) => {
1172 assert_eq!(a.name, b.name);
1173 assert_eq!(a.id, b.id);
1174 }
1175 _ => panic!("variant mismatch after roundtrip"),
1176 }
1177 }
1178 }
1179
1180 #[tokio::test]
1181 async fn usage_metadata_populated_from_writer_after_resolve() {
1182 let (writer, handle) = channel();
1183
1184 tokio::spawn(async move {
1185 writer
1186 .event_tx
1187 .send(ResponseEvent::TextChunk("hello".to_owned()))
1188 .await
1189 .unwrap();
1190 writer.set_usage(crate::types::UsageMetadata {
1191 prompt_token_count: Some(5),
1192 cached_content_token_count: None,
1193 candidates_token_count: Some(1),
1194 thoughts_token_count: None,
1195 total_token_count: Some(6),
1196 });
1197 writer.set_structured_output(serde_json::json!({"key": "value"}));
1198 });
1199
1200 let shared = handle.shared_state();
1203 let events = handle.resolve().await;
1204 assert_eq!(events.len(), 1);
1205
1206 let state = shared.lock().expect("lock shared state");
1207 assert_eq!(state.usage.as_ref().unwrap().total_token_count, Some(6));
1208 assert_eq!(
1209 state.structured_output.as_ref().unwrap(),
1210 &serde_json::json!({"key": "value"})
1211 );
1212 }
1213
1214 #[test]
1215 fn chat_result_into_string() {
1216 let (writer, handle) = channel();
1217 drop(writer);
1218 let rt = tokio::runtime::Runtime::new().unwrap();
1219 let result = rt.block_on(handle.text()).unwrap();
1220 let s: String = result.into();
1221 assert!(s.is_empty());
1222 }
1223}