Skip to main content

enact_core/flow/
conditional.rs

1//! Conditional Flow - If/then/else branching
2//!
3//! Route execution based on conditions evaluated against the input.
4
5use crate::callable::Callable;
6use std::sync::Arc;
7
8/// Condition for branching
9pub type Condition = Box<dyn Fn(&str) -> bool + Send + Sync>;
10
11/// A conditional branch
12pub struct Branch<C: Callable> {
13    /// Condition to evaluate
14    pub condition: Condition,
15    /// Callable to execute if condition is true
16    pub callable: Arc<C>,
17    /// Branch name (for debugging/logging)
18    pub name: String,
19}
20
21impl<C: Callable> Branch<C> {
22    /// Create a new branch
23    pub fn new(
24        name: impl Into<String>,
25        condition: impl Fn(&str) -> bool + Send + Sync + 'static,
26        callable: Arc<C>,
27    ) -> Self {
28        Self {
29            condition: Box::new(condition),
30            callable,
31            name: name.into(),
32        }
33    }
34
35    /// Create an "always true" branch (default/else)
36    pub fn default(name: impl Into<String>, callable: Arc<C>) -> Self {
37        Self {
38            condition: Box::new(|_| true),
39            callable,
40            name: name.into(),
41        }
42    }
43
44    /// Check if this branch matches the input
45    pub fn matches(&self, input: &str) -> bool {
46        (self.condition)(input)
47    }
48}
49
50/// Conditional execution flow
51pub struct ConditionalFlow<C: Callable> {
52    /// Ordered list of branches (first match wins)
53    branches: Vec<Branch<C>>,
54    /// Flow name
55    name: String,
56    /// Default branch if no conditions match
57    default: Option<Arc<C>>,
58}
59
60impl<C: Callable> ConditionalFlow<C> {
61    /// Create a new conditional flow
62    pub fn new(name: impl Into<String>) -> Self {
63        Self {
64            branches: Vec::new(),
65            name: name.into(),
66            default: None,
67        }
68    }
69
70    /// Add a conditional branch
71    pub fn add_branch(mut self, branch: Branch<C>) -> Self {
72        self.branches.push(branch);
73        self
74    }
75
76    /// Add a condition with callable
77    pub fn when(
78        mut self,
79        name: impl Into<String>,
80        condition: impl Fn(&str) -> bool + Send + Sync + 'static,
81        callable: Arc<C>,
82    ) -> Self {
83        self.branches.push(Branch::new(name, condition, callable));
84        self
85    }
86
87    /// Set default (else) branch
88    pub fn otherwise(mut self, callable: Arc<C>) -> Self {
89        self.default = Some(callable);
90        self
91    }
92
93    /// Execute the flow - first matching branch wins
94    pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
95        // Find first matching branch
96        for branch in &self.branches {
97            if branch.matches(input) {
98                return branch.callable.run(input).await;
99            }
100        }
101
102        // Fall back to default
103        if let Some(default) = &self.default {
104            return default.run(input).await;
105        }
106
107        anyhow::bail!("No matching branch and no default")
108    }
109
110    /// Get the flow name
111    pub fn name(&self) -> &str {
112        &self.name
113    }
114
115    /// Get branch count
116    pub fn branch_count(&self) -> usize {
117        self.branches.len()
118    }
119}
120
121// Common condition builders as free functions
122/// Create a condition that checks if input contains a string
123pub fn contains_condition(needle: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
124    let needle = needle.into();
125    Box::new(move |input: &str| input.contains(&needle))
126}
127
128/// Create a condition that checks if input starts with a string
129pub fn starts_with_condition(prefix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
130    let prefix = prefix.into();
131    Box::new(move |input: &str| input.starts_with(&prefix))
132}
133
134/// Create a condition that checks if input ends with a string
135pub fn ends_with_condition(suffix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
136    let suffix = suffix.into();
137    Box::new(move |input: &str| input.ends_with(&suffix))
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use async_trait::async_trait;
144
145    /// Mock callable for testing
146    struct MockCallable {
147        name: String,
148        response: String,
149    }
150
151    impl MockCallable {
152        fn new(name: &str, response: &str) -> Self {
153            Self {
154                name: name.to_string(),
155                response: response.to_string(),
156            }
157        }
158    }
159
160    #[async_trait]
161    impl Callable for MockCallable {
162        fn name(&self) -> &str {
163            &self.name
164        }
165
166        async fn run(&self, _input: &str) -> anyhow::Result<String> {
167            Ok(self.response.clone())
168        }
169    }
170
171    #[tokio::test]
172    async fn test_conditional_first_match_wins() {
173        let flow = ConditionalFlow::new("router")
174            .when(
175                "branch_a",
176                |s| s.contains("foo"),
177                Arc::new(MockCallable::new("a", "matched_a")),
178            )
179            .when(
180                "branch_b",
181                |s| s.contains("bar"),
182                Arc::new(MockCallable::new("b", "matched_b")),
183            )
184            .otherwise(Arc::new(MockCallable::new("default", "matched_default")));
185
186        // Should match branch_a
187        let result = flow.execute("foo").await.unwrap();
188        assert_eq!(result, "matched_a");
189
190        // Should match branch_b
191        let result = flow.execute("bar").await.unwrap();
192        assert_eq!(result, "matched_b");
193
194        // Should match default
195        let result = flow.execute("baz").await.unwrap();
196        assert_eq!(result, "matched_default");
197    }
198
199    #[tokio::test]
200    async fn test_conditional_first_match_priority() {
201        // If input matches multiple conditions, first one wins
202        let flow = ConditionalFlow::new("priority")
203            .when(
204                "first",
205                |s| !s.is_empty(),
206                Arc::new(MockCallable::new("a", "first_wins")),
207            )
208            .when(
209                "second",
210                |s| s.contains("x"),
211                Arc::new(MockCallable::new("b", "second_wins")),
212            );
213
214        // Both conditions match "xyz", but first one wins
215        let result = flow.execute("xyz").await.unwrap();
216        assert_eq!(result, "first_wins");
217    }
218
219    #[tokio::test]
220    async fn test_conditional_no_match_no_default() {
221        let flow: ConditionalFlow<MockCallable> = ConditionalFlow::new("strict").when(
222            "only_a",
223            |s| s == "a",
224            Arc::new(MockCallable::new("a", "matched")),
225        );
226
227        let result = flow.execute("b").await;
228        assert!(result.is_err());
229        assert!(result
230            .unwrap_err()
231            .to_string()
232            .contains("No matching branch"));
233    }
234
235    #[tokio::test]
236    async fn test_conditional_with_default() {
237        let flow = ConditionalFlow::new("with_default")
238            .when(
239                "specific",
240                |s| s == "specific",
241                Arc::new(MockCallable::new("s", "specific_response")),
242            )
243            .otherwise(Arc::new(MockCallable::new("d", "default_response")));
244
245        let result = flow.execute("anything").await.unwrap();
246        assert_eq!(result, "default_response");
247    }
248
249    #[tokio::test]
250    async fn test_branch_new_and_matches() {
251        let callable = Arc::new(MockCallable::new("test", "response"));
252        let branch = Branch::new("test_branch", |s| s.starts_with("hello"), callable);
253
254        assert!(branch.matches("hello world"));
255        assert!(!branch.matches("world hello"));
256        assert_eq!(branch.name, "test_branch");
257    }
258
259    #[tokio::test]
260    async fn test_branch_default_always_matches() {
261        let callable = Arc::new(MockCallable::new("test", "response"));
262        let branch = Branch::default("default_branch", callable);
263
264        assert!(branch.matches("anything"));
265        assert!(branch.matches(""));
266        assert!(branch.matches("123"));
267    }
268
269    #[tokio::test]
270    async fn test_contains_condition() {
271        let condition = contains_condition("needle");
272        assert!(condition("haystack needle here"));
273        assert!(!condition("no match"));
274    }
275
276    #[tokio::test]
277    async fn test_starts_with_condition() {
278        let condition = starts_with_condition("prefix");
279        assert!(condition("prefix_rest"));
280        assert!(!condition("no_prefix"));
281    }
282
283    #[tokio::test]
284    async fn test_ends_with_condition() {
285        let condition = ends_with_condition("suffix");
286        assert!(condition("word_suffix"));
287        assert!(!condition("suffix_not"));
288    }
289
290    #[tokio::test]
291    async fn test_conditional_flow_properties() {
292        let flow = ConditionalFlow::new("test_flow")
293            .when("b1", |_| true, Arc::new(MockCallable::new("1", "r1")))
294            .when("b2", |_| true, Arc::new(MockCallable::new("2", "r2")));
295
296        assert_eq!(flow.name(), "test_flow");
297        assert_eq!(flow.branch_count(), 2);
298    }
299
300    #[tokio::test]
301    async fn test_conditional_error_propagation() {
302        struct FailingCallable;
303
304        #[async_trait]
305        impl Callable for FailingCallable {
306            fn name(&self) -> &str {
307                "failing"
308            }
309            async fn run(&self, _input: &str) -> anyhow::Result<String> {
310                anyhow::bail!("Branch failed")
311            }
312        }
313
314        let flow: ConditionalFlow<FailingCallable> =
315            ConditionalFlow::new("failing").when("fail", |_| true, Arc::new(FailingCallable));
316
317        let result = flow.execute("any").await;
318        assert!(result.is_err());
319        assert!(result.unwrap_err().to_string().contains("Branch failed"));
320    }
321}