1use std::pin::Pin;
5
6use futures::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use uuid::Uuid;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(tag = "type")]
14pub enum Event {
15 OnStart {
17 runnable: String,
19 run_id: Uuid,
21 input: serde_json::Value,
23 },
24 OnNodeStart {
26 node: String,
28 step: u64,
30 run_id: Uuid,
32 },
33 OnNodeEnd {
35 node: String,
37 step: u64,
39 output: serde_json::Value,
41 run_id: Uuid,
43 },
44 OnLlmToken {
46 token: String,
48 run_id: Uuid,
50 },
51 OnToolStart {
53 tool: String,
55 args: serde_json::Value,
57 run_id: Uuid,
59 },
60 OnToolEnd {
62 tool: String,
64 result: serde_json::Value,
66 run_id: Uuid,
68 },
69 OnError {
71 error: String,
73 run_id: Uuid,
75 },
76 OnEnd {
78 runnable: String,
80 run_id: Uuid,
82 output: serde_json::Value,
84 },
85 OnCheckpoint {
87 step: u64,
89 run_id: Uuid,
91 },
92 Custom {
97 kind: String,
99 payload: serde_json::Value,
101 run_id: Uuid,
103 },
104}
105
106pub trait Observer: Send + Sync {
108 fn on_event(&self, event: &Event);
111}
112
113impl<F> Observer for F
115where
116 F: Fn(&Event) + Send + Sync,
117{
118 fn on_event(&self, event: &Event) {
119 self(event)
120 }
121}
122
123pub struct EventStream(Pin<Box<dyn Stream<Item = Event> + Send>>);
127
128impl EventStream {
129 pub fn new(s: impl Stream<Item = Event> + Send + 'static) -> Self {
131 Self(Box::pin(s))
132 }
133}
134
135impl Stream for EventStream {
136 type Item = Event;
137 fn poll_next(
138 mut self: Pin<&mut Self>,
139 cx: &mut std::task::Context<'_>,
140 ) -> std::task::Poll<Option<Self::Item>> {
141 self.0.as_mut().poll_next(cx)
142 }
143}
144
145pub struct RunnableStream<O> {
149 inner: Pin<Box<dyn Stream<Item = crate::Result<O>> + Send>>,
150}
151
152impl<O> RunnableStream<O>
153where
154 O: Send + 'static,
155{
156 pub fn new(s: impl Stream<Item = crate::Result<O>> + Send + 'static) -> Self {
158 Self { inner: Box::pin(s) }
159 }
160
161 pub fn once(value: crate::Result<O>) -> Self {
163 Self::new(futures::stream::once(async move { value }))
164 }
165
166 pub async fn collect_into_vec(mut self) -> crate::Result<Vec<O>> {
168 let mut out = Vec::new();
169 while let Some(item) = self.inner.next().await {
170 out.push(item?);
171 }
172 Ok(out)
173 }
174
175 pub fn with_callback<F>(self, f: F) -> Self
177 where
178 F: Fn(&O) + Send + Sync + 'static,
179 {
180 let inner = self.inner.map(move |item| {
181 if let Ok(ref v) = item {
182 f(v);
183 }
184 item
185 });
186 Self::new(inner)
187 }
188}
189
190impl<O> Stream for RunnableStream<O> {
191 type Item = crate::Result<O>;
192 fn poll_next(
193 mut self: Pin<&mut Self>,
194 cx: &mut std::task::Context<'_>,
195 ) -> std::task::Poll<Option<Self::Item>> {
196 self.inner.as_mut().poll_next(cx)
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use std::sync::atomic::{AtomicUsize, Ordering};
204 use std::sync::Arc;
205
206 #[test]
207 fn fn_observer_works() {
208 let count = Arc::new(AtomicUsize::new(0));
209 let count2 = count.clone();
210 let observer: Arc<dyn Observer> = Arc::new(move |e: &Event| {
211 if matches!(e, Event::OnStart { .. } | Event::OnEnd { .. }) {
212 count2.fetch_add(1, Ordering::SeqCst);
213 }
214 });
215
216 let e = Event::OnStart {
217 runnable: "x".into(),
218 run_id: Uuid::nil(),
219 input: serde_json::json!({}),
220 };
221 observer.on_event(&e);
222 observer.on_event(&e);
223 assert_eq!(count.load(Ordering::SeqCst), 2);
224 }
225
226 #[test]
227 fn event_serialization_tagged() {
228 let e = Event::OnLlmToken {
229 token: "hi".into(),
230 run_id: Uuid::nil(),
231 };
232 let s = serde_json::to_string(&e).unwrap();
233 assert!(s.contains("\"type\":\"OnLlmToken\""));
234 assert!(s.contains("\"token\":\"hi\""));
235 }
236
237 #[tokio::test]
238 async fn runnable_stream_collect() {
239 let s = RunnableStream::new(futures::stream::iter(vec![Ok(1u32), Ok(2), Ok(3)]));
240 let v = s.collect_into_vec().await.unwrap();
241 assert_eq!(v, vec![1, 2, 3]);
242 }
243
244 #[tokio::test]
245 async fn runnable_stream_callback() {
246 let counter = Arc::new(AtomicUsize::new(0));
247 let counter2 = counter.clone();
248 let s = RunnableStream::new(futures::stream::iter(vec![Ok(10u32), Ok(20)])).with_callback(
249 move |v| {
250 counter2.fetch_add(*v as usize, Ordering::SeqCst);
251 },
252 );
253 let _ = s.collect_into_vec().await.unwrap();
254 assert_eq!(counter.load(Ordering::SeqCst), 30);
255 }
256
257 #[tokio::test]
258 async fn runnable_stream_short_circuits_on_error() {
259 let s: RunnableStream<u32> = RunnableStream::new(futures::stream::iter(vec![
260 Ok(1),
261 Err(crate::CognisError::Internal("stop".into())),
262 Ok(3),
263 ]));
264 let result = s.collect_into_vec().await;
265 assert!(result.is_err());
266 }
267}