Skip to main content

cognis_core/compose/
branch.rs

1//! Predicate-based dispatch.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::runnable::{Runnable, RunnableConfig};
8use crate::Result;
9
10type Predicate<I> = dyn Fn(&I) -> bool + Send + Sync;
11
12/// One arm of a [`Branch`]: a sync predicate paired with a runnable.
13pub struct BranchCase<I, O> {
14    /// Predicate inspected without consuming the input.
15    pub predicate: Arc<Predicate<I>>,
16    /// Runnable invoked when the predicate returns `true`.
17    pub runnable: Arc<dyn Runnable<I, O>>,
18}
19
20/// Conditional dispatch: try each case's predicate in order; the first
21/// match runs. If none match, the default runs.
22pub struct Branch<I, O> {
23    cases: Vec<BranchCase<I, O>>,
24    default: Arc<dyn Runnable<I, O>>,
25}
26
27impl<I, O> Branch<I, O>
28where
29    I: Send + 'static,
30    O: Send + 'static,
31{
32    /// Build with a default runnable used when no case matches.
33    pub fn new(default: Arc<dyn Runnable<I, O>>) -> Self {
34        Self {
35            cases: Vec::new(),
36            default,
37        }
38    }
39
40    /// Add a case.
41    pub fn case<P>(mut self, predicate: P, runnable: Arc<dyn Runnable<I, O>>) -> Self
42    where
43        P: Fn(&I) -> bool + Send + Sync + 'static,
44    {
45        self.cases.push(BranchCase {
46            predicate: Arc::new(predicate),
47            runnable,
48        });
49        self
50    }
51}
52
53#[async_trait]
54impl<I, O> Runnable<I, O> for Branch<I, O>
55where
56    I: Send + 'static,
57    O: Send + 'static,
58{
59    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
60        for c in &self.cases {
61            if (c.predicate)(&input) {
62                return c.runnable.invoke(input, config).await;
63            }
64        }
65        self.default.invoke(input, config).await
66    }
67    fn name(&self) -> &str {
68        "Branch"
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    struct Const(u32);
77
78    #[async_trait]
79    impl Runnable<u32, u32> for Const {
80        async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
81            Ok(self.0)
82        }
83    }
84
85    #[tokio::test]
86    async fn dispatches_to_first_match() {
87        let b: Branch<u32, u32> = Branch::new(Arc::new(Const(0)))
88            .case(|i| *i < 10, Arc::new(Const(1)))
89            .case(|i| *i < 100, Arc::new(Const(2)));
90        assert_eq!(b.invoke(5, RunnableConfig::default()).await.unwrap(), 1);
91        assert_eq!(b.invoke(50, RunnableConfig::default()).await.unwrap(), 2);
92        assert_eq!(b.invoke(500, RunnableConfig::default()).await.unwrap(), 0);
93    }
94}