1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_core::Stream;
6
7use crate::{AgentError, AgentEvent, Input, TurnOptions, TurnResult, Usage};
8
9pub trait Session: Send {
10 fn turn(
14 &mut self,
15 input: Input,
16 options: TurnOptions,
17 ) -> Pin<Box<dyn Future<Output = Result<TurnOutput, AgentError>> + Send + '_>> {
18 Box::pin(async move { self.turn_stream(input, options).await?.collect_turn().await })
19 }
20
21 fn turn_stream(
23 &mut self,
24 input: Input,
25 options: TurnOptions,
26 ) -> Pin<Box<dyn Future<Output = Result<EventStream, AgentError>> + Send + '_>>;
27
28 fn session_id(&self) -> Option<&str>;
30
31 fn interrupt(&mut self) -> Pin<Box<dyn Future<Output = Result<(), AgentError>> + Send + '_>>;
33}
34
35#[must_use = "streams do nothing unless polled"]
36pub struct EventStream {
37 inner: Pin<Box<dyn Stream<Item = Result<AgentEvent, AgentError>> + Send>>,
38}
39
40impl EventStream {
41 pub fn new(
42 stream: impl Stream<Item = Result<AgentEvent, AgentError>> + Send + 'static,
43 ) -> Self {
44 Self {
45 inner: Box::pin(stream),
46 }
47 }
48
49 pub fn from_receiver(rx: tokio::sync::mpsc::Receiver<Result<AgentEvent, AgentError>>) -> Self {
50 Self::new(ReceiverStream { inner: rx })
51 }
52
53 pub async fn collect_turn(mut self) -> Result<TurnOutput, AgentError> {
55 let mut events = Vec::new();
56 let mut response = None;
57 let mut response_deltas = String::new();
58 let mut saw_delta = false;
59 let mut usage = None;
60 let mut result = None;
61
62 loop {
63 match std::future::poll_fn(|cx| self.inner.as_mut().poll_next(cx)).await {
64 Some(Ok(event)) => {
65 match &event {
66 AgentEvent::Message { text, .. } => {
67 response = Some(text.clone());
68 }
69 AgentEvent::TextDelta { delta, .. } => {
70 response_deltas.push_str(delta);
71 saw_delta = true;
72 }
73 AgentEvent::TurnCompleted {
74 usage: u,
75 result: r,
76 } => {
77 usage.clone_from(u);
78 result.clone_from(r);
79 }
80 AgentEvent::TurnFailed { message } => {
81 return Err(AgentError::TurnFailed {
82 message: message.clone(),
83 });
84 }
85 _ => {}
86 }
87 events.push(event);
88 }
89 Some(Err(e)) => return Err(e),
90 None => break,
91 }
92 }
93
94 if response.is_none() && saw_delta {
95 response = Some(response_deltas);
96 }
97
98 Ok(TurnOutput {
99 events,
100 response,
101 usage,
102 result,
103 })
104 }
105}
106
107impl Stream for EventStream {
108 type Item = Result<AgentEvent, AgentError>;
109
110 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 self.inner.as_mut().poll_next(cx)
112 }
113}
114
115struct ReceiverStream {
117 inner: tokio::sync::mpsc::Receiver<Result<AgentEvent, AgentError>>,
118}
119
120impl Stream for ReceiverStream {
121 type Item = Result<AgentEvent, AgentError>;
122
123 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124 self.inner.poll_recv(cx)
125 }
126}
127
128#[derive(Debug, Clone)]
129pub struct TurnOutput {
130 pub events: Vec<AgentEvent>,
131 pub response: Option<String>,
132 pub usage: Option<Usage>,
133 pub result: Option<TurnResult>,
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[tokio::test]
141 async fn collect_turn_extracts_response() {
142 let (tx, rx) = tokio::sync::mpsc::channel(16);
143
144 tx.send(Ok(AgentEvent::TurnStarted)).await.unwrap();
145 tx.send(Ok(AgentEvent::Message {
146 id: "msg-1".into(),
147 text: "Hello world".into(),
148 }))
149 .await
150 .unwrap();
151 tx.send(Ok(AgentEvent::TurnCompleted {
152 usage: Some(Usage {
153 input_tokens: Some(10),
154 output_tokens: Some(5),
155 ..Default::default()
156 }),
157 result: None,
158 }))
159 .await
160 .unwrap();
161 drop(tx);
162
163 let stream = EventStream::from_receiver(rx);
164 let output = stream.collect_turn().await.unwrap();
165
166 assert_eq!(output.response.as_deref(), Some("Hello world"));
167 assert_eq!(output.events.len(), 3);
168 assert_eq!(output.usage.as_ref().unwrap().input_tokens, Some(10));
169 }
170
171 #[tokio::test]
172 async fn collect_turn_fails_on_turn_failed() {
173 let (tx, rx) = tokio::sync::mpsc::channel(16);
174
175 tx.send(Ok(AgentEvent::TurnStarted)).await.unwrap();
176 tx.send(Ok(AgentEvent::TurnFailed {
177 message: "something broke".into(),
178 }))
179 .await
180 .unwrap();
181 drop(tx);
182
183 let stream = EventStream::from_receiver(rx);
184 let err = stream.collect_turn().await.unwrap_err();
185
186 match err {
187 AgentError::TurnFailed { message } => assert_eq!(message, "something broke"),
188 _ => panic!("expected TurnFailed, got {err:?}"),
189 }
190 }
191
192 #[tokio::test]
193 async fn collect_turn_uses_last_message() {
194 let (tx, rx) = tokio::sync::mpsc::channel(16);
195
196 tx.send(Ok(AgentEvent::Message {
197 id: "1".into(),
198 text: "first".into(),
199 }))
200 .await
201 .unwrap();
202 tx.send(Ok(AgentEvent::Message {
203 id: "2".into(),
204 text: "second".into(),
205 }))
206 .await
207 .unwrap();
208 tx.send(Ok(AgentEvent::TurnCompleted {
209 usage: None,
210 result: None,
211 }))
212 .await
213 .unwrap();
214 drop(tx);
215
216 let stream = EventStream::from_receiver(rx);
217 let output = stream.collect_turn().await.unwrap();
218
219 assert_eq!(output.response.as_deref(), Some("second"));
220 }
221
222 #[tokio::test]
223 async fn collect_turn_uses_deltas_when_no_message() {
224 let (tx, rx) = tokio::sync::mpsc::channel(16);
225
226 tx.send(Ok(AgentEvent::TextDelta {
227 id: "block_0".into(),
228 delta: "Hello".into(),
229 }))
230 .await
231 .unwrap();
232 tx.send(Ok(AgentEvent::TextDelta {
233 id: "block_0".into(),
234 delta: " world".into(),
235 }))
236 .await
237 .unwrap();
238 tx.send(Ok(AgentEvent::TurnCompleted {
239 usage: None,
240 result: None,
241 }))
242 .await
243 .unwrap();
244 drop(tx);
245
246 let stream = EventStream::from_receiver(rx);
247 let output = stream.collect_turn().await.unwrap();
248
249 assert_eq!(output.response.as_deref(), Some("Hello world"));
250 }
251
252 #[tokio::test]
253 async fn collect_turn_propagates_stream_error() {
254 let (tx, rx) = tokio::sync::mpsc::channel(16);
255
256 tx.send(Ok(AgentEvent::TurnStarted)).await.unwrap();
257 tx.send(Err(AgentError::ProcessFailed {
258 code: 1,
259 stderr: "crash".into(),
260 }))
261 .await
262 .unwrap();
263 drop(tx);
264
265 let stream = EventStream::from_receiver(rx);
266 let err = stream.collect_turn().await.unwrap_err();
267
268 match err {
269 AgentError::ProcessFailed { code, .. } => assert_eq!(code, 1),
270 _ => panic!("expected ProcessFailed, got {err:?}"),
271 }
272 }
273
274 #[tokio::test]
275 async fn empty_stream_produces_empty_output() {
276 let (_tx, rx) = tokio::sync::mpsc::channel::<Result<AgentEvent, AgentError>>(1);
277 drop(_tx);
278
279 let stream = EventStream::from_receiver(rx);
280 let output = stream.collect_turn().await.unwrap();
281
282 assert!(output.events.is_empty());
283 assert!(output.response.is_none());
284 assert!(output.usage.is_none());
285 }
286}