Skip to main content

cognis_core/
stream.rs

1//! Streaming primitives for cognis: token-output streams and structured
2//! event streams.
3
4use std::pin::Pin;
5
6use futures::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use uuid::Uuid;
9
10/// A structured event emitted by `stream_events()` — exposes per-step
11/// graph activity, tool calls, token deltas, and errors.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(tag = "type")]
14pub enum Event {
15    /// A `Runnable` started.
16    OnStart {
17        /// Name of the runnable that started.
18        runnable: String,
19        /// Correlation ID for this run.
20        run_id: Uuid,
21        /// Serialized input value.
22        input: serde_json::Value,
23    },
24    /// A graph node started.
25    OnNodeStart {
26        /// Node name.
27        node: String,
28        /// Superstep number.
29        step: u64,
30        /// Correlation ID for this run.
31        run_id: Uuid,
32    },
33    /// A graph node finished.
34    OnNodeEnd {
35        /// Node name.
36        node: String,
37        /// Superstep number.
38        step: u64,
39        /// Serialized output value.
40        output: serde_json::Value,
41        /// Correlation ID for this run.
42        run_id: Uuid,
43    },
44    /// LLM emitted a token.
45    OnLlmToken {
46        /// The token text.
47        token: String,
48        /// Correlation ID for this run.
49        run_id: Uuid,
50    },
51    /// Tool execution started.
52    OnToolStart {
53        /// Tool name.
54        tool: String,
55        /// Serialized arguments.
56        args: serde_json::Value,
57        /// Correlation ID for this run.
58        run_id: Uuid,
59    },
60    /// Tool execution finished.
61    OnToolEnd {
62        /// Tool name.
63        tool: String,
64        /// Serialized result.
65        result: serde_json::Value,
66        /// Correlation ID for this run.
67        run_id: Uuid,
68    },
69    /// A `Runnable` errored.
70    OnError {
71        /// Error description.
72        error: String,
73        /// Correlation ID for this run.
74        run_id: Uuid,
75    },
76    /// A `Runnable` finished successfully.
77    OnEnd {
78        /// Name of the runnable that finished.
79        runnable: String,
80        /// Correlation ID for this run.
81        run_id: Uuid,
82        /// Serialized output value.
83        output: serde_json::Value,
84    },
85    /// A graph engine persisted a checkpoint at a superstep boundary.
86    OnCheckpoint {
87        /// Step number that was just persisted.
88        step: u64,
89        /// Correlation ID for this run.
90        run_id: Uuid,
91    },
92    /// User-emitted event from a graph node via `NodeCtx::write_custom`.
93    /// Carries an arbitrary `kind` label and a JSON payload — the consumer
94    /// decides how to interpret it. Used by `StreamMode::Custom` to surface
95    /// node-authored progress signals without cluttering the typed enum.
96    Custom {
97        /// Caller-defined label (e.g. `"progress"`, `"chunk"`).
98        kind: String,
99        /// Arbitrary JSON payload.
100        payload: serde_json::Value,
101        /// Correlation ID for this run.
102        run_id: Uuid,
103    },
104}
105
106/// Pluggable event sink. Multiple observers can subscribe to a single run.
107pub trait Observer: Send + Sync {
108    /// Called for every event emitted during execution. Implementations
109    /// should be cheap and non-blocking — a slow observer slows execution.
110    fn on_event(&self, event: &Event);
111}
112
113/// Convenience: any `Fn(&Event) + Send + Sync` is an `Observer`.
114impl<F> Observer for F
115where
116    F: Fn(&Event) + Send + Sync,
117{
118    fn on_event(&self, event: &Event) {
119        self(event)
120    }
121}
122
123/// A stream of structured events. Same shape as `RunnableStream<Event>`,
124/// but named separately to make stream-of-events vs stream-of-output
125/// distinguishable at the type level.
126pub struct EventStream(Pin<Box<dyn Stream<Item = Event> + Send>>);
127
128impl EventStream {
129    /// Wrap an arbitrary `Stream<Item = Event>`.
130    pub fn new(s: impl Stream<Item = Event> + Send + 'static) -> Self {
131        Self(Box::pin(s))
132    }
133}
134
135impl Stream for EventStream {
136    type Item = Event;
137    fn poll_next(
138        mut self: Pin<&mut Self>,
139        cx: &mut std::task::Context<'_>,
140    ) -> std::task::Poll<Option<Self::Item>> {
141        self.0.as_mut().poll_next(cx)
142    }
143}
144
145/// A stream of `Result<O>` items — the canonical output stream type for
146/// `Runnable::stream`. Wraps `Pin<Box<dyn Stream>>` for trait-object
147/// flexibility, with helper combinators on the wrapper.
148pub struct RunnableStream<O> {
149    inner: Pin<Box<dyn Stream<Item = crate::Result<O>> + Send>>,
150}
151
152impl<O> RunnableStream<O>
153where
154    O: Send + 'static,
155{
156    /// Wrap any `Stream<Item = Result<O>>`.
157    pub fn new(s: impl Stream<Item = crate::Result<O>> + Send + 'static) -> Self {
158        Self { inner: Box::pin(s) }
159    }
160
161    /// Build from a single value (one-shot stream).
162    pub fn once(value: crate::Result<O>) -> Self {
163        Self::new(futures::stream::once(async move { value }))
164    }
165
166    /// Collect all items into a `Vec`. Stops at the first `Err`.
167    pub async fn collect_into_vec(mut self) -> crate::Result<Vec<O>> {
168        let mut out = Vec::new();
169        while let Some(item) = self.inner.next().await {
170            out.push(item?);
171        }
172        Ok(out)
173    }
174
175    /// Apply a side-effect callback to each item (errors pass through unchanged).
176    pub fn with_callback<F>(self, f: F) -> Self
177    where
178        F: Fn(&O) + Send + Sync + 'static,
179    {
180        let inner = self.inner.map(move |item| {
181            if let Ok(ref v) = item {
182                f(v);
183            }
184            item
185        });
186        Self::new(inner)
187    }
188}
189
190impl<O> Stream for RunnableStream<O> {
191    type Item = crate::Result<O>;
192    fn poll_next(
193        mut self: Pin<&mut Self>,
194        cx: &mut std::task::Context<'_>,
195    ) -> std::task::Poll<Option<Self::Item>> {
196        self.inner.as_mut().poll_next(cx)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use std::sync::atomic::{AtomicUsize, Ordering};
204    use std::sync::Arc;
205
206    #[test]
207    fn fn_observer_works() {
208        let count = Arc::new(AtomicUsize::new(0));
209        let count2 = count.clone();
210        let observer: Arc<dyn Observer> = Arc::new(move |e: &Event| {
211            if matches!(e, Event::OnStart { .. } | Event::OnEnd { .. }) {
212                count2.fetch_add(1, Ordering::SeqCst);
213            }
214        });
215
216        let e = Event::OnStart {
217            runnable: "x".into(),
218            run_id: Uuid::nil(),
219            input: serde_json::json!({}),
220        };
221        observer.on_event(&e);
222        observer.on_event(&e);
223        assert_eq!(count.load(Ordering::SeqCst), 2);
224    }
225
226    #[test]
227    fn event_serialization_tagged() {
228        let e = Event::OnLlmToken {
229            token: "hi".into(),
230            run_id: Uuid::nil(),
231        };
232        let s = serde_json::to_string(&e).unwrap();
233        assert!(s.contains("\"type\":\"OnLlmToken\""));
234        assert!(s.contains("\"token\":\"hi\""));
235    }
236
237    #[tokio::test]
238    async fn runnable_stream_collect() {
239        let s = RunnableStream::new(futures::stream::iter(vec![Ok(1u32), Ok(2), Ok(3)]));
240        let v = s.collect_into_vec().await.unwrap();
241        assert_eq!(v, vec![1, 2, 3]);
242    }
243
244    #[tokio::test]
245    async fn runnable_stream_callback() {
246        let counter = Arc::new(AtomicUsize::new(0));
247        let counter2 = counter.clone();
248        let s = RunnableStream::new(futures::stream::iter(vec![Ok(10u32), Ok(20)])).with_callback(
249            move |v| {
250                counter2.fetch_add(*v as usize, Ordering::SeqCst);
251            },
252        );
253        let _ = s.collect_into_vec().await.unwrap();
254        assert_eq!(counter.load(Ordering::SeqCst), 30);
255    }
256
257    #[tokio::test]
258    async fn runnable_stream_short_circuits_on_error() {
259        let s: RunnableStream<u32> = RunnableStream::new(futures::stream::iter(vec![
260            Ok(1),
261            Err(crate::CognisError::Internal("stop".into())),
262            Ok(3),
263        ]));
264        let result = s.collect_into_vec().await;
265        assert!(result.is_err());
266    }
267}