Skip to main content

aagt_core/agent/
streaming.rs

1//! Streaming response types
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use futures::Stream;
9
10use crate::error::Error;
11use crate::agent::message::ToolCall;
12
13/// Token usage information
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct Usage {
16    /// Number of tokens in the prompt
17    pub prompt_tokens: u32,
18    /// Number of tokens in the completion
19    pub completion_tokens: u32,
20    /// Total number of tokens
21    pub total_tokens: u32,
22}
23
24/// A chunk from a streaming response
25#[derive(Debug, Clone)]
26pub enum StreamingChoice {
27    /// Text content chunk
28    Message(String),
29
30    /// Single tool call (sequential)
31    ToolCall {
32        /// Tool call ID
33        id: String,
34        /// Tool name
35        name: String,
36        /// Arguments as JSON
37        arguments: serde_json::Value,
38    },
39
40    /// Multiple tool calls (parallel)
41    ParallelToolCalls(HashMap<usize, ToolCall>),
42
43    /// Thinking/reasoning chunk (e.g., Gemini's thoughts)
44    Thought(String),
45
46    /// Usage information (emitted at the end)
47    Usage(Usage),
48
49    /// Stream finished
50    Done,
51}
52
53impl StreamingChoice {
54    /// Check if this is a message chunk
55    pub fn is_message(&self) -> bool {
56        matches!(self, Self::Message(_))
57    }
58
59    /// Check if this is a tool call
60    pub fn is_tool_call(&self) -> bool {
61        matches!(self, Self::ToolCall { .. } | Self::ParallelToolCalls(_))
62    }
63
64    /// Check if stream is done
65    pub fn is_done(&self) -> bool {
66        matches!(self, Self::Done)
67    }
68
69    /// Get message text if this is a message chunk
70    pub fn as_message(&self) -> Option<&str> {
71        match self {
72            Self::Message(s) => Some(s),
73            _ => None,
74        }
75    }
76}
77
78/// Type alias for streaming result
79pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, Error>> + Send>>;
80
81/// A wrapper for streaming responses with utility methods
82pub struct StreamingResponse {
83    inner: StreamingResult,
84}
85
86impl StreamingResponse {
87    /// Create from a stream
88    pub fn new(stream: StreamingResult) -> Self {
89        Self { inner: stream }
90    }
91
92    /// Create from any stream that implements the right traits
93    pub fn from_stream<S>(stream: S) -> Self
94    where
95        S: Stream<Item = Result<StreamingChoice, Error>> + Send + 'static,
96    {
97        Self {
98            inner: Box::pin(stream),
99        }
100    }
101
102    /// Collect all message chunks into a single string
103    pub async fn collect_text(mut self) -> Result<String, Error> {
104        use futures::StreamExt;
105
106        let mut result = String::new();
107        while let Some(chunk) = self.inner.next().await {
108            match chunk? {
109                StreamingChoice::Message(text) => result.push_str(&text),
110                StreamingChoice::Done => break,
111                _ => {}
112            }
113        }
114        Ok(result)
115    }
116
117    /// Get the inner stream
118    pub fn into_inner(self) -> StreamingResult {
119        self.inner
120    }
121}
122
123impl Stream for StreamingResponse {
124    type Item = Result<StreamingChoice, Error>;
125
126    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        Pin::new(&mut self.inner).poll_next(cx)
128    }
129}
130
131/// Builder for creating mock streams (useful for testing)
132pub struct MockStreamBuilder {
133    chunks: Vec<Result<StreamingChoice, Error>>,
134}
135
136impl Default for MockStreamBuilder {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl MockStreamBuilder {
143    /// Create a new builder
144    pub fn new() -> Self {
145        Self { chunks: Vec::new() }
146    }
147
148    /// Add a message chunk
149    pub fn message(mut self, text: impl Into<String>) -> Self {
150        self.chunks.push(Ok(StreamingChoice::Message(text.into())));
151        self
152    }
153
154    /// Add a tool call
155    pub fn tool_call(
156        mut self,
157        id: impl Into<String>,
158        name: impl Into<String>,
159        arguments: serde_json::Value,
160    ) -> Self {
161        self.chunks.push(Ok(StreamingChoice::ToolCall {
162            id: id.into(),
163            name: name.into(),
164            arguments,
165        }));
166        self
167    }
168
169    /// Add done marker
170    pub fn done(mut self) -> Self {
171        self.chunks.push(Ok(StreamingChoice::Done));
172        self
173    }
174
175    /// Add an error
176    pub fn error(mut self, error: Error) -> Self {
177        self.chunks.push(Err(error));
178        self
179    }
180
181    /// Add usage info
182    pub fn usage(mut self, usage: Usage) -> Self {
183        self.chunks.push(Ok(StreamingChoice::Usage(usage)));
184        self
185    }
186
187    /// Build the stream
188    pub fn build(self) -> StreamingResponse {
189        StreamingResponse::from_stream(futures::stream::iter(self.chunks))
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use futures::StreamExt;
197
198    #[tokio::test]
199    async fn test_streaming_response() {
200        let stream = MockStreamBuilder::new()
201            .message("Hello, ")
202            .message("world!")
203            .done()
204            .build();
205
206        let text = stream.collect_text().await.expect("collect should succeed");
207        assert_eq!(text, "Hello, world!");
208    }
209
210    #[tokio::test]
211    async fn test_stream_iteration() {
212        let mut stream = MockStreamBuilder::new()
213            .message("chunk1")
214            .message("chunk2")
215            .done()
216            .build();
217
218        let mut messages = Vec::new();
219        while let Some(chunk) = stream.next().await {
220            if let Ok(StreamingChoice::Message(text)) = chunk {
221                messages.push(text);
222            }
223        }
224
225        assert_eq!(messages, vec!["chunk1", "chunk2"]);
226    }
227}