Skip to main content

agy_bridge/streaming/
mod.rs

1//! Streaming response bridge for the Antigravity SDK.
2//!
3//! Bridges the SDK's `ChatResponse` (Python async iterator) to tokio channels
4//! so Rust consumers can stream text tokens, thinking tokens, and tool call
5//! events independently.
6
7mod handle;
8mod types;
9mod writer;
10
11use std::sync::{Arc, Mutex};
12
13use tokio::sync::mpsc;
14
15use self::types::StreamReceivers;
16pub use self::{
17    handle::ChatResponseHandle,
18    types::{
19        ChatResponseSharedState, ChatResult, ResponseEvent, StreamChunk, StreamError, ToolCallEvent,
20    },
21    writer::{ChatResponseWriter, WriterError},
22};
23
24/// Default channel buffer size. Large enough to avoid backpressure during
25/// normal operation while bounding memory usage.
26const CHANNEL_BUFFER: usize = 256;
27
28/// Create a paired `(ChatResponseWriter, ChatResponseHandle)`.
29///
30/// The writer is handed to the Python bridge thread; the handle is returned
31/// to the Rust caller.
32#[must_use]
33pub fn channel() -> (ChatResponseWriter, ChatResponseHandle) {
34    let (text_tx, text_rx) = mpsc::channel(CHANNEL_BUFFER);
35    let (thought_tx, thought_rx) = mpsc::channel(CHANNEL_BUFFER);
36    let (tool_call_tx, tool_call_rx) = mpsc::channel(CHANNEL_BUFFER);
37    let (error_tx, error_rx) = mpsc::channel(1);
38    let (event_tx, event_rx) = mpsc::channel(CHANNEL_BUFFER);
39    let (step_tx, step_rx) = mpsc::channel(CHANNEL_BUFFER);
40    let (chunk_tx, chunk_rx) = mpsc::channel(CHANNEL_BUFFER);
41
42    let shared_state = Arc::new(Mutex::new(ChatResponseSharedState::default()));
43
44    let writer = ChatResponseWriter {
45        text_tx,
46        thought_tx,
47        tool_call_tx,
48        error_tx,
49        event_tx,
50        step_tx,
51        chunk_tx,
52        shared_state: Arc::clone(&shared_state),
53    };
54
55    let handle = ChatResponseHandle {
56        rx: StreamReceivers::new(
57            text_rx,
58            thought_rx,
59            tool_call_rx,
60            error_rx,
61            event_rx,
62            step_rx,
63            chunk_rx,
64        ),
65        usage: None,
66        structured_output_value: None,
67        shared_state,
68    };
69
70    (writer, handle)
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[tokio::test]
78    async fn streaming_receives_all_tokens_in_order() {
79        let (writer, mut handle) = channel();
80
81        let tokens = ["Hello", " ", "world", "!"];
82        let expected: String = tokens.iter().copied().collect();
83
84        // Simulate the Python bridge sending tokens
85        let send_task = tokio::spawn(async move {
86            for token in &["Hello", " ", "world", "!"] {
87                writer
88                    .text_tx
89                    .send((*token).to_owned())
90                    .await
91                    .expect("send should succeed");
92            }
93            // Dropping writer closes the channel
94        });
95
96        // Consume via the stream receiver
97        let mut rx = handle.take_text_stream().expect("should get receiver");
98        let mut received = Vec::new();
99        while let Some(token) = rx.recv().await {
100            received.push(token);
101        }
102
103        send_task.await.expect("send task should complete");
104        let full: String = received.iter().map(String::as_str).collect();
105        assert_eq!(full, expected);
106    }
107
108    #[tokio::test]
109    async fn text_returns_complete_response() {
110        let (writer, handle) = channel();
111
112        tokio::spawn(async move {
113            for token in &["The ", "answer ", "is ", "42."] {
114                writer
115                    .text_tx
116                    .send((*token).to_owned())
117                    .await
118                    .expect("send");
119            }
120        });
121
122        let text = handle.text().await.expect("should succeed");
123        assert_eq!(text, "The answer is 42.");
124    }
125
126    #[tokio::test]
127    async fn text_returns_empty_when_no_tokens() {
128        let (writer, handle) = channel();
129        // Drop the writer immediately to close the channel
130        drop(writer);
131
132        let text = handle.text().await.expect("should succeed");
133        assert!(text.is_empty());
134    }
135
136    #[tokio::test]
137    async fn stream_error_propagated() {
138        let (writer, handle) = channel();
139
140        tokio::spawn(async move {
141            writer
142                .text_tx
143                .send("partial".to_owned())
144                .await
145                .expect("send");
146            writer
147                .error_tx
148                .send(StreamError {
149                    message: "Python exception: quota exceeded".to_owned(),
150                })
151                .await
152                .expect("send error");
153        });
154
155        let result = handle.text().await;
156        assert!(result.is_err());
157        let err = result.unwrap_err();
158        assert!(err.message.contains("quota exceeded"));
159    }
160
161    #[tokio::test]
162    async fn thought_stream_works() {
163        let (writer, mut handle) = channel();
164
165        tokio::spawn(async move {
166            writer
167                .thought_tx
168                .send("thinking...".to_owned())
169                .await
170                .expect("send");
171            writer
172                .thought_tx
173                .send("done.".to_owned())
174                .await
175                .expect("send");
176        });
177
178        let mut rx = handle.take_thought_stream().expect("should get receiver");
179        let mut thoughts = Vec::new();
180        while let Some(t) = rx.recv().await {
181            thoughts.push(t);
182        }
183        assert_eq!(thoughts, vec!["thinking...", "done."]);
184    }
185
186    #[tokio::test]
187    async fn tool_call_stream_works() {
188        let (writer, mut handle) = channel();
189
190        let event = ToolCallEvent {
191            name: "view_file".to_owned(),
192            args: serde_json::json!({"path": "/tmp/test.txt"}),
193            id: Some("call_1".to_owned()),
194            canonical_path: None,
195        };
196
197        let event_clone = event.clone();
198        tokio::spawn(async move {
199            writer.tool_call_tx.send(event_clone).await.expect("send");
200        });
201
202        let mut rx = handle.take_tool_call_stream().expect("should get receiver");
203        let received = rx.recv().await.expect("should receive event");
204        assert_eq!(received.name, "view_file");
205        assert_eq!(received.id, Some("call_1".to_owned()));
206    }
207
208    #[tokio::test]
209    async fn usage_metadata_available_after_finalize() {
210        let (writer, mut handle) = channel();
211        assert!(handle.usage_metadata().is_none());
212
213        writer.set_usage(crate::types::UsageMetadata {
214            prompt_token_count: Some(100),
215            cached_content_token_count: Some(10),
216            candidates_token_count: Some(50),
217            thoughts_token_count: Some(20),
218            total_token_count: Some(170),
219        });
220        drop(writer);
221        handle.finalize();
222
223        let usage = handle.usage_metadata().expect("should have usage");
224        assert_eq!(usage.prompt_token_count, Some(100));
225        assert_eq!(usage.total_token_count, Some(170));
226    }
227
228    #[test]
229    fn take_text_stream_returns_none_second_time() {
230        let (_writer, mut handle) = channel();
231        assert!(handle.take_text_stream().is_some());
232        assert!(handle.take_text_stream().is_none());
233    }
234
235    #[test]
236    fn tool_call_event_serde_roundtrip() {
237        let event = ToolCallEvent {
238            name: "run_command".to_owned(),
239            args: serde_json::json!({"command": "ls"}),
240            id: Some("call_42".to_owned()),
241            canonical_path: None,
242        };
243        let json = serde_json::to_string(&event).expect("serialize");
244        let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
245        assert_eq!(parsed.name, event.name);
246        assert_eq!(parsed.args, event.args);
247        assert_eq!(parsed.id, event.id);
248    }
249
250    #[test]
251    fn take_thought_stream_returns_none_second_time() {
252        let (_writer, mut handle) = channel();
253        assert!(handle.take_thought_stream().is_some());
254        assert!(handle.take_thought_stream().is_none());
255    }
256
257    #[test]
258    fn take_tool_call_stream_returns_none_second_time() {
259        let (_writer, mut handle) = channel();
260        assert!(handle.take_tool_call_stream().is_some());
261        assert!(handle.take_tool_call_stream().is_none());
262    }
263
264    #[test]
265    fn stream_error_display() {
266        let err = StreamError {
267            message: "quota exceeded".to_owned(),
268        };
269        assert_eq!(format!("{err}"), "stream error: quota exceeded");
270    }
271
272    #[test]
273    fn stream_error_is_std_error() {
274        let err = StreamError {
275            message: "test".to_owned(),
276        };
277        // Verify it implements std::error::Error
278        let _: &dyn std::error::Error = &err;
279    }
280
281    #[tokio::test]
282    async fn concurrent_text_and_thought_streams() {
283        let (writer, mut handle) = channel();
284
285        tokio::spawn(async move {
286            writer
287                .text_tx
288                .send("Hello".to_owned())
289                .await
290                .expect("send text");
291            writer
292                .thought_tx
293                .send("thinking...".to_owned())
294                .await
295                .expect("send thought");
296        });
297
298        let mut text_rx = handle.take_text_stream().expect("text rx");
299        let mut thought_rx = handle.take_thought_stream().expect("thought rx");
300
301        let text = text_rx.recv().await.expect("receive text");
302        let thought = thought_rx.recv().await.expect("receive thought");
303
304        assert_eq!(text, "Hello");
305        assert_eq!(thought, "thinking...");
306    }
307
308    #[tokio::test]
309    async fn writer_dropped_without_sending_closes_text() {
310        let (writer, handle) = channel();
311        drop(writer);
312
313        let text = handle.text().await.expect("should succeed");
314        assert!(text.is_empty());
315    }
316
317    #[tokio::test]
318    async fn writer_dropped_without_sending_closes_thought_stream() {
319        let (writer, mut handle) = channel();
320        drop(writer);
321
322        let mut thought_rx = handle.take_thought_stream().expect("rx");
323        assert!(thought_rx.recv().await.is_none());
324    }
325
326    #[test]
327    fn tool_call_event_without_id() {
328        let event = ToolCallEvent {
329            name: "custom".to_owned(),
330            args: serde_json::json!(null),
331            id: None,
332            canonical_path: None,
333        };
334        let json = serde_json::to_string(&event).expect("serialize");
335        let parsed: ToolCallEvent = serde_json::from_str(&json).expect("deserialize");
336        assert_eq!(parsed.name, "custom");
337        assert_eq!(parsed.args, serde_json::json!(null));
338    }
339
340    #[tokio::test]
341    async fn large_token_stream() {
342        let (writer, handle) = channel();
343        let token_count = 200;
344
345        tokio::spawn(async move {
346            for i in 0..token_count {
347                writer.text_tx.send(format!("t{i}")).await.expect("send");
348            }
349        });
350
351        let text = handle.text().await.expect("should succeed");
352        // Verify all 200 tokens were collected
353        for i in 0..token_count {
354            assert!(
355                text.contains(&format!("t{i}")),
356                "Missing token t{i} in output"
357            );
358        }
359    }
360
361    #[tokio::test]
362    async fn resolve_returns_events_in_order() {
363        let (writer, handle) = channel();
364
365        let tool_event = ToolCallEvent {
366            name: "view_file".to_owned(),
367            args: serde_json::json!({"path": "/tmp/x.rs"}),
368            id: Some("call_1".to_owned()),
369            canonical_path: None,
370        };
371
372        let tool_clone = tool_event.clone();
373        tokio::spawn(async move {
374            writer
375                .event_tx
376                .send(ResponseEvent::TextChunk("Hello ".to_owned()))
377                .await
378                .expect("send");
379            writer
380                .event_tx
381                .send(ResponseEvent::ThoughtChunk("hmm".to_owned()))
382                .await
383                .expect("send");
384            writer
385                .event_tx
386                .send(ResponseEvent::ToolCall(tool_clone))
387                .await
388                .expect("send");
389            writer
390                .event_tx
391                .send(ResponseEvent::TextChunk("world".to_owned()))
392                .await
393                .expect("send");
394            writer
395                .event_tx
396                .send(ResponseEvent::ToolResult(crate::types::ToolResult {
397                    name: "view_file".to_owned(),
398                    id: Some("call_1".to_owned()),
399                    result: serde_json::json!({"output": "file contents"}),
400                    error: None,
401                }))
402                .await
403                .expect("send");
404            // Drop writer to close the channel
405        });
406
407        let events = handle.resolve().await;
408        assert_eq!(events.len(), 5, "Expected 5 events, got {}", events.len());
409
410        // Verify ordering and types
411        assert!(
412            matches!(&events[0], ResponseEvent::TextChunk(s) if s == "Hello "),
413            "events[0] should be TextChunk(\"Hello \")"
414        );
415        assert!(
416            matches!(&events[1], ResponseEvent::ThoughtChunk(s) if s == "hmm"),
417            "events[1] should be ThoughtChunk(\"hmm\")"
418        );
419        assert!(
420            matches!(&events[2], ResponseEvent::ToolCall(tc) if tc.name == "view_file"),
421            "events[2] should be ToolCall(view_file)"
422        );
423        assert!(
424            matches!(&events[3], ResponseEvent::TextChunk(s) if s == "world"),
425            "events[3] should be TextChunk(\"world\")"
426        );
427        assert!(
428            matches!(&events[4], ResponseEvent::ToolResult(tr) if tr.name == "view_file"),
429            "events[4] should be ToolResult(view_file)"
430        );
431    }
432
433    #[test]
434    fn response_event_serde_roundtrip() {
435        let events = vec![
436            ResponseEvent::TextChunk("hello".to_owned()),
437            ResponseEvent::ThoughtChunk("thinking".to_owned()),
438            ResponseEvent::ToolCall(ToolCallEvent {
439                name: "run_command".to_owned(),
440                args: serde_json::json!({"cmd": "ls"}),
441                id: Some("c1".to_owned()),
442                canonical_path: None,
443            }),
444            ResponseEvent::ToolResult(crate::types::ToolResult {
445                name: "run_command".to_owned(),
446                id: Some("c1".to_owned()),
447                result: serde_json::json!({"output": "done"}),
448                error: None,
449            }),
450        ];
451
452        let json = serde_json::to_string(&events).expect("serialize");
453        let parsed: Vec<ResponseEvent> = serde_json::from_str(&json).expect("deserialize");
454        assert_eq!(parsed.len(), events.len());
455    }
456
457    // ── receive_chunks / receive_steps tests ─────────────────────────────
458
459    #[tokio::test]
460    async fn receive_chunks_returns_chunks_in_order() {
461        use tokio_stream::StreamExt;
462
463        let (writer, mut handle) = channel();
464
465        tokio::spawn(async move {
466            writer
467                .chunk_tx
468                .send(StreamChunk::Text("hello".to_owned()))
469                .await
470                .expect("send");
471            writer
472                .chunk_tx
473                .send(StreamChunk::Thought("hmm".to_owned()))
474                .await
475                .expect("send");
476            writer
477                .chunk_tx
478                .send(StreamChunk::ToolCall(ToolCallEvent {
479                    name: "view_file".to_owned(),
480                    args: serde_json::json!({}),
481                    id: None,
482                    canonical_path: None,
483                }))
484                .await
485                .expect("send");
486            writer
487                .chunk_tx
488                .send(StreamChunk::Text(" world".to_owned()))
489                .await
490                .expect("send");
491        });
492
493        let mut stream = handle.receive_chunks().expect("should get stream");
494        let mut items = Vec::new();
495        while let Some(chunk) = stream.next().await {
496            items.push(chunk);
497        }
498
499        assert_eq!(items.len(), 4);
500        assert!(matches!(&items[0], StreamChunk::Text(t) if t == "hello"));
501        assert!(matches!(&items[1], StreamChunk::Thought(t) if t == "hmm"));
502        assert!(matches!(&items[2], StreamChunk::ToolCall(tc) if tc.name == "view_file"));
503        assert!(matches!(&items[3], StreamChunk::Text(t) if t == " world"));
504    }
505
506    #[tokio::test]
507    async fn receive_steps_returns_steps() {
508        use tokio_stream::StreamExt;
509
510        let (writer, mut handle) = channel();
511
512        tokio::spawn(async move {
513            writer
514                .step_tx
515                .send(crate::types::Step {
516                    id: "step-0".to_owned(),
517                    step_index: 0,
518                    step_type: crate::types::StepType::TextResponse,
519                    source: crate::types::StepSource::Model,
520                    target: crate::types::StepTarget::User,
521                    status: crate::types::StepStatus::Done,
522                    content: "Hello".to_owned(),
523                    content_delta: "Hello".to_owned(),
524                    thinking: String::new(),
525                    thinking_delta: String::new(),
526                    tool_calls: vec![],
527                    error: String::new(),
528                    is_complete_response: Some(true),
529                    structured_output: None,
530                    usage_metadata: None,
531                })
532                .await
533                .expect("send");
534        });
535
536        let mut stream = handle.receive_steps().expect("should get stream");
537        let step = stream.next().await.expect("should get a step");
538        assert_eq!(step.id, "step-0");
539        assert_eq!(step.step_type, crate::types::StepType::TextResponse);
540        assert_eq!(step.content, "Hello");
541    }
542
543    #[tokio::test]
544    async fn existing_channels_work_alongside_chunk_stream() {
545        use tokio_stream::StreamExt;
546
547        let (writer, mut handle) = channel();
548
549        tokio::spawn(async move {
550            // Send through both the dedicated text channel and the chunk channel.
551            writer
552                .text_tx
553                .send("text-tok".to_owned())
554                .await
555                .expect("send text");
556            writer
557                .chunk_tx
558                .send(StreamChunk::Text("text-tok".to_owned()))
559                .await
560                .expect("send chunk");
561        });
562
563        let mut text_rx = handle.take_text_stream().expect("text rx");
564        let text = text_rx.recv().await.expect("receive text");
565        assert_eq!(text, "text-tok");
566
567        let mut chunk_stream = handle.receive_chunks().expect("chunk stream");
568        let chunk = chunk_stream.next().await.expect("receive chunk");
569        assert!(matches!(chunk, StreamChunk::Text(t) if t == "text-tok"));
570    }
571
572    #[test]
573    fn receive_chunks_returns_none_on_second_call() {
574        let (_writer, mut handle) = channel();
575        assert!(handle.receive_chunks().is_some());
576        assert!(handle.receive_chunks().is_none());
577    }
578
579    #[test]
580    fn receive_steps_returns_none_on_second_call() {
581        let (_writer, mut handle) = channel();
582        assert!(handle.receive_steps().is_some());
583        assert!(handle.receive_steps().is_none());
584    }
585
586    #[test]
587    fn stream_chunk_serde_roundtrip() {
588        let chunks = vec![
589            StreamChunk::Text("hello".to_owned()),
590            StreamChunk::Thought("hmm".to_owned()),
591            StreamChunk::ToolCall(ToolCallEvent {
592                name: "run".to_owned(),
593                args: serde_json::json!({"cmd": "ls"}),
594                id: Some("c1".to_owned()),
595                canonical_path: None,
596            }),
597        ];
598        for chunk in &chunks {
599            let json = serde_json::to_string(chunk).expect("serialize");
600            let parsed: StreamChunk = serde_json::from_str(&json).expect("deserialize");
601            // Verify discriminant matches.
602            match (chunk, &parsed) {
603                (StreamChunk::Text(a), StreamChunk::Text(b))
604                | (StreamChunk::Thought(a), StreamChunk::Thought(b)) => assert_eq!(a, b),
605                (StreamChunk::ToolCall(a), StreamChunk::ToolCall(b)) => {
606                    assert_eq!(a.name, b.name);
607                    assert_eq!(a.id, b.id);
608                }
609                _ => panic!("variant mismatch after roundtrip"),
610            }
611        }
612    }
613
614    #[tokio::test]
615    async fn usage_metadata_populated_from_writer_after_resolve() {
616        let (writer, handle) = channel();
617
618        tokio::spawn(async move {
619            writer
620                .event_tx
621                .send(ResponseEvent::TextChunk("hello".to_owned()))
622                .await
623                .unwrap();
624            writer.set_usage(crate::types::UsageMetadata {
625                prompt_token_count: Some(5),
626                cached_content_token_count: None,
627                candidates_token_count: Some(1),
628                thoughts_token_count: None,
629                total_token_count: Some(6),
630            });
631            writer.set_structured_output(serde_json::json!({"key": "value"}));
632        });
633
634        // resolve() consumes the handle but finalize() runs internally,
635        // so we verify via the shared state directly instead.
636        let shared = handle.shared_state();
637        let events = handle.resolve().await;
638        assert_eq!(events.len(), 1);
639
640        let state = shared.lock().expect("lock shared state");
641        assert_eq!(state.usage.as_ref().unwrap().total_token_count, Some(6));
642        assert_eq!(
643            state.structured_output.as_ref().unwrap(),
644            &serde_json::json!({"key": "value"})
645        );
646    }
647
648    #[test]
649    fn chat_result_into_string() {
650        let (writer, handle) = channel();
651        drop(writer);
652        let rt = tokio::runtime::Runtime::new().unwrap();
653        let result = rt.block_on(handle.text()).unwrap();
654        let s: String = result.into();
655        assert!(s.is_empty());
656    }
657}