cognis_core/compose/
pipe.rs1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6
7use crate::runnable::{Runnable, RunnableConfig};
8use crate::Result;
9
10pub struct Pipe<A, B, I, M, O> {
12 a: A,
13 b: B,
14 _types: PhantomData<fn(I) -> O>,
15 _mid: PhantomData<fn() -> M>,
16}
17
18impl<A, B, I, M, O> Pipe<A, B, I, M, O>
19where
20 A: Runnable<I, M>,
21 B: Runnable<M, O>,
22 I: Send + 'static,
23 M: Send + 'static,
24 O: Send + 'static,
25{
26 pub fn new(a: A, b: B) -> Self {
28 Self {
29 a,
30 b,
31 _types: PhantomData,
32 _mid: PhantomData,
33 }
34 }
35}
36
37#[async_trait]
38impl<A, B, I, M, O> Runnable<I, O> for Pipe<A, B, I, M, O>
39where
40 A: Runnable<I, M>,
41 B: Runnable<M, O>,
42 I: Send + 'static,
43 M: Send + 'static,
44 O: Send + 'static,
45{
46 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
47 let mid = self.a.invoke(input, config.clone()).await?;
48 self.b.invoke(mid, config).await
49 }
50
51 fn name(&self) -> &str {
52 "Pipe"
53 }
54
55 fn input_schema(&self) -> Option<serde_json::Value> {
56 self.a.input_schema()
57 }
58
59 fn output_schema(&self) -> Option<serde_json::Value> {
60 self.b.output_schema()
61 }
62}
63
64pub fn pipe<A, B, I, M, O>(a: A, b: B) -> Pipe<A, B, I, M, O>
67where
68 A: Runnable<I, M>,
69 B: Runnable<M, O>,
70 I: Send + 'static,
71 M: Send + 'static,
72 O: Send + 'static,
73{
74 Pipe::new(a, b)
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 struct Add(u32);
82
83 #[async_trait]
84 impl Runnable<u32, u32> for Add {
85 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
86 Ok(input + self.0)
87 }
88 }
89
90 struct Mul(u32);
91
92 #[async_trait]
93 impl Runnable<u32, u32> for Mul {
94 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
95 Ok(input * self.0)
96 }
97 }
98
99 #[tokio::test]
100 async fn pipe_method_chains() {
101 let chain = Pipe::new(Add(2), Mul(3));
102 let out = chain.invoke(4, RunnableConfig::default()).await.unwrap();
103 assert_eq!(out, (4 + 2) * 3);
104 }
105
106 #[tokio::test]
107 async fn pipe_fn_works() {
108 let chain = pipe(Add(1), Mul(10));
109 let out = chain.invoke(2, RunnableConfig::default()).await.unwrap();
110 assert_eq!(out, (2 + 1) * 10);
111 }
112}