Skip to main content

aagt_core/agent/
streaming.rs

1//! Streaming response types
2
3use std::collections::HashMap;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures::Stream;
8
9use crate::error::Error;
10use crate::agent::message::ToolCall;
11
12/// A chunk from a streaming response
13#[derive(Debug, Clone)]
14pub enum StreamingChoice {
15    /// Text content chunk
16    Message(String),
17
18    /// Single tool call (sequential)
19    ToolCall {
20        /// Tool call ID
21        id: String,
22        /// Tool name
23        name: String,
24        /// Arguments as JSON
25        arguments: serde_json::Value,
26    },
27
28    /// Multiple tool calls (parallel)
29    ParallelToolCalls(HashMap<usize, ToolCall>),
30
31    /// Thinking/reasoning chunk (e.g., Gemini's thoughts)
32    Thought(String),
33
34    /// Stream finished
35    Done,
36}
37
38impl StreamingChoice {
39    /// Check if this is a message chunk
40    pub fn is_message(&self) -> bool {
41        matches!(self, Self::Message(_))
42    }
43
44    /// Check if this is a tool call
45    pub fn is_tool_call(&self) -> bool {
46        matches!(self, Self::ToolCall { .. } | Self::ParallelToolCalls(_))
47    }
48
49    /// Check if stream is done
50    pub fn is_done(&self) -> bool {
51        matches!(self, Self::Done)
52    }
53
54    /// Get message text if this is a message chunk
55    pub fn as_message(&self) -> Option<&str> {
56        match self {
57            Self::Message(s) => Some(s),
58            _ => None,
59        }
60    }
61}
62
63/// Type alias for streaming result
64pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, Error>> + Send>>;
65
66/// A wrapper for streaming responses with utility methods
67pub struct StreamingResponse {
68    inner: StreamingResult,
69}
70
71impl StreamingResponse {
72    /// Create from a stream
73    pub fn new(stream: StreamingResult) -> Self {
74        Self { inner: stream }
75    }
76
77    /// Create from any stream that implements the right traits
78    pub fn from_stream<S>(stream: S) -> Self
79    where
80        S: Stream<Item = Result<StreamingChoice, Error>> + Send + 'static,
81    {
82        Self {
83            inner: Box::pin(stream),
84        }
85    }
86
87    /// Collect all message chunks into a single string
88    pub async fn collect_text(mut self) -> Result<String, Error> {
89        use futures::StreamExt;
90
91        let mut result = String::new();
92        while let Some(chunk) = self.inner.next().await {
93            match chunk? {
94                StreamingChoice::Message(text) => result.push_str(&text),
95                StreamingChoice::Done => break,
96                _ => {}
97            }
98        }
99        Ok(result)
100    }
101
102    /// Get the inner stream
103    pub fn into_inner(self) -> StreamingResult {
104        self.inner
105    }
106}
107
108impl Stream for StreamingResponse {
109    type Item = Result<StreamingChoice, Error>;
110
111    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112        Pin::new(&mut self.inner).poll_next(cx)
113    }
114}
115
116/// Builder for creating mock streams (useful for testing)
117pub struct MockStreamBuilder {
118    chunks: Vec<Result<StreamingChoice, Error>>,
119}
120
121impl Default for MockStreamBuilder {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl MockStreamBuilder {
128    /// Create a new builder
129    pub fn new() -> Self {
130        Self { chunks: Vec::new() }
131    }
132
133    /// Add a message chunk
134    pub fn message(mut self, text: impl Into<String>) -> Self {
135        self.chunks.push(Ok(StreamingChoice::Message(text.into())));
136        self
137    }
138
139    /// Add a tool call
140    pub fn tool_call(
141        mut self,
142        id: impl Into<String>,
143        name: impl Into<String>,
144        arguments: serde_json::Value,
145    ) -> Self {
146        self.chunks.push(Ok(StreamingChoice::ToolCall {
147            id: id.into(),
148            name: name.into(),
149            arguments,
150        }));
151        self
152    }
153
154    /// Add done marker
155    pub fn done(mut self) -> Self {
156        self.chunks.push(Ok(StreamingChoice::Done));
157        self
158    }
159
160    /// Add an error
161    pub fn error(mut self, error: Error) -> Self {
162        self.chunks.push(Err(error));
163        self
164    }
165
166    /// Build the stream
167    pub fn build(self) -> StreamingResponse {
168        StreamingResponse::from_stream(futures::stream::iter(self.chunks))
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use futures::StreamExt;
176
177    #[tokio::test]
178    async fn test_streaming_response() {
179        let stream = MockStreamBuilder::new()
180            .message("Hello, ")
181            .message("world!")
182            .done()
183            .build();
184
185        let text = stream.collect_text().await.expect("collect should succeed");
186        assert_eq!(text, "Hello, world!");
187    }
188
189    #[tokio::test]
190    async fn test_stream_iteration() {
191        let mut stream = MockStreamBuilder::new()
192            .message("chunk1")
193            .message("chunk2")
194            .done()
195            .build();
196
197        let mut messages = Vec::new();
198        while let Some(chunk) = stream.next().await {
199            if let Ok(StreamingChoice::Message(text)) = chunk {
200                messages.push(text);
201            }
202        }
203
204        assert_eq!(messages, vec!["chunk1", "chunk2"]);
205    }
206}