use std::pin::Pin;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Event {
OnStart {
runnable: String,
run_id: Uuid,
input: serde_json::Value,
},
OnNodeStart {
node: String,
step: u64,
run_id: Uuid,
},
OnNodeEnd {
node: String,
step: u64,
output: serde_json::Value,
run_id: Uuid,
},
OnLlmToken {
token: String,
run_id: Uuid,
},
OnToolStart {
tool: String,
args: serde_json::Value,
run_id: Uuid,
},
OnToolEnd {
tool: String,
result: serde_json::Value,
run_id: Uuid,
},
OnError {
error: String,
run_id: Uuid,
},
OnEnd {
runnable: String,
run_id: Uuid,
output: serde_json::Value,
},
OnCheckpoint {
step: u64,
run_id: Uuid,
},
Custom {
kind: String,
payload: serde_json::Value,
run_id: Uuid,
},
}
pub trait Observer: Send + Sync {
fn on_event(&self, event: &Event);
}
impl<F> Observer for F
where
F: Fn(&Event) + Send + Sync,
{
fn on_event(&self, event: &Event) {
self(event)
}
}
pub struct EventStream(Pin<Box<dyn Stream<Item = Event> + Send>>);
impl EventStream {
pub fn new(s: impl Stream<Item = Event> + Send + 'static) -> Self {
Self(Box::pin(s))
}
}
impl Stream for EventStream {
type Item = Event;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.0.as_mut().poll_next(cx)
}
}
pub struct RunnableStream<O> {
inner: Pin<Box<dyn Stream<Item = crate::Result<O>> + Send>>,
}
impl<O> RunnableStream<O>
where
O: Send + 'static,
{
pub fn new(s: impl Stream<Item = crate::Result<O>> + Send + 'static) -> Self {
Self { inner: Box::pin(s) }
}
pub fn once(value: crate::Result<O>) -> Self {
Self::new(futures::stream::once(async move { value }))
}
pub async fn collect_into_vec(mut self) -> crate::Result<Vec<O>> {
let mut out = Vec::new();
while let Some(item) = self.inner.next().await {
out.push(item?);
}
Ok(out)
}
pub fn with_callback<F>(self, f: F) -> Self
where
F: Fn(&O) + Send + Sync + 'static,
{
let inner = self.inner.map(move |item| {
if let Ok(ref v) = item {
f(v);
}
item
});
Self::new(inner)
}
}
impl<O> Stream for RunnableStream<O> {
type Item = crate::Result<O>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[test]
fn fn_observer_works() {
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let observer: Arc<dyn Observer> = Arc::new(move |e: &Event| {
if matches!(e, Event::OnStart { .. } | Event::OnEnd { .. }) {
count2.fetch_add(1, Ordering::SeqCst);
}
});
let e = Event::OnStart {
runnable: "x".into(),
run_id: Uuid::nil(),
input: serde_json::json!({}),
};
observer.on_event(&e);
observer.on_event(&e);
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[test]
fn event_serialization_tagged() {
let e = Event::OnLlmToken {
token: "hi".into(),
run_id: Uuid::nil(),
};
let s = serde_json::to_string(&e).unwrap();
assert!(s.contains("\"type\":\"OnLlmToken\""));
assert!(s.contains("\"token\":\"hi\""));
}
#[tokio::test]
async fn runnable_stream_collect() {
let s = RunnableStream::new(futures::stream::iter(vec![Ok(1u32), Ok(2), Ok(3)]));
let v = s.collect_into_vec().await.unwrap();
assert_eq!(v, vec![1, 2, 3]);
}
#[tokio::test]
async fn runnable_stream_callback() {
let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();
let s = RunnableStream::new(futures::stream::iter(vec![Ok(10u32), Ok(20)])).with_callback(
move |v| {
counter2.fetch_add(*v as usize, Ordering::SeqCst);
},
);
let _ = s.collect_into_vec().await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 30);
}
#[tokio::test]
async fn runnable_stream_short_circuits_on_error() {
let s: RunnableStream<u32> = RunnableStream::new(futures::stream::iter(vec![
Ok(1),
Err(crate::CognisError::Internal("stop".into())),
Ok(3),
]));
let result = s.collect_into_vec().await;
assert!(result.is_err());
}
}