use async_trait::async_trait;
use futures::stream::{self, StreamExt};
use serde_json::Value;
use crate::error::Result;
use super::config::RunnableConfig;
use super::RunnableStream;
#[async_trait]
pub trait Runnable: Send + Sync {
fn name(&self) -> &str;
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value>;
async fn batch(
&self,
inputs: Vec<Value>,
config: Option<&RunnableConfig>,
) -> Result<Vec<Value>> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
results.push(self.invoke(input, config).await?);
}
Ok(results)
}
async fn abatch(
&self,
inputs: Vec<Value>,
config: Option<&RunnableConfig>,
) -> Vec<Result<Value>> {
let concurrency = config
.and_then(|c| c.max_concurrency)
.unwrap_or(inputs.len().max(1));
let results: Vec<Result<Value>> = stream::iter(inputs.into_iter().enumerate())
.map(|(idx, input)| async move {
let result = self.invoke(input, config).await;
(idx, result)
})
.buffer_unordered(concurrency)
.collect::<Vec<_>>()
.await
.into_iter()
.fold(Vec::new(), |mut acc, (idx, result)| {
if acc.len() <= idx {
acc.resize_with(idx + 1, || Ok(Value::Null));
}
acc[idx] = result;
acc
});
results
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
let result = self.invoke(input, config).await;
Ok(Box::pin(stream::once(async { result })))
}
fn input_schema(&self) -> Value {
serde_json::json!({
"description": format!("Input for {}", self.name())
})
}
fn output_schema(&self) -> Value {
serde_json::json!({
"description": format!("Output of {}", self.name())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runnables::{RunnableExt, RunnableLambda};
use serde_json::json;
use std::time::{Duration, Instant};
fn slow_double(delay_ms: u64) -> RunnableLambda {
RunnableLambda::new("slow_double", move |v: Value| async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
let n = v.as_i64().unwrap();
Ok(json!(n * 2))
})
}
#[tokio::test]
async fn test_abatch_concurrent_faster_than_sequential() {
let delay_ms = 50;
let runnable = slow_double(delay_ms);
let inputs: Vec<Value> = (0..4).map(|i| json!(i)).collect();
let start = Instant::now();
let seq_results = runnable.batch(inputs.clone(), None).await.unwrap();
let seq_elapsed = start.elapsed();
let start = Instant::now();
let conc_results = runnable.abatch(inputs.clone(), None).await;
let conc_elapsed = start.elapsed();
assert_eq!(seq_results.len(), 4);
assert_eq!(conc_results.len(), 4);
for (i, r) in conc_results.iter().enumerate() {
assert_eq!(r.as_ref().unwrap(), &json!(i as i64 * 2));
}
assert!(
conc_elapsed < seq_elapsed,
"abatch ({:?}) should be faster than batch ({:?})",
conc_elapsed,
seq_elapsed
);
}
#[tokio::test]
async fn test_abatch_concurrency_limit_1_is_sequential() {
let delay_ms = 50;
let runnable = slow_double(delay_ms);
let inputs: Vec<Value> = (0..4).map(|i| json!(i)).collect();
let mut config = RunnableConfig::default();
config.max_concurrency = Some(1);
let start = Instant::now();
let results = runnable.abatch(inputs, Some(&config)).await;
let elapsed = start.elapsed();
assert_eq!(results.len(), 4);
assert!(
elapsed >= Duration::from_millis(delay_ms * 4 - 20),
"concurrency=1 should be sequential-like, elapsed: {:?}",
elapsed
);
for (i, r) in results.iter().enumerate() {
assert_eq!(r.as_ref().unwrap(), &json!(i as i64 * 2));
}
}
#[tokio::test]
async fn test_abatch_concurrency_limit_2_on_4_inputs() {
let delay_ms = 50;
let runnable = slow_double(delay_ms);
let inputs: Vec<Value> = (0..4).map(|i| json!(i)).collect();
let mut config = RunnableConfig::default();
config.max_concurrency = Some(2);
let start = Instant::now();
let results = runnable.abatch(inputs, Some(&config)).await;
let elapsed = start.elapsed();
assert_eq!(results.len(), 4);
for (i, r) in results.iter().enumerate() {
assert_eq!(r.as_ref().unwrap(), &json!(i as i64 * 2));
}
assert!(
elapsed < Duration::from_millis(delay_ms * 4 - 20),
"concurrency=2 should be faster than sequential, elapsed: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_abatch_preserves_input_order() {
let runnable = RunnableLambda::new("delay_by_value", |v: Value| async move {
let n = v.as_i64().unwrap();
let delay = (4 - n) as u64 * 20;
tokio::time::sleep(Duration::from_millis(delay)).await;
Ok(json!(n * 10))
});
let inputs: Vec<Value> = (0..5).map(|i| json!(i)).collect();
let results = runnable.abatch(inputs, None).await;
assert_eq!(results.len(), 5);
for (i, r) in results.iter().enumerate() {
assert_eq!(
r.as_ref().unwrap(),
&json!(i as i64 * 10),
"Result at index {} should be {} but got {:?}",
i,
i * 10,
r
);
}
}
#[tokio::test]
async fn test_with_concurrency_extension() {
let delay_ms = 50;
let runnable = slow_double(delay_ms);
let bound = runnable.with_concurrency(2);
let result = bound.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(10));
let inputs: Vec<Value> = (0..4).map(|i| json!(i)).collect();
let start = Instant::now();
let results = bound.abatch(inputs, None).await;
let elapsed = start.elapsed();
assert_eq!(results.len(), 4);
for (i, r) in results.iter().enumerate() {
assert_eq!(r.as_ref().unwrap(), &json!(i as i64 * 2));
}
assert!(
elapsed < Duration::from_millis(delay_ms * 4 - 20),
"with_concurrency(2) should limit but still parallelize, elapsed: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_abatch_empty_inputs() {
let runnable = slow_double(10);
let results = runnable.abatch(vec![], None).await;
assert!(results.is_empty());
}
#[tokio::test]
async fn test_abatch_single_input() {
let runnable = slow_double(10);
let results = runnable.abatch(vec![json!(7)], None).await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].as_ref().unwrap(), &json!(14));
}
#[test]
fn test_default_input_schema() {
let r = RunnableLambda::new("test_fn", |v: Value| async move { Ok(v) });
let schema = r.input_schema();
assert!(schema.get("type").is_none());
assert!(schema["description"].as_str().unwrap().contains("test_fn"));
}
#[test]
fn test_default_output_schema() {
let r = RunnableLambda::new("test_fn", |v: Value| async move { Ok(v) });
let schema = r.output_schema();
assert!(schema.get("type").is_none());
assert!(schema["description"].as_str().unwrap().contains("test_fn"));
}
}