Skip to main content

agy_bridge/
streaming.rs

1//! Streaming response bridge for the Antigravity SDK.
2//!
3//! Bridges the SDK's `ChatResponse` (Python async iterator) to tokio channels
4//! so Rust consumers can stream text tokens, thinking tokens, and tool call
5//! events independently.
6
7use std::{
8    sync::{Arc, Mutex},
9    time::Duration,
10};
11
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15
16use crate::types::{Step, UsageMetadata};
17
18/// The result of draining a chat response via [`ChatResponseHandle::text()`].
19///
20/// Carries the full response text alongside optional metadata (token usage,
21/// structured output). Dereferences to `&str` for ergonomic use:
22///
23/// ```rust
24/// # #[tokio::main]
25/// # async fn main() -> Result<(), agy_bridge::error::Error> {
26/// # agy_bridge::load_dotenv();
27/// # let bridge = agy_bridge::AgyBridge::builder().build()?;
28/// # let agent = bridge.agent(
29/// #     agy_bridge::config::AgentConfig::builder()
30/// #         .system_instructions("Reply with 'Hello!' and nothing else. Never use tools.")
31/// #         .capabilities(agy_bridge::config::CapabilitiesConfig::custom_tools_only())
32/// #         .build()
33/// # ).await?;
34/// let result = agent
35///     .chat("Reply with 'Hello!' and nothing else.")
36///     .await?
37///     .text()
38///     .await?;
39/// println!("{result}"); // prints text
40/// if let Some(usage) = result.usage() { /* access metadata */ }
41/// # agent.shutdown().await?;
42/// # Ok(())
43/// # }
44/// ```
45#[derive(Debug, Clone)]
46pub struct ChatResult {
47    text: String,
48    usage: Option<UsageMetadata>,
49    structured_output: Option<serde_json::Value>,
50}
51
52impl ChatResult {
53    /// The full response text.
54    #[must_use]
55    pub fn text(&self) -> &str {
56        &self.text
57    }
58
59    /// Consume the result and return the inner `String`.
60    #[must_use]
61    pub fn into_string(self) -> String {
62        self.text
63    }
64
65    /// Token usage metadata, if available.
66    #[must_use]
67    pub fn usage(&self) -> Option<&UsageMetadata> {
68        self.usage.as_ref()
69    }
70
71    /// Structured output (JSON), if the agent was configured with a
72    /// `response_schema` and the model returned valid JSON.
73    #[must_use]
74    pub fn structured_output(&self) -> Option<&serde_json::Value> {
75        self.structured_output.as_ref()
76    }
77}
78
79impl std::ops::Deref for ChatResult {
80    type Target = str;
81    fn deref(&self) -> &str {
82        &self.text
83    }
84}
85
86impl PartialEq<&str> for ChatResult {
87    fn eq(&self, other: &&str) -> bool {
88        self.text == *other
89    }
90}
91
92impl PartialEq<String> for ChatResult {
93    fn eq(&self, other: &String) -> bool {
94        self.text == *other
95    }
96}
97
98impl std::fmt::Display for ChatResult {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.write_str(&self.text)
101    }
102}
103
104impl From<ChatResult> for String {
105    fn from(result: ChatResult) -> Self {
106        result.text
107    }
108}
109
110/// Brief timeout used when draining the error channel after the text stream
111/// closes. Shared with [`crate::interactive`].
112pub(crate) const ERROR_DRAIN_TIMEOUT: Duration = Duration::from_millis(50);
113
114/// A tool call event received during streaming.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCallEvent {
117    /// Tool name (e.g. `"view_file"` or a custom tool name).
118    pub name: String,
119    /// Arguments as a JSON object.
120    pub args: serde_json::Value,
121    /// Optional call identifier assigned by the backend.
122    pub id: Option<String>,
123    /// Optional canonical path for file tools.
124    #[serde(default)]
125    pub canonical_path: Option<String>,
126}
127
128/// Error sent over the error channel when the Python stream fails.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct StreamError {
131    /// Error message from the Python side.
132    pub message: String,
133}
134
135impl std::fmt::Display for StreamError {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        write!(f, "stream error: {}", self.message)
138    }
139}
140
141impl std::error::Error for StreamError {}
142
143/// An ordered event from a response timeline, produced by [`ChatResponseHandle::resolve`].
144///
145/// Mirrors the Python SDK's `ChatResponse.resolve()` which returns
146/// `list[StreamChunk | ToolCall | ToolResult]`.
147#[non_exhaustive]
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub enum ResponseEvent {
150    /// A text chunk from the model.
151    TextChunk(String),
152    /// A thinking/reasoning chunk from the model.
153    ThoughtChunk(String),
154    /// A tool call request from the model.
155    ToolCall(ToolCallEvent),
156    /// A tool execution result.
157    ToolResult(crate::types::ToolResult),
158}
159
160/// A chunk from the streaming response, combining text, thought, and tool call events.
161///
162/// This provides a unified stream of all chunk types, unlike the separate
163/// `take_text_stream()` / `take_thought_stream()` / `take_tool_call_stream()`
164/// accessors which split events by kind.
165#[non_exhaustive]
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub enum StreamChunk {
168    /// A text token from the model.
169    Text(String),
170    /// A thinking/reasoning token.
171    Thought(String),
172    /// A tool call event.
173    ToolCall(ToolCallEvent),
174}
175
176/// Handle to a streaming chat response.
177///
178/// Created by the Python bridge when `agent.chat()` is called. Provides
179/// independent channels for text tokens, thinking tokens, and tool call events.
180///
181/// # Ownership
182///
183/// The receivers are consumed when you call the corresponding accessor.
184/// Each accessor can only be called once — subsequent calls return `None`
185/// because the receiver has already been taken.
186/// Shared mutable state between the writer and handle.
187///
188/// Uses `std::sync::Mutex` rather than `tokio::sync::Mutex` because the lock
189/// is held only for brief field reads/clones (never across `.await`). This is
190/// safe from deadlocks and cheaper than an async mutex.
191#[doc(hidden)]
192#[derive(Debug, Default)]
193pub struct ChatResponseSharedState {
194    /// Token usage metadata, populated by the writer after the stream completes.
195    pub usage: Option<UsageMetadata>,
196    /// Structured output, populated by the writer after the stream completes.
197    pub structured_output: Option<serde_json::Value>,
198}
199
200/// Grouped receivers for each independent stream channel.
201///
202/// Extracted from [`ChatResponseHandle`] so the seven channel receivers
203/// are logically grouped, keeping the handle's field list manageable.
204#[derive(Debug)]
205pub(crate) struct StreamReceivers {
206    /// Receives text tokens as they arrive from the model.
207    text: Option<mpsc::Receiver<String>>,
208    /// Receives thinking/reasoning tokens.
209    thought: Option<mpsc::Receiver<String>>,
210    /// Receives tool call events.
211    tool_call: Option<mpsc::Receiver<ToolCallEvent>>,
212    /// Receives at most one error if the stream fails.
213    error: Option<mpsc::Receiver<StreamError>>,
214    /// Receives ordered [`ResponseEvent`]s for [`resolve()`](ChatResponseHandle::resolve).
215    event: Option<mpsc::Receiver<ResponseEvent>>,
216    /// Receives [`Step`] objects as they are produced.
217    step: Option<mpsc::Receiver<Step>>,
218    /// Receives unified [`StreamChunk`]s (text, thought, and tool call events).
219    chunk: Option<mpsc::Receiver<StreamChunk>>,
220}
221
222/// Handle to a streaming chat response.
223///
224/// Created by [`AgentHandle::chat()`](crate::agent::AgentHandle::chat). Provides
225/// independent channels for text tokens, thinking tokens, and tool-call events.
226///
227/// Each stream accessor can only be called once — subsequent calls return `None`
228/// because the underlying receiver has already been taken.
229#[derive(Debug)]
230pub struct ChatResponseHandle {
231    /// All per-stream receivers, grouped for clarity.
232    rx: StreamReceivers,
233    /// Token usage metadata, populated after the stream completes.
234    usage: Option<UsageMetadata>,
235    /// Structured output from a `response_schema`-configured agent.
236    structured_output_value: Option<serde_json::Value>,
237    /// Shared state to receive metadata updates from the python bridge thread.
238    pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
239    /// Semaphore permit that keeps the agent alive while the handle exists.
240    pub(crate) keep_alive_permit: Option<tokio::sync::OwnedSemaphorePermit>,
241}
242
243/// Error returned when sending to a [`ChatResponseWriter`] channel fails.
244///
245/// This wraps the underlying channel error to avoid leaking the
246/// `tokio::sync::mpsc::error::SendError<T>` generic into the public API.
247///
248/// # Example
249///
250/// ```
251/// use agy_bridge::streaming::WriterError;
252///
253/// let err = WriterError::new("receiver dropped");
254/// assert_eq!(err.to_string(), "receiver dropped");
255/// ```
256#[derive(Debug)]
257pub struct WriterError {
258    /// Human-readable description of the failure.
259    pub message: String,
260}
261
262impl WriterError {
263    /// Create a new writer error.
264    pub fn new(message: impl Into<String>) -> Self {
265        Self {
266            message: message.into(),
267        }
268    }
269}
270
271impl std::fmt::Display for WriterError {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        write!(f, "{}", self.message)
274    }
275}
276
277impl std::error::Error for WriterError {}
278
279impl<T> From<mpsc::error::SendError<T>> for WriterError {
280    fn from(err: mpsc::error::SendError<T>) -> Self {
281        Self {
282            message: format!("channel send failed: {err}"),
283        }
284    }
285}
286
287/// The sending side of a [`ChatResponseHandle`], held by the Python bridge
288/// thread that drives the SDK's async iterator.
289pub struct ChatResponseWriter {
290    /// Sends text tokens.
291    pub(crate) text_tx: mpsc::Sender<String>,
292    /// Sends thinking tokens.
293    pub(crate) thought_tx: mpsc::Sender<String>,
294    /// Sends tool call events.
295    pub(crate) tool_call_tx: mpsc::Sender<ToolCallEvent>,
296    /// Sends a stream error (at most one).
297    pub(crate) error_tx: mpsc::Sender<StreamError>,
298    /// Sends ordered [`ResponseEvent`]s for the resolve timeline.
299    pub(crate) event_tx: mpsc::Sender<ResponseEvent>,
300    /// Sends [`Step`] objects as they are produced.
301    ///
302    /// The sender must be held to keep the channel alive for
303    /// [`ChatResponseHandle::take_step_stream()`]. It will be actively
304    /// written once step-level streaming is wired through the command loop.
305    pub(crate) step_tx: mpsc::Sender<Step>,
306    /// Sends unified [`StreamChunk`]s.
307    pub(crate) chunk_tx: mpsc::Sender<StreamChunk>,
308    /// Shared state to send metadata updates back to the handle.
309    pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
310}
311
312impl ChatResponseWriter {
313    /// Send a text token.
314    ///
315    /// # Errors
316    ///
317    /// Returns [`WriterError`] if the receiver has been dropped.
318    pub async fn send_text(&self, text: String) -> Result<(), WriterError> {
319        self.text_tx.send(text).await.map_err(WriterError::from)
320    }
321
322    /// Send a thinking token.
323    ///
324    /// # Errors
325    ///
326    /// Returns [`WriterError`] if the receiver has been dropped.
327    pub async fn send_thought(&self, thought: String) -> Result<(), WriterError> {
328        self.thought_tx
329            .send(thought)
330            .await
331            .map_err(WriterError::from)
332    }
333
334    /// Send a tool call event.
335    ///
336    /// # Errors
337    ///
338    /// Returns [`WriterError`] if the receiver has been dropped.
339    pub async fn send_tool_call(&self, event: ToolCallEvent) -> Result<(), WriterError> {
340        self.tool_call_tx
341            .send(event)
342            .await
343            .map_err(WriterError::from)
344    }
345
346    /// Send an error.
347    ///
348    /// # Errors
349    ///
350    /// Returns [`WriterError`] if the receiver has been dropped.
351    pub async fn send_error(&self, error: StreamError) -> Result<(), WriterError> {
352        self.error_tx.send(error).await.map_err(WriterError::from)
353    }
354
355    /// Send a response event.
356    ///
357    /// # Errors
358    ///
359    /// Returns [`WriterError`] if the receiver has been dropped.
360    pub async fn send_event(&self, event: ResponseEvent) -> Result<(), WriterError> {
361        self.event_tx.send(event).await.map_err(WriterError::from)
362    }
363
364    /// Send a step.
365    ///
366    /// # Errors
367    ///
368    /// Returns [`WriterError`] if the receiver has been dropped.
369    pub async fn send_step(&self, step: crate::types::Step) -> Result<(), WriterError> {
370        self.step_tx.send(step).await.map_err(WriterError::from)
371    }
372
373    /// Send a unified stream chunk.
374    ///
375    /// # Errors
376    ///
377    /// Returns [`WriterError`] if the receiver has been dropped.
378    pub async fn send_chunk(&self, chunk: StreamChunk) -> Result<(), WriterError> {
379        self.chunk_tx.send(chunk).await.map_err(WriterError::from)
380    }
381}
382
383/// Default channel buffer size. Large enough to avoid backpressure during
384/// normal operation while bounding memory usage.
385const CHANNEL_BUFFER: usize = 256;
386
387/// Create a paired `(ChatResponseWriter, ChatResponseHandle)`.
388///
389/// The writer is handed to the Python bridge thread; the handle is returned
390/// to the Rust caller.
391#[must_use]
392pub fn channel() -> (ChatResponseWriter, ChatResponseHandle) {
393    let (text_tx, text_rx) = mpsc::channel(CHANNEL_BUFFER);
394    let (thought_tx, thought_rx) = mpsc::channel(CHANNEL_BUFFER);
395    let (tool_call_tx, tool_call_rx) = mpsc::channel(CHANNEL_BUFFER);
396    let (error_tx, error_rx) = mpsc::channel(1);
397    let (event_tx, event_rx) = mpsc::channel(CHANNEL_BUFFER);
398    let (step_tx, step_rx) = mpsc::channel(CHANNEL_BUFFER);
399    let (chunk_tx, chunk_rx) = mpsc::channel(CHANNEL_BUFFER);
400
401    let shared_state = Arc::new(Mutex::new(ChatResponseSharedState::default()));
402
403    let writer = ChatResponseWriter {
404        text_tx,
405        thought_tx,
406        tool_call_tx,
407        error_tx,
408        event_tx,
409        step_tx,
410        chunk_tx,
411        shared_state: Arc::clone(&shared_state),
412    };
413
414    let handle = ChatResponseHandle {
415        keep_alive_permit: None,
416        rx: StreamReceivers {
417            text: Some(text_rx),
418            thought: Some(thought_rx),
419            tool_call: Some(tool_call_rx),
420            error: Some(error_rx),
421            event: Some(event_rx),
422            step: Some(step_rx),
423            chunk: Some(chunk_rx),
424        },
425        usage: None,
426        structured_output_value: None,
427        shared_state,
428    };
429
430    (writer, handle)
431}
432
433impl ChatResponseHandle {
434    /// Take the text token receiver for token-by-token streaming.
435    ///
436    /// Returns `None` if the receiver was already taken.
437    pub const fn take_text_stream(&mut self) -> Option<mpsc::Receiver<String>> {
438        self.rx.text.take()
439    }
440
441    /// Take the thinking token receiver.
442    ///
443    /// Returns `None` if the receiver was already taken.
444    pub const fn take_thought_stream(&mut self) -> Option<mpsc::Receiver<String>> {
445        self.rx.thought.take()
446    }
447
448    /// Take the tool call event receiver.
449    ///
450    /// Returns `None` if the receiver was already taken.
451    pub const fn take_tool_call_stream(&mut self) -> Option<mpsc::Receiver<ToolCallEvent>> {
452        self.rx.tool_call.take()
453    }
454
455    /// Take the raw step receiver.
456    ///
457    /// Returns `None` if the receiver was already taken.
458    /// Prefer [`receive_steps()`](Self::receive_steps) for `StreamExt`-compatible usage.
459    pub const fn take_step_stream(&mut self) -> Option<mpsc::Receiver<Step>> {
460        self.rx.step.take()
461    }
462
463    /// Take the step stream for consuming with `StreamExt::next()`.
464    ///
465    /// Returns `None` if the stream was already taken.
466    ///
467    /// # Example
468    ///
469    /// ```
470    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
471    /// use agy_bridge::streaming;
472    /// use tokio_stream::StreamExt;
473    ///
474    /// let (_writer, mut handle) = streaming::channel();
475    /// drop(_writer); // close the channel so the stream ends
476    /// let mut steps = handle.receive_steps().unwrap();
477    /// while let Some(step) = steps.next().await {
478    ///     println!("step: {:?}", step.step_type);
479    /// }
480    /// # });
481    pub fn receive_steps(&mut self) -> Option<impl tokio_stream::Stream<Item = Step>> {
482        self.rx.step.take().map(ReceiverStream::new)
483    }
484
485    /// Take the unified chunk stream for consuming with `StreamExt::next()`.
486    ///
487    /// Returns `None` if the stream was already taken.
488    ///
489    /// # Example
490    ///
491    /// ```
492    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
493    /// use agy_bridge::streaming::{self, StreamChunk};
494    /// use tokio_stream::StreamExt;
495    ///
496    /// let (_writer, mut handle) = streaming::channel();
497    /// drop(_writer); // close the channel so the stream ends
498    /// let mut chunks = handle.receive_chunks().unwrap();
499    /// while let Some(chunk) = chunks.next().await {
500    ///     match chunk {
501    ///         StreamChunk::Text(t) => print!("{t}"),
502    ///         StreamChunk::Thought(t) => eprintln!("thought: {t}"),
503    ///         StreamChunk::ToolCall(tc) => eprintln!("tool: {}", tc.name),
504    ///         _ => {}
505    ///     }
506    /// }
507    /// # });
508    pub fn receive_chunks(&mut self) -> Option<impl tokio_stream::Stream<Item = StreamChunk>> {
509        self.rx.chunk.take().map(ReceiverStream::new)
510    }
511
512    /// Drain the text stream and return the complete response text.
513    ///
514    /// Consumes the handle — use the `take_*` methods instead if you need
515    /// to keep streaming individual channels.
516    ///
517    /// # Errors
518    ///
519    /// Returns a [`StreamError`] if the Python side reported an error.
520    pub async fn text(mut self) -> Result<ChatResult, StreamError> {
521        let mut buf = String::new();
522
523        if let Some(mut rx) = self.rx.text.take() {
524            while let Some(token) = rx.recv().await {
525                buf.push_str(&token);
526            }
527        }
528
529        // Check for errors. Use a brief timeout rather than try_recv() to
530        // catch errors that are sent just after the text channel closes.
531        if let Some(mut err_rx) = self.rx.error.take()
532            && let Ok(Some(err)) = tokio::time::timeout(ERROR_DRAIN_TIMEOUT, err_rx.recv()).await
533        {
534            return Err(err);
535        }
536
537        self.finalize();
538
539        Ok(ChatResult {
540            text: buf,
541            usage: self.usage,
542            structured_output: self.structured_output_value,
543        })
544    }
545
546    /// Finalize the response handle by pulling usage and structured output
547    /// from the shared state. Called after the stream has been fully drained.
548    pub fn finalize(&mut self) {
549        if let Ok(state) = self.shared_state.lock() {
550            self.usage = state.usage.clone();
551            self.structured_output_value = state.structured_output.clone();
552        } else {
553            tracing::error!(
554                "ChatResponseHandle shared_state mutex poisoned during finalize — \
555                 usage and structured_output will be unavailable"
556            );
557        }
558    }
559
560    /// Return the structured output, if available.
561    ///
562    /// Only populated when the agent was configured with a `response_schema`
563    /// and the model returned a valid JSON payload.
564    #[must_use]
565    pub const fn structured_output(&self) -> Option<&serde_json::Value> {
566        self.structured_output_value.as_ref()
567    }
568
569    /// Return the token usage metadata, if available.
570    ///
571    /// Populated after [`finalize()`](Self::finalize) or [`text()`](Self::text).
572    #[must_use]
573    pub const fn usage_metadata(&self) -> Option<&UsageMetadata> {
574        self.usage.as_ref()
575    }
576
577    /// Return a reference-counted handle to the shared state.
578    ///
579    /// This allows callers to clone the `Arc` **before** consuming the handle
580    /// via [`text()`](Self::text) or [`resolve()`](Self::resolve), and then
581    /// read usage metadata / structured output from the shared state
582    /// afterwards.
583    #[doc(hidden)]
584    #[must_use]
585    pub fn shared_state(&self) -> Arc<Mutex<ChatResponseSharedState>> {
586        Arc::clone(&self.shared_state)
587    }
588
589    /// Drain all events and return them as an ordered timeline.
590    ///
591    /// Consumes the handle — use the `take_*` methods instead if you need
592    /// to keep streaming individual channels.
593    pub async fn resolve(mut self) -> Vec<ResponseEvent> {
594        let mut events = Vec::new();
595        if let Some(mut rx) = self.rx.event.take() {
596            while let Some(event) = rx.recv().await {
597                events.push(event);
598            }
599        }
600        self.finalize();
601        events
602    }
603}
604
605impl ChatResponseWriter {
606    /// Store usage metadata in the shared state so the handle can read it
607    /// after the stream completes.
608    pub fn set_usage(&self, usage: crate::types::UsageMetadata) {
609        match self.shared_state.lock() {
610            Ok(mut state) => {
611                state.usage = Some(usage);
612            }
613            Err(e) => {
614                tracing::error!(
615                    error = %e,
616                    "ChatResponseWriter shared_state mutex poisoned in set_usage"
617                );
618            }
619        }
620    }
621
622    /// Store structured output in the shared state so the handle can read it
623    /// after the stream completes.
624    pub fn set_structured_output(&self, value: serde_json::Value) {
625        match self.shared_state.lock() {
626            Ok(mut state) => {
627                state.structured_output = Some(value);
628            }
629            Err(e) => {
630                tracing::error!(
631                    error = %e,
632                    "ChatResponseWriter shared_state mutex poisoned in set_structured_output"
633                );
634            }
635        }
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642
643    #[tokio::test]
644    async fn streaming_receives_all_tokens_in_order() {
645        let (writer, mut handle) = channel();
646
647        let tokens = ["Hello", " ", "world", "!"];
648        let expected: String = tokens.iter().copied().collect();
649
650        // Simulate the Python bridge sending tokens
651        let send_task = tokio::spawn(async move {
652            for token in &["Hello", " ", "world", "!"] {
653                writer
654                    .text_tx
655                    .send((*token).to_owned())
656                    .await
657                    .expect("send should succeed");
658            }
659            // Dropping writer closes the channel
660        });
661
662        // Consume via the stream receiver
663        let mut rx = handle.take_text_stream().expect("should get receiver");
664        let mut received = Vec::new();
665        while let Some(token) = rx.recv().await {
666            received.push(token);
667        }
668
669        send_task.await.expect("send task should complete");
670        let full: String = received.iter().map(String::as_str).collect();
671        assert_eq!(full, expected);
672    }
673
674    #[tokio::test]
675    async fn text_returns_complete_response() {
676        let (writer, handle) = channel();
677
678        tokio::spawn(async move {
679            for token in &["The ", "answer ", "is ", "42."] {
680                writer
681                    .text_tx
682                    .send((*token).to_owned())
683                    .await
684                    .expect("send");
685            }
686        });
687
688        let text = handle.text().await.expect("should succeed");
689        assert_eq!(text, "The answer is 42.");
690    }
691
692    #[tokio::test]
693    async fn text_returns_empty_when_no_tokens() {
694        let (writer, handle) = channel();
695        // Drop the writer immediately to close the channel
696        drop(writer);
697
698        let text = handle.text().await.expect("should succeed");
699        assert!(text.is_empty());
700    }
701
702    #[tokio::test]
703    async fn stream_error_propagated() {
704        let (writer, handle) = channel();
705
706        tokio::spawn(async move {
707            writer
708                .text_tx
709                .send("partial".to_owned())
710                .await
711                .expect("send");
712            writer
713                .error_tx
714                .send(StreamError {
715                    message: "Python exception: quota exceeded".to_owned(),
716                })
717                .await
718                .expect("send error");
719        });
720
721        let result = handle.text().await;
722        assert!(result.is_err());
723        let err = result.unwrap_err();
724        assert!(err.message.contains("quota exceeded"));
725    }
726
727    #[tokio::test]
728    async fn thought_stream_works() {
729        let (writer, mut handle) = channel();
730
731        tokio::spawn(async move {
732            writer
733                .thought_tx
734                .send("thinking...".to_owned())
735                .await
736                .expect("send");
737            writer
738                .thought_tx
739                .send("done.".to_owned())
740                .await
741                .expect("send");
742        });
743
744        let mut rx = handle.take_thought_stream().expect("should get receiver");
745        let mut thoughts = Vec::new();
746        while let Some(t) = rx.recv().await {
747            thoughts.push(t);
748        }
749        assert_eq!(thoughts, vec!["thinking...", "done."]);
750    }
751
752    #[tokio::test]
753    async fn tool_call_stream_works() {
754        let (writer, mut handle) = channel();
755
756        let event = ToolCallEvent {
757            name: "view_file".to_owned(),
758            args: serde_json::json!({"path": "/tmp/test.txt"}),
759            id: Some("call_1".to_owned()),
760            canonical_path: None,
761        };
762
763        let event_clone = event.clone();
764        tokio::spawn(async move {
765            writer.tool_call_tx.send(event_clone).await.expect("send");
766        });
767
768        let mut rx = handle.take_tool_call_stream().expect("should get receiver");
769        let received = rx.recv().await.expect("should receive event");
770        assert_eq!(received.name, "view_file");
771        assert_eq!(received.id, Some("call_1".to_owned()));
772    }
773
774    #[tokio::test]
775    async fn usage_metadata_available_after_finalize() {
776        let (writer, mut handle) = channel();
777        assert!(handle.usage_metadata().is_none());
778
779        writer.set_usage(UsageMetadata {
780            prompt_token_count: Some(100),
781            cached_content_token_count: Some(10),
782            candidates_token_count: Some(50),
783            thoughts_token_count: Some(20),
784            total_token_count: Some(170),
785        });
786        drop(writer);
787        handle.finalize();
788
789        let usage = handle.usage_metadata().expect("should have usage");
790        assert_eq!(usage.prompt_token_count, Some(100));
791        assert_eq!(usage.total_token_count, Some(170));
792    }
793
794    #[test]
795    fn take_text_stream_returns_none_second_time() {
796        let (_writer, mut handle) = channel();
797        assert!(handle.take_text_stream().is_some());
798        assert!(handle.take_text_stream().is_none());
799    }
800
801    #[test]
802    fn tool_call_event_serde_roundtrip() {
803        let event = ToolCallEvent {
804            name: "run_command".to_owned(),
805            args: serde_json::json!({"command": "ls"}),
806            id: Some("call_42".to_owned()),
807            canonical_path: None,
808        };
809        let json = serde_json::to_string(&event).expect("serialize");
810        let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
811        assert_eq!(parsed.name, event.name);
812        assert_eq!(parsed.args, event.args);
813        assert_eq!(parsed.id, event.id);
814    }
815
816    #[test]
817    fn take_thought_stream_returns_none_second_time() {
818        let (_writer, mut handle) = channel();
819        assert!(handle.take_thought_stream().is_some());
820        assert!(handle.take_thought_stream().is_none());
821    }
822
823    #[test]
824    fn take_tool_call_stream_returns_none_second_time() {
825        let (_writer, mut handle) = channel();
826        assert!(handle.take_tool_call_stream().is_some());
827        assert!(handle.take_tool_call_stream().is_none());
828    }
829
830    #[test]
831    fn stream_error_display() {
832        let err = StreamError {
833            message: "quota exceeded".to_owned(),
834        };
835        assert_eq!(format!("{err}"), "stream error: quota exceeded");
836    }
837
838    #[test]
839    fn stream_error_is_std_error() {
840        let err = StreamError {
841            message: "test".to_owned(),
842        };
843        // Verify it implements std::error::Error
844        let _: &dyn std::error::Error = &err;
845    }
846
847    #[tokio::test]
848    async fn concurrent_text_and_thought_streams() {
849        let (writer, mut handle) = channel();
850
851        tokio::spawn(async move {
852            writer
853                .text_tx
854                .send("Hello".to_owned())
855                .await
856                .expect("send text");
857            writer
858                .thought_tx
859                .send("thinking...".to_owned())
860                .await
861                .expect("send thought");
862        });
863
864        let mut text_rx = handle.take_text_stream().expect("text rx");
865        let mut thought_rx = handle.take_thought_stream().expect("thought rx");
866
867        let text = text_rx.recv().await.expect("receive text");
868        let thought = thought_rx.recv().await.expect("receive thought");
869
870        assert_eq!(text, "Hello");
871        assert_eq!(thought, "thinking...");
872    }
873
874    #[tokio::test]
875    async fn writer_dropped_without_sending_closes_text() {
876        let (writer, handle) = channel();
877        drop(writer);
878
879        let text = handle.text().await.expect("should succeed");
880        assert!(text.is_empty());
881    }
882
883    #[tokio::test]
884    async fn writer_dropped_without_sending_closes_thought_stream() {
885        let (writer, mut handle) = channel();
886        drop(writer);
887
888        let mut thought_rx = handle.take_thought_stream().expect("rx");
889        assert!(thought_rx.recv().await.is_none());
890    }
891
892    #[test]
893    fn tool_call_event_without_id() {
894        let event = ToolCallEvent {
895            name: "custom".to_owned(),
896            args: serde_json::json!(null),
897            id: None,
898            canonical_path: None,
899        };
900        let json = serde_json::to_string(&event).expect("serialize");
901        let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
902        assert_eq!(parsed.name, "custom");
903        assert_eq!(parsed.args, serde_json::json!(null));
904    }
905
906    #[tokio::test]
907    async fn large_token_stream() {
908        let (writer, handle) = channel();
909        let token_count = 200;
910
911        tokio::spawn(async move {
912            for i in 0..token_count {
913                writer.text_tx.send(format!("t{i}")).await.expect("send");
914            }
915        });
916
917        let text = handle.text().await.expect("should succeed");
918        // Verify all 200 tokens were collected
919        for i in 0..token_count {
920            assert!(
921                text.contains(&format!("t{i}")),
922                "Missing token t{i} in output"
923            );
924        }
925    }
926
927    #[tokio::test]
928    async fn resolve_returns_events_in_order() {
929        let (writer, handle) = channel();
930
931        let tool_event = ToolCallEvent {
932            name: "view_file".to_owned(),
933            args: serde_json::json!({"path": "/tmp/x.rs"}),
934            id: Some("call_1".to_owned()),
935            canonical_path: None,
936        };
937
938        let tool_clone = tool_event.clone();
939        tokio::spawn(async move {
940            writer
941                .event_tx
942                .send(ResponseEvent::TextChunk("Hello ".to_owned()))
943                .await
944                .expect("send");
945            writer
946                .event_tx
947                .send(ResponseEvent::ThoughtChunk("hmm".to_owned()))
948                .await
949                .expect("send");
950            writer
951                .event_tx
952                .send(ResponseEvent::ToolCall(tool_clone))
953                .await
954                .expect("send");
955            writer
956                .event_tx
957                .send(ResponseEvent::TextChunk("world".to_owned()))
958                .await
959                .expect("send");
960            writer
961                .event_tx
962                .send(ResponseEvent::ToolResult(crate::types::ToolResult {
963                    name: "view_file".to_owned(),
964                    id: Some("call_1".to_owned()),
965                    result: serde_json::json!({"output": "file contents"}),
966                    error: None,
967                }))
968                .await
969                .expect("send");
970            // Drop writer to close the channel
971        });
972
973        let events = handle.resolve().await;
974        assert_eq!(events.len(), 5, "Expected 5 events, got {}", events.len());
975
976        // Verify ordering and types
977        assert!(
978            matches!(&events[0], ResponseEvent::TextChunk(s) if s == "Hello "),
979            "events[0] should be TextChunk(\"Hello \")"
980        );
981        assert!(
982            matches!(&events[1], ResponseEvent::ThoughtChunk(s) if s == "hmm"),
983            "events[1] should be ThoughtChunk(\"hmm\")"
984        );
985        assert!(
986            matches!(&events[2], ResponseEvent::ToolCall(tc) if tc.name == "view_file"),
987            "events[2] should be ToolCall(view_file)"
988        );
989        assert!(
990            matches!(&events[3], ResponseEvent::TextChunk(s) if s == "world"),
991            "events[3] should be TextChunk(\"world\")"
992        );
993        assert!(
994            matches!(&events[4], ResponseEvent::ToolResult(tr) if tr.name == "view_file"),
995            "events[4] should be ToolResult(view_file)"
996        );
997    }
998
999    #[test]
1000    fn response_event_serde_roundtrip() {
1001        let events = vec![
1002            ResponseEvent::TextChunk("hello".to_owned()),
1003            ResponseEvent::ThoughtChunk("thinking".to_owned()),
1004            ResponseEvent::ToolCall(ToolCallEvent {
1005                name: "run_command".to_owned(),
1006                args: serde_json::json!({"cmd": "ls"}),
1007                id: Some("c1".to_owned()),
1008                canonical_path: None,
1009            }),
1010            ResponseEvent::ToolResult(crate::types::ToolResult {
1011                name: "run_command".to_owned(),
1012                id: Some("c1".to_owned()),
1013                result: serde_json::json!({"output": "done"}),
1014                error: None,
1015            }),
1016        ];
1017
1018        let json = serde_json::to_string(&events).expect("serialize");
1019        let parsed: Vec<ResponseEvent> = serde_json::from_str(&json).expect("deserialize");
1020        assert_eq!(parsed.len(), events.len());
1021    }
1022
1023    // ── receive_chunks / receive_steps tests ─────────────────────────────
1024
1025    #[tokio::test]
1026    async fn receive_chunks_returns_chunks_in_order() {
1027        use tokio_stream::StreamExt;
1028
1029        let (writer, mut handle) = channel();
1030
1031        tokio::spawn(async move {
1032            writer
1033                .chunk_tx
1034                .send(StreamChunk::Text("hello".to_owned()))
1035                .await
1036                .expect("send");
1037            writer
1038                .chunk_tx
1039                .send(StreamChunk::Thought("hmm".to_owned()))
1040                .await
1041                .expect("send");
1042            writer
1043                .chunk_tx
1044                .send(StreamChunk::ToolCall(ToolCallEvent {
1045                    name: "view_file".to_owned(),
1046                    args: serde_json::json!({}),
1047                    id: None,
1048                    canonical_path: None,
1049                }))
1050                .await
1051                .expect("send");
1052            writer
1053                .chunk_tx
1054                .send(StreamChunk::Text(" world".to_owned()))
1055                .await
1056                .expect("send");
1057        });
1058
1059        let mut stream = handle.receive_chunks().expect("should get stream");
1060        let mut items = Vec::new();
1061        while let Some(chunk) = stream.next().await {
1062            items.push(chunk);
1063        }
1064
1065        assert_eq!(items.len(), 4);
1066        assert!(matches!(&items[0], StreamChunk::Text(t) if t == "hello"));
1067        assert!(matches!(&items[1], StreamChunk::Thought(t) if t == "hmm"));
1068        assert!(matches!(&items[2], StreamChunk::ToolCall(tc) if tc.name == "view_file"));
1069        assert!(matches!(&items[3], StreamChunk::Text(t) if t == " world"));
1070    }
1071
1072    #[tokio::test]
1073    async fn receive_steps_returns_steps() {
1074        use tokio_stream::StreamExt;
1075
1076        let (writer, mut handle) = channel();
1077
1078        tokio::spawn(async move {
1079            writer
1080                .step_tx
1081                .send(crate::types::Step {
1082                    id: "step-0".to_owned(),
1083                    step_index: 0,
1084                    step_type: crate::types::StepType::TextResponse,
1085                    source: crate::types::StepSource::Model,
1086                    target: crate::types::StepTarget::User,
1087                    status: crate::types::StepStatus::Done,
1088                    content: "Hello".to_owned(),
1089                    content_delta: "Hello".to_owned(),
1090                    thinking: String::new(),
1091                    thinking_delta: String::new(),
1092                    tool_calls: vec![],
1093                    error: String::new(),
1094                    is_complete_response: Some(true),
1095                    structured_output: None,
1096                    usage_metadata: None,
1097                })
1098                .await
1099                .expect("send");
1100        });
1101
1102        let mut stream = handle.receive_steps().expect("should get stream");
1103        let step = stream.next().await.expect("should get a step");
1104        assert_eq!(step.id, "step-0");
1105        assert_eq!(step.step_type, crate::types::StepType::TextResponse);
1106        assert_eq!(step.content, "Hello");
1107    }
1108
1109    #[tokio::test]
1110    async fn existing_channels_work_alongside_chunk_stream() {
1111        use tokio_stream::StreamExt;
1112
1113        let (writer, mut handle) = channel();
1114
1115        tokio::spawn(async move {
1116            // Send through both the dedicated text channel and the chunk channel.
1117            writer
1118                .text_tx
1119                .send("text-tok".to_owned())
1120                .await
1121                .expect("send text");
1122            writer
1123                .chunk_tx
1124                .send(StreamChunk::Text("text-tok".to_owned()))
1125                .await
1126                .expect("send chunk");
1127        });
1128
1129        let mut text_rx = handle.take_text_stream().expect("text rx");
1130        let text = text_rx.recv().await.expect("receive text");
1131        assert_eq!(text, "text-tok");
1132
1133        let mut chunk_stream = handle.receive_chunks().expect("chunk stream");
1134        let chunk = chunk_stream.next().await.expect("receive chunk");
1135        assert!(matches!(chunk, StreamChunk::Text(t) if t == "text-tok"));
1136    }
1137
1138    #[test]
1139    fn receive_chunks_returns_none_on_second_call() {
1140        let (_writer, mut handle) = channel();
1141        assert!(handle.receive_chunks().is_some());
1142        assert!(handle.receive_chunks().is_none());
1143    }
1144
1145    #[test]
1146    fn receive_steps_returns_none_on_second_call() {
1147        let (_writer, mut handle) = channel();
1148        assert!(handle.receive_steps().is_some());
1149        assert!(handle.receive_steps().is_none());
1150    }
1151
1152    #[test]
1153    fn stream_chunk_serde_roundtrip() {
1154        let chunks = vec![
1155            StreamChunk::Text("hello".to_owned()),
1156            StreamChunk::Thought("hmm".to_owned()),
1157            StreamChunk::ToolCall(ToolCallEvent {
1158                name: "run".to_owned(),
1159                args: serde_json::json!({"cmd": "ls"}),
1160                id: Some("c1".to_owned()),
1161                canonical_path: None,
1162            }),
1163        ];
1164        for chunk in &chunks {
1165            let json = serde_json::to_string(chunk).expect("serialize");
1166            let parsed: StreamChunk = serde_json::from_str(&json).expect("deserialize");
1167            // Verify discriminant matches.
1168            match (chunk, &parsed) {
1169                (StreamChunk::Text(a), StreamChunk::Text(b))
1170                | (StreamChunk::Thought(a), StreamChunk::Thought(b)) => assert_eq!(a, b),
1171                (StreamChunk::ToolCall(a), StreamChunk::ToolCall(b)) => {
1172                    assert_eq!(a.name, b.name);
1173                    assert_eq!(a.id, b.id);
1174                }
1175                _ => panic!("variant mismatch after roundtrip"),
1176            }
1177        }
1178    }
1179
1180    #[tokio::test]
1181    async fn usage_metadata_populated_from_writer_after_resolve() {
1182        let (writer, handle) = channel();
1183
1184        tokio::spawn(async move {
1185            writer
1186                .event_tx
1187                .send(ResponseEvent::TextChunk("hello".to_owned()))
1188                .await
1189                .unwrap();
1190            writer.set_usage(crate::types::UsageMetadata {
1191                prompt_token_count: Some(5),
1192                cached_content_token_count: None,
1193                candidates_token_count: Some(1),
1194                thoughts_token_count: None,
1195                total_token_count: Some(6),
1196            });
1197            writer.set_structured_output(serde_json::json!({"key": "value"}));
1198        });
1199
1200        // resolve() consumes the handle but finalize() runs internally,
1201        // so we verify via the shared state directly instead.
1202        let shared = handle.shared_state();
1203        let events = handle.resolve().await;
1204        assert_eq!(events.len(), 1);
1205
1206        let state = shared.lock().expect("lock shared state");
1207        assert_eq!(state.usage.as_ref().unwrap().total_token_count, Some(6));
1208        assert_eq!(
1209            state.structured_output.as_ref().unwrap(),
1210            &serde_json::json!({"key": "value"})
1211        );
1212    }
1213
1214    #[test]
1215    fn chat_result_into_string() {
1216        let (writer, handle) = channel();
1217        drop(writer);
1218        let rt = tokio::runtime::Runtime::new().unwrap();
1219        let result = rt.block_on(handle.text()).unwrap();
1220        let s: String = result.into();
1221        assert!(s.is_empty());
1222    }
1223}