use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::error::Result;
use super::base::Runnable;
use super::config::RunnableConfig;
pub struct RunnablePipe {
first: Arc<dyn Runnable>,
second: Arc<dyn Runnable>,
}
impl RunnablePipe {
pub fn new(first: Arc<dyn Runnable>, second: Arc<dyn Runnable>) -> Self {
Self { first, second }
}
}
#[async_trait]
impl Runnable for RunnablePipe {
fn name(&self) -> &str {
"RunnablePipe"
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
let intermediate = self.first.invoke(input, config).await?;
self.second.invoke(intermediate, config).await
}
async fn batch(
&self,
inputs: Vec<Value>,
config: Option<&RunnableConfig>,
) -> Result<Vec<Value>> {
let intermediates = self.first.batch(inputs, config).await?;
self.second.batch(intermediates, config).await
}
}
impl RunnablePipe {
pub fn pipe_name(&self) -> String {
format!("{} | {}", self.first.name(), self.second.name())
}
}
pub fn pipe(first: Arc<dyn Runnable>, second: Arc<dyn Runnable>) -> Arc<dyn Runnable> {
Arc::new(RunnablePipe::new(first, second))
}
pub struct PipeBuilder {
steps: Vec<Arc<dyn Runnable>>,
}
impl PipeBuilder {
pub fn new(first: Arc<dyn Runnable>) -> Self {
Self { steps: vec![first] }
}
pub fn pipe(mut self, next: Arc<dyn Runnable>) -> Self {
self.steps.push(next);
self
}
pub fn build(self) -> Result<Arc<dyn Runnable>> {
match self.steps.len() {
0 => unreachable!("PipeBuilder always has at least one step"),
1 => Ok(self.steps.into_iter().next().unwrap()),
2 => {
let mut iter = self.steps.into_iter();
let first = iter.next().unwrap();
let second = iter.next().unwrap();
Ok(Arc::new(RunnablePipe::new(first, second)))
}
_ => {
let seq = super::sequence::RunnableSequence::new(self.steps)?;
Ok(Arc::new(seq))
}
}
}
pub fn name(&self) -> String {
self.steps
.iter()
.map(|s| s.name())
.collect::<Vec<_>>()
.join(" | ")
}
}
#[derive(Clone)]
pub struct RunnableRef {
pub(crate) inner: Arc<dyn Runnable>,
steps: Vec<Arc<dyn Runnable>>,
}
impl RunnableRef {
pub fn new(r: Arc<dyn Runnable>) -> Self {
Self {
inner: r.clone(),
steps: vec![r],
}
}
pub fn into_inner(self) -> Arc<dyn Runnable> {
if self.steps.len() <= 1 {
self.inner
} else {
self.build_sequence()
}
}
pub fn runnable(&self) -> Arc<dyn Runnable> {
if self.steps.len() <= 1 {
self.inner.clone()
} else {
self.build_sequence()
}
}
fn build_sequence(&self) -> Arc<dyn Runnable> {
match crate::runnables::RunnableSequence::new(self.steps.clone()) {
Ok(seq) => Arc::new(seq),
Err(_) => self.inner.clone(),
}
}
}
impl std::ops::BitOr for RunnableRef {
type Output = RunnableRef;
fn bitor(self, rhs: RunnableRef) -> RunnableRef {
let mut steps = self.steps;
steps.extend(rhs.steps);
let inner = match crate::runnables::RunnableSequence::new(steps.clone()) {
Ok(seq) => Arc::new(seq) as Arc<dyn Runnable>,
Err(_) => unreachable!("pipe operator always provides at least 2 steps"),
};
RunnableRef { inner, steps }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runnables::{RunnableLambda, RunnableParallel};
use serde_json::json;
use std::collections::HashMap;
fn add_one() -> Arc<dyn Runnable> {
Arc::new(RunnableLambda::new("add_one", |v: Value| async move {
let n = v.as_i64().unwrap();
Ok(json!(n + 1))
}))
}
fn double() -> Arc<dyn Runnable> {
Arc::new(RunnableLambda::new("double", |v: Value| async move {
let n = v.as_i64().unwrap();
Ok(json!(n * 2))
}))
}
fn to_string_runnable() -> Arc<dyn Runnable> {
Arc::new(RunnableLambda::new("to_string", |v: Value| async move {
let n = v.as_i64().unwrap();
Ok(json!(format!("result:{}", n)))
}))
}
fn parse_int() -> Arc<dyn Runnable> {
Arc::new(RunnableLambda::new("parse_int", |v: Value| async move {
let s = v.as_str().unwrap();
let n: i64 = s.parse().map_err(|e: std::num::ParseIntError| {
crate::error::CognisError::Other(e.to_string())
})?;
Ok(json!(n))
}))
}
fn failing_runnable() -> Arc<dyn Runnable> {
Arc::new(RunnableLambda::new("fail", |_v: Value| async move {
Err(crate::error::CognisError::Other(
"intentional failure".into(),
))
}))
}
#[tokio::test]
async fn test_pipe_basic_two_lambdas() {
let piped = RunnablePipe::new(add_one(), double());
let result = piped.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(12));
}
#[tokio::test]
async fn test_pipe_triple_chain() {
let first = RunnablePipe::new(add_one(), double());
let chained: Arc<dyn Runnable> = Arc::new(first);
let piped = RunnablePipe::new(chained, add_one());
let result = piped.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(13));
}
#[tokio::test]
async fn test_pipe_name_formatting() {
let piped = RunnablePipe::new(add_one(), double());
assert_eq!(piped.pipe_name(), "add_one | double");
assert_eq!(piped.name(), "RunnablePipe");
}
#[tokio::test]
async fn test_pipe_builder_multi_step() {
let chain = PipeBuilder::new(add_one())
.pipe(double())
.pipe(add_one())
.build()
.unwrap();
let result = chain.invoke(json!(3), None).await.unwrap();
assert_eq!(result, json!(9));
}
#[tokio::test]
async fn test_pipe_builder_name() {
let builder = PipeBuilder::new(add_one())
.pipe(double())
.pipe(to_string_runnable());
assert_eq!(builder.name(), "add_one | double | to_string");
}
#[tokio::test]
async fn test_parallel_multiple_branches() {
let mut steps = HashMap::new();
steps.insert("added".to_string(), add_one());
steps.insert("doubled".to_string(), double());
let parallel = RunnableParallel::new(steps);
let result = parallel.invoke(json!(5), None).await.unwrap();
assert_eq!(result["added"], json!(6));
assert_eq!(result["doubled"], json!(10));
}
#[tokio::test]
async fn test_parallel_single_branch() {
let mut steps = HashMap::new();
steps.insert("only".to_string(), add_one());
let parallel = RunnableParallel::new(steps);
let result = parallel.invoke(json!(10), None).await.unwrap();
assert_eq!(result["only"], json!(11));
}
#[tokio::test]
async fn test_pipe_with_parallel() {
let mut steps = HashMap::new();
steps.insert("added".to_string(), add_one());
steps.insert("doubled".to_string(), double());
let parallel: Arc<dyn Runnable> = Arc::new(RunnableParallel::new(steps));
let extract = Arc::new(RunnableLambda::new(
"extract_added",
|v: Value| async move { Ok(v["added"].clone()) },
));
let chain = PipeBuilder::new(parallel).pipe(extract).build().unwrap();
let result = chain.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(6));
}
#[tokio::test]
async fn test_pipe_error_propagation() {
let piped = RunnablePipe::new(failing_runnable(), double());
let result = piped.invoke(json!(5), None).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("intentional failure"));
}
#[tokio::test]
async fn test_pipe_error_in_second_step() {
let piped = RunnablePipe::new(add_one(), failing_runnable());
let result = piped.invoke(json!(5), None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_pipe_config_passing() {
let config_reader = Arc::new(RunnableLambda::with_config(
"config_reader",
|v: Value, config: Option<RunnableConfig>| async move {
let run_name = config
.and_then(|c| c.run_name)
.unwrap_or_else(|| "none".to_string());
Ok(json!({
"input": v,
"run_name": run_name,
}))
},
));
let piped = RunnablePipe::new(add_one(), config_reader);
let mut cfg = RunnableConfig::default();
cfg.run_name = Some("test_run".to_string());
let result = piped.invoke(json!(5), Some(&cfg)).await.unwrap();
assert_eq!(result["input"], json!(6));
assert_eq!(result["run_name"], json!("test_run"));
}
#[tokio::test]
async fn test_pipe_type_transformation() {
let chain = PipeBuilder::new(parse_int())
.pipe(double())
.pipe(to_string_runnable())
.build()
.unwrap();
let result = chain.invoke(json!("7"), None).await.unwrap();
assert_eq!(result, json!("result:14"));
}
#[tokio::test]
async fn test_pipe_builder_single_step() {
let chain = PipeBuilder::new(add_one()).build().unwrap();
let result = chain.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(6));
}
#[tokio::test]
async fn test_pipe_batch() {
let piped = RunnablePipe::new(add_one(), double());
let inputs = vec![json!(1), json!(2), json!(3)];
let results = piped.batch(inputs, None).await.unwrap();
assert_eq!(results, vec![json!(4), json!(6), json!(8)]);
}
}
#[cfg(test)]
mod pipe_operator_tests {
use super::*;
use crate::runnables::lambda::RunnableLambda;
use serde_json::{json, Value};
#[tokio::test]
async fn test_pipe_operator_two_steps() {
let add_one = RunnableRef::new(Arc::new(RunnableLambda::new(
"add_one",
|v: Value| async move {
let n = v.as_i64().unwrap_or(0);
Ok(json!(n + 1))
},
)));
let double = RunnableRef::new(Arc::new(RunnableLambda::new(
"double",
|v: Value| async move {
let n = v.as_i64().unwrap_or(0);
Ok(json!(n * 2))
},
)));
let chain = add_one | double;
let result = chain.runnable().invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(12)); }
#[tokio::test]
async fn test_pipe_operator_three_steps() {
let a = RunnableRef::new(Arc::new(RunnableLambda::new("a", |v: Value| async move {
let n = v.as_i64().unwrap_or(0);
Ok(json!(n + 1))
})));
let b = RunnableRef::new(Arc::new(RunnableLambda::new("b", |v: Value| async move {
let n = v.as_i64().unwrap_or(0);
Ok(json!(n * 2))
})));
let c = RunnableRef::new(Arc::new(RunnableLambda::new("c", |v: Value| async move {
let n = v.as_i64().unwrap_or(0);
Ok(json!(n - 3))
})));
let chain = a | b | c;
let result = chain.runnable().invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(9)); }
}