use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use futures::stream::{self, StreamExt};
use uuid::Uuid;
use crate::extensions::Extensions;
use crate::stream::Observer;
#[derive(Clone)]
pub struct RunnableConfig {
pub recursion_limit: u32,
pub max_concurrency: usize,
pub tags: Vec<String>,
pub metadata: serde_json::Value,
pub observers: Vec<Arc<dyn Observer>>,
pub run_id: Uuid,
pub cancel_token: Option<tokio_util::sync::CancellationToken>,
pub deadline: Option<Instant>,
pub extras: Extensions,
pub parent_run_id: Option<Uuid>,
}
impl Default for RunnableConfig {
fn default() -> Self {
Self {
recursion_limit: 25,
max_concurrency: num_cpus::get().max(1),
tags: Vec::new(),
metadata: serde_json::Value::Null,
observers: Vec::new(),
run_id: Uuid::new_v4(),
cancel_token: None,
deadline: None,
extras: Extensions::new(),
parent_run_id: None,
}
}
}
impl RunnableConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_recursion_limit(mut self, n: u32) -> Self {
self.recursion_limit = n;
self
}
pub fn with_max_concurrency(mut self, n: usize) -> Self {
self.max_concurrency = n;
self
}
pub fn with_observer(mut self, o: Arc<dyn Observer>) -> Self {
self.observers.push(o);
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn with_cancel_token(mut self, t: tokio_util::sync::CancellationToken) -> Self {
self.cancel_token = Some(t);
self
}
pub fn with_parent_run_id(mut self, id: Uuid) -> Self {
self.parent_run_id = Some(id);
self
}
pub fn emit(&self, event: &crate::stream::Event) {
for o in &self.observers {
o.on_event(event);
}
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token
.as_ref()
.map(|t| t.is_cancelled())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_sane() {
let c = RunnableConfig::default();
assert_eq!(c.recursion_limit, 25);
assert!(c.max_concurrency >= 1);
assert!(c.observers.is_empty());
}
#[test]
fn builder_chains() {
let c = RunnableConfig::new()
.with_recursion_limit(10)
.with_max_concurrency(4)
.with_tag("prod");
assert_eq!(c.recursion_limit, 10);
assert_eq!(c.max_concurrency, 4);
assert_eq!(c.tags, vec!["prod"]);
}
#[test]
fn cancel_default_false() {
let c = RunnableConfig::default();
assert!(!c.is_cancelled());
}
#[test]
fn config_clones_with_extras_emptied() {
let mut c = RunnableConfig::default()
.with_recursion_limit(50)
.with_max_concurrency(8)
.with_tag("test");
c.extras.insert(42u32);
assert!(c.extras.contains::<u32>());
let cloned = c.clone();
assert_eq!(cloned.recursion_limit, 50);
assert_eq!(cloned.max_concurrency, 8);
assert_eq!(cloned.tags, vec!["test"]);
assert!(cloned.extras.is_empty());
}
#[test]
fn parent_run_id_default_is_none() {
assert!(RunnableConfig::default().parent_run_id.is_none());
}
#[test]
fn clone_for_subcall_sets_parent_run_id_to_self() {
use std::sync::Arc;
let parent = Arc::new(RunnableConfig::default());
let child = RunnableConfig::clone_for_subcall(&parent);
assert_eq!(child.parent_run_id, Some(parent.run_id));
assert_ne!(child.run_id, parent.run_id);
}
#[test]
fn with_parent_run_id_builder() {
let id = Uuid::new_v4();
let cfg = RunnableConfig::default().with_parent_run_id(id);
assert_eq!(cfg.parent_run_id, Some(id));
}
}
#[async_trait]
pub trait Runnable<I, O>: Send + Sync
where
I: Send + 'static,
O: Send + 'static,
{
async fn invoke(&self, input: I, config: RunnableConfig) -> crate::Result<O>;
async fn batch(&self, inputs: Vec<I>, config: RunnableConfig) -> crate::Result<Vec<O>>
where
I: 'static,
O: 'static,
Self: Sized + Sync,
{
let concurrency = config.max_concurrency.max(1);
let cfg = Arc::new(config);
stream::iter(inputs)
.map(|input| {
let cfg = cfg.clone();
async move {
self.invoke(input, RunnableConfig::clone_for_subcall(&cfg))
.await
}
})
.buffer_unordered(concurrency)
.collect::<Vec<_>>()
.await
.into_iter()
.collect()
}
async fn stream(&self, input: I, config: RunnableConfig) -> crate::Result<RunnableStream<O>>
where
Self: Sized + Sync,
{
let result = self.invoke(input, config).await;
Ok(RunnableStream::once(result))
}
async fn stream_events(&self, input: I, config: RunnableConfig) -> crate::Result<EventStream>
where
I: serde::Serialize,
O: serde::Serialize,
Self: Sized + Sync,
{
let runnable = self.name().to_string();
let run_id = config.run_id;
let input_json = serde_json::to_value(&input).unwrap_or(serde_json::Value::Null);
let on_start = Event::OnStart {
runnable: runnable.clone(),
run_id,
input: input_json,
};
let result = self.invoke(input, config).await;
let on_end_or_err = match &result {
Ok(o) => Event::OnEnd {
runnable,
run_id,
output: serde_json::to_value(o).unwrap_or(serde_json::Value::Null),
},
Err(e) => Event::OnError {
error: e.to_string(),
run_id,
},
};
Ok(EventStream::new(stream::iter(vec![
on_start,
on_end_or_err,
])))
}
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
fn input_schema(&self) -> Option<serde_json::Value> {
None
}
fn output_schema(&self) -> Option<serde_json::Value> {
None
}
}
use crate::stream::{Event, EventStream, RunnableStream};
impl RunnableConfig {
pub fn clone_for_subcall(parent: &Arc<RunnableConfig>) -> RunnableConfig {
RunnableConfig {
recursion_limit: parent.recursion_limit,
max_concurrency: parent.max_concurrency,
tags: parent.tags.clone(),
metadata: parent.metadata.clone(),
observers: parent.observers.clone(),
run_id: Uuid::new_v4(),
parent_run_id: Some(parent.run_id),
cancel_token: parent.cancel_token.clone(),
deadline: parent.deadline,
extras: Extensions::new(),
}
}
}
#[cfg(test)]
mod runnable_tests {
use super::*;
use async_trait::async_trait;
struct Doubler;
#[async_trait]
impl Runnable<u32, u32> for Doubler {
async fn invoke(&self, input: u32, _: RunnableConfig) -> crate::Result<u32> {
Ok(input * 2)
}
}
#[tokio::test]
async fn invoke_works() {
let d = Doubler;
let out = d.invoke(5, RunnableConfig::default()).await.unwrap();
assert_eq!(out, 10);
}
#[tokio::test]
async fn default_batch_runs_each() {
let d = Doubler;
let out = d
.batch(vec![1, 2, 3, 4], RunnableConfig::default())
.await
.unwrap();
let mut sorted = out;
sorted.sort();
assert_eq!(sorted, vec![2, 4, 6, 8]);
}
#[tokio::test]
async fn default_stream_emits_one_item() {
let d = Doubler;
let s = d.stream(7, RunnableConfig::default()).await.unwrap();
let v = s.collect_into_vec().await.unwrap();
assert_eq!(v, vec![14]);
}
#[tokio::test]
async fn default_stream_events_emits_start_end() {
use futures::StreamExt;
let d = Doubler;
let mut s = d.stream_events(3, RunnableConfig::default()).await.unwrap();
let mut events = Vec::new();
while let Some(e) = s.next().await {
events.push(e);
}
assert_eq!(events.len(), 2);
assert!(matches!(events[0], Event::OnStart { .. }));
assert!(matches!(events[1], Event::OnEnd { .. }));
}
#[tokio::test]
async fn batch_respects_max_concurrency() {
let d = Doubler;
let cfg = RunnableConfig::default().with_max_concurrency(1);
let out = d.batch(vec![1, 2, 3], cfg).await.unwrap();
let mut sorted = out;
sorted.sort();
assert_eq!(sorted, vec![2, 4, 6]);
}
}