aagt_core/agent/
streaming.rs1use std::collections::HashMap;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use futures::Stream;
8
9use crate::error::Error;
10use crate::agent::message::ToolCall;
11
12#[derive(Debug, Clone)]
14pub enum StreamingChoice {
15 Message(String),
17
18 ToolCall {
20 id: String,
22 name: String,
24 arguments: serde_json::Value,
26 },
27
28 ParallelToolCalls(HashMap<usize, ToolCall>),
30
31 Thought(String),
33
34 Done,
36}
37
38impl StreamingChoice {
39 pub fn is_message(&self) -> bool {
41 matches!(self, Self::Message(_))
42 }
43
44 pub fn is_tool_call(&self) -> bool {
46 matches!(self, Self::ToolCall { .. } | Self::ParallelToolCalls(_))
47 }
48
49 pub fn is_done(&self) -> bool {
51 matches!(self, Self::Done)
52 }
53
54 pub fn as_message(&self) -> Option<&str> {
56 match self {
57 Self::Message(s) => Some(s),
58 _ => None,
59 }
60 }
61}
62
63pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, Error>> + Send>>;
65
66pub struct StreamingResponse {
68 inner: StreamingResult,
69}
70
71impl StreamingResponse {
72 pub fn new(stream: StreamingResult) -> Self {
74 Self { inner: stream }
75 }
76
77 pub fn from_stream<S>(stream: S) -> Self
79 where
80 S: Stream<Item = Result<StreamingChoice, Error>> + Send + 'static,
81 {
82 Self {
83 inner: Box::pin(stream),
84 }
85 }
86
87 pub async fn collect_text(mut self) -> Result<String, Error> {
89 use futures::StreamExt;
90
91 let mut result = String::new();
92 while let Some(chunk) = self.inner.next().await {
93 match chunk? {
94 StreamingChoice::Message(text) => result.push_str(&text),
95 StreamingChoice::Done => break,
96 _ => {}
97 }
98 }
99 Ok(result)
100 }
101
102 pub fn into_inner(self) -> StreamingResult {
104 self.inner
105 }
106}
107
108impl Stream for StreamingResponse {
109 type Item = Result<StreamingChoice, Error>;
110
111 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 Pin::new(&mut self.inner).poll_next(cx)
113 }
114}
115
116pub struct MockStreamBuilder {
118 chunks: Vec<Result<StreamingChoice, Error>>,
119}
120
121impl Default for MockStreamBuilder {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl MockStreamBuilder {
128 pub fn new() -> Self {
130 Self { chunks: Vec::new() }
131 }
132
133 pub fn message(mut self, text: impl Into<String>) -> Self {
135 self.chunks.push(Ok(StreamingChoice::Message(text.into())));
136 self
137 }
138
139 pub fn tool_call(
141 mut self,
142 id: impl Into<String>,
143 name: impl Into<String>,
144 arguments: serde_json::Value,
145 ) -> Self {
146 self.chunks.push(Ok(StreamingChoice::ToolCall {
147 id: id.into(),
148 name: name.into(),
149 arguments,
150 }));
151 self
152 }
153
154 pub fn done(mut self) -> Self {
156 self.chunks.push(Ok(StreamingChoice::Done));
157 self
158 }
159
160 pub fn error(mut self, error: Error) -> Self {
162 self.chunks.push(Err(error));
163 self
164 }
165
166 pub fn build(self) -> StreamingResponse {
168 StreamingResponse::from_stream(futures::stream::iter(self.chunks))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use futures::StreamExt;
176
177 #[tokio::test]
178 async fn test_streaming_response() {
179 let stream = MockStreamBuilder::new()
180 .message("Hello, ")
181 .message("world!")
182 .done()
183 .build();
184
185 let text = stream.collect_text().await.expect("collect should succeed");
186 assert_eq!(text, "Hello, world!");
187 }
188
189 #[tokio::test]
190 async fn test_stream_iteration() {
191 let mut stream = MockStreamBuilder::new()
192 .message("chunk1")
193 .message("chunk2")
194 .done()
195 .build();
196
197 let mut messages = Vec::new();
198 while let Some(chunk) = stream.next().await {
199 if let Ok(StreamingChoice::Message(text)) = chunk {
200 messages.push(text);
201 }
202 }
203
204 assert_eq!(messages, vec!["chunk1", "chunk2"]);
205 }
206}