Skip to main content

chimera_core/
session.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_core::Stream;
6
7use crate::{AgentError, AgentEvent, Input, TurnOptions, TurnResult, Usage};
8
9pub trait Session: Send {
10    /// Run a turn and return the buffered result.
11    ///
12    /// Default implementation calls `turn_stream` then `collect_turn`.
13    fn turn(
14        &mut self,
15        input: Input,
16        options: TurnOptions,
17    ) -> Pin<Box<dyn Future<Output = Result<TurnOutput, AgentError>> + Send + '_>> {
18        Box::pin(async move { self.turn_stream(input, options).await?.collect_turn().await })
19    }
20
21    /// Run a turn and return a stream of events.
22    fn turn_stream(
23        &mut self,
24        input: Input,
25        options: TurnOptions,
26    ) -> Pin<Box<dyn Future<Output = Result<EventStream, AgentError>> + Send + '_>>;
27
28    /// Session ID (populated after the first turn completes).
29    fn session_id(&self) -> Option<&str>;
30
31    /// Interrupt the current turn. No-op if idle.
32    fn interrupt(&mut self) -> Pin<Box<dyn Future<Output = Result<(), AgentError>> + Send + '_>>;
33}
34
35#[must_use = "streams do nothing unless polled"]
36pub struct EventStream {
37    inner: Pin<Box<dyn Stream<Item = Result<AgentEvent, AgentError>> + Send>>,
38}
39
40impl EventStream {
41    pub fn new(
42        stream: impl Stream<Item = Result<AgentEvent, AgentError>> + Send + 'static,
43    ) -> Self {
44        Self {
45            inner: Box::pin(stream),
46        }
47    }
48
49    pub fn from_receiver(rx: tokio::sync::mpsc::Receiver<Result<AgentEvent, AgentError>>) -> Self {
50        Self::new(ReceiverStream { inner: rx })
51    }
52
53    /// Consume the stream and collect into a `TurnOutput`.
54    pub async fn collect_turn(mut self) -> Result<TurnOutput, AgentError> {
55        let mut events = Vec::new();
56        let mut response = None;
57        let mut response_deltas = String::new();
58        let mut saw_delta = false;
59        let mut usage = None;
60        let mut result = None;
61
62        loop {
63            match std::future::poll_fn(|cx| self.inner.as_mut().poll_next(cx)).await {
64                Some(Ok(event)) => {
65                    match &event {
66                        AgentEvent::Message { text, .. } => {
67                            response = Some(text.clone());
68                        }
69                        AgentEvent::TextDelta { delta, .. } => {
70                            response_deltas.push_str(delta);
71                            saw_delta = true;
72                        }
73                        AgentEvent::TurnCompleted {
74                            usage: u,
75                            result: r,
76                        } => {
77                            usage.clone_from(u);
78                            result.clone_from(r);
79                        }
80                        AgentEvent::TurnFailed { message } => {
81                            return Err(AgentError::TurnFailed {
82                                message: message.clone(),
83                            });
84                        }
85                        _ => {}
86                    }
87                    events.push(event);
88                }
89                Some(Err(e)) => return Err(e),
90                None => break,
91            }
92        }
93
94        if response.is_none() && saw_delta {
95            response = Some(response_deltas);
96        }
97
98        Ok(TurnOutput {
99            events,
100            response,
101            usage,
102            result,
103        })
104    }
105}
106
107impl Stream for EventStream {
108    type Item = Result<AgentEvent, AgentError>;
109
110    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        self.inner.as_mut().poll_next(cx)
112    }
113}
114
115/// Adapts `mpsc::Receiver` to `Stream` without pulling in `tokio-stream`.
116struct ReceiverStream {
117    inner: tokio::sync::mpsc::Receiver<Result<AgentEvent, AgentError>>,
118}
119
120impl Stream for ReceiverStream {
121    type Item = Result<AgentEvent, AgentError>;
122
123    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124        self.inner.poll_recv(cx)
125    }
126}
127
128#[derive(Debug, Clone)]
129pub struct TurnOutput {
130    pub events: Vec<AgentEvent>,
131    pub response: Option<String>,
132    pub usage: Option<Usage>,
133    pub result: Option<TurnResult>,
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[tokio::test]
141    async fn collect_turn_extracts_response() {
142        let (tx, rx) = tokio::sync::mpsc::channel(16);
143
144        tx.send(Ok(AgentEvent::TurnStarted)).await.unwrap();
145        tx.send(Ok(AgentEvent::Message {
146            id: "msg-1".into(),
147            text: "Hello world".into(),
148        }))
149        .await
150        .unwrap();
151        tx.send(Ok(AgentEvent::TurnCompleted {
152            usage: Some(Usage {
153                input_tokens: Some(10),
154                output_tokens: Some(5),
155                ..Default::default()
156            }),
157            result: None,
158        }))
159        .await
160        .unwrap();
161        drop(tx);
162
163        let stream = EventStream::from_receiver(rx);
164        let output = stream.collect_turn().await.unwrap();
165
166        assert_eq!(output.response.as_deref(), Some("Hello world"));
167        assert_eq!(output.events.len(), 3);
168        assert_eq!(output.usage.as_ref().unwrap().input_tokens, Some(10));
169    }
170
171    #[tokio::test]
172    async fn collect_turn_fails_on_turn_failed() {
173        let (tx, rx) = tokio::sync::mpsc::channel(16);
174
175        tx.send(Ok(AgentEvent::TurnStarted)).await.unwrap();
176        tx.send(Ok(AgentEvent::TurnFailed {
177            message: "something broke".into(),
178        }))
179        .await
180        .unwrap();
181        drop(tx);
182
183        let stream = EventStream::from_receiver(rx);
184        let err = stream.collect_turn().await.unwrap_err();
185
186        match err {
187            AgentError::TurnFailed { message } => assert_eq!(message, "something broke"),
188            _ => panic!("expected TurnFailed, got {err:?}"),
189        }
190    }
191
192    #[tokio::test]
193    async fn collect_turn_uses_last_message() {
194        let (tx, rx) = tokio::sync::mpsc::channel(16);
195
196        tx.send(Ok(AgentEvent::Message {
197            id: "1".into(),
198            text: "first".into(),
199        }))
200        .await
201        .unwrap();
202        tx.send(Ok(AgentEvent::Message {
203            id: "2".into(),
204            text: "second".into(),
205        }))
206        .await
207        .unwrap();
208        tx.send(Ok(AgentEvent::TurnCompleted {
209            usage: None,
210            result: None,
211        }))
212        .await
213        .unwrap();
214        drop(tx);
215
216        let stream = EventStream::from_receiver(rx);
217        let output = stream.collect_turn().await.unwrap();
218
219        assert_eq!(output.response.as_deref(), Some("second"));
220    }
221
222    #[tokio::test]
223    async fn collect_turn_uses_deltas_when_no_message() {
224        let (tx, rx) = tokio::sync::mpsc::channel(16);
225
226        tx.send(Ok(AgentEvent::TextDelta {
227            id: "block_0".into(),
228            delta: "Hello".into(),
229        }))
230        .await
231        .unwrap();
232        tx.send(Ok(AgentEvent::TextDelta {
233            id: "block_0".into(),
234            delta: " world".into(),
235        }))
236        .await
237        .unwrap();
238        tx.send(Ok(AgentEvent::TurnCompleted {
239            usage: None,
240            result: None,
241        }))
242        .await
243        .unwrap();
244        drop(tx);
245
246        let stream = EventStream::from_receiver(rx);
247        let output = stream.collect_turn().await.unwrap();
248
249        assert_eq!(output.response.as_deref(), Some("Hello world"));
250    }
251
252    #[tokio::test]
253    async fn collect_turn_propagates_stream_error() {
254        let (tx, rx) = tokio::sync::mpsc::channel(16);
255
256        tx.send(Ok(AgentEvent::TurnStarted)).await.unwrap();
257        tx.send(Err(AgentError::ProcessFailed {
258            code: 1,
259            stderr: "crash".into(),
260        }))
261        .await
262        .unwrap();
263        drop(tx);
264
265        let stream = EventStream::from_receiver(rx);
266        let err = stream.collect_turn().await.unwrap_err();
267
268        match err {
269            AgentError::ProcessFailed { code, .. } => assert_eq!(code, 1),
270            _ => panic!("expected ProcessFailed, got {err:?}"),
271        }
272    }
273
274    #[tokio::test]
275    async fn empty_stream_produces_empty_output() {
276        let (_tx, rx) = tokio::sync::mpsc::channel::<Result<AgentEvent, AgentError>>(1);
277        drop(_tx);
278
279        let stream = EventStream::from_receiver(rx);
280        let output = stream.collect_turn().await.unwrap();
281
282        assert!(output.events.is_empty());
283        assert!(output.response.is_none());
284        assert!(output.usage.is_none());
285    }
286}