aagt_core/agent/
streaming.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use futures::Stream;
9
10use crate::error::Error;
11use crate::agent::message::ToolCall;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct Usage {
16 pub prompt_tokens: u32,
18 pub completion_tokens: u32,
20 pub total_tokens: u32,
22}
23
24#[derive(Debug, Clone)]
26pub enum StreamingChoice {
27 Message(String),
29
30 ToolCall {
32 id: String,
34 name: String,
36 arguments: serde_json::Value,
38 },
39
40 ParallelToolCalls(HashMap<usize, ToolCall>),
42
43 Thought(String),
45
46 Usage(Usage),
48
49 Done,
51}
52
53impl StreamingChoice {
54 pub fn is_message(&self) -> bool {
56 matches!(self, Self::Message(_))
57 }
58
59 pub fn is_tool_call(&self) -> bool {
61 matches!(self, Self::ToolCall { .. } | Self::ParallelToolCalls(_))
62 }
63
64 pub fn is_done(&self) -> bool {
66 matches!(self, Self::Done)
67 }
68
69 pub fn as_message(&self) -> Option<&str> {
71 match self {
72 Self::Message(s) => Some(s),
73 _ => None,
74 }
75 }
76}
77
78pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, Error>> + Send>>;
80
81pub struct StreamingResponse {
83 inner: StreamingResult,
84}
85
86impl StreamingResponse {
87 pub fn new(stream: StreamingResult) -> Self {
89 Self { inner: stream }
90 }
91
92 pub fn from_stream<S>(stream: S) -> Self
94 where
95 S: Stream<Item = Result<StreamingChoice, Error>> + Send + 'static,
96 {
97 Self {
98 inner: Box::pin(stream),
99 }
100 }
101
102 pub async fn collect_text(mut self) -> Result<String, Error> {
104 use futures::StreamExt;
105
106 let mut result = String::new();
107 while let Some(chunk) = self.inner.next().await {
108 match chunk? {
109 StreamingChoice::Message(text) => result.push_str(&text),
110 StreamingChoice::Done => break,
111 _ => {}
112 }
113 }
114 Ok(result)
115 }
116
117 pub fn into_inner(self) -> StreamingResult {
119 self.inner
120 }
121}
122
123impl Stream for StreamingResponse {
124 type Item = Result<StreamingChoice, Error>;
125
126 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 Pin::new(&mut self.inner).poll_next(cx)
128 }
129}
130
131pub struct MockStreamBuilder {
133 chunks: Vec<Result<StreamingChoice, Error>>,
134}
135
136impl Default for MockStreamBuilder {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl MockStreamBuilder {
143 pub fn new() -> Self {
145 Self { chunks: Vec::new() }
146 }
147
148 pub fn message(mut self, text: impl Into<String>) -> Self {
150 self.chunks.push(Ok(StreamingChoice::Message(text.into())));
151 self
152 }
153
154 pub fn tool_call(
156 mut self,
157 id: impl Into<String>,
158 name: impl Into<String>,
159 arguments: serde_json::Value,
160 ) -> Self {
161 self.chunks.push(Ok(StreamingChoice::ToolCall {
162 id: id.into(),
163 name: name.into(),
164 arguments,
165 }));
166 self
167 }
168
169 pub fn done(mut self) -> Self {
171 self.chunks.push(Ok(StreamingChoice::Done));
172 self
173 }
174
175 pub fn error(mut self, error: Error) -> Self {
177 self.chunks.push(Err(error));
178 self
179 }
180
181 pub fn usage(mut self, usage: Usage) -> Self {
183 self.chunks.push(Ok(StreamingChoice::Usage(usage)));
184 self
185 }
186
187 pub fn build(self) -> StreamingResponse {
189 StreamingResponse::from_stream(futures::stream::iter(self.chunks))
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use futures::StreamExt;
197
198 #[tokio::test]
199 async fn test_streaming_response() {
200 let stream = MockStreamBuilder::new()
201 .message("Hello, ")
202 .message("world!")
203 .done()
204 .build();
205
206 let text = stream.collect_text().await.expect("collect should succeed");
207 assert_eq!(text, "Hello, world!");
208 }
209
210 #[tokio::test]
211 async fn test_stream_iteration() {
212 let mut stream = MockStreamBuilder::new()
213 .message("chunk1")
214 .message("chunk2")
215 .done()
216 .build();
217
218 let mut messages = Vec::new();
219 while let Some(chunk) = stream.next().await {
220 if let Ok(StreamingChoice::Message(text)) = chunk {
221 messages.push(text);
222 }
223 }
224
225 assert_eq!(messages, vec!["chunk1", "chunk2"]);
226 }
227}