enact_core/flow/
conditional.rs1use crate::callable::Callable;
6use std::sync::Arc;
7
8pub type Condition = Box<dyn Fn(&str) -> bool + Send + Sync>;
10
11pub struct Branch<C: Callable> {
13 pub condition: Condition,
15 pub callable: Arc<C>,
17 pub name: String,
19}
20
21impl<C: Callable> Branch<C> {
22 pub fn new(
24 name: impl Into<String>,
25 condition: impl Fn(&str) -> bool + Send + Sync + 'static,
26 callable: Arc<C>,
27 ) -> Self {
28 Self {
29 condition: Box::new(condition),
30 callable,
31 name: name.into(),
32 }
33 }
34
35 pub fn default(name: impl Into<String>, callable: Arc<C>) -> Self {
37 Self {
38 condition: Box::new(|_| true),
39 callable,
40 name: name.into(),
41 }
42 }
43
44 pub fn matches(&self, input: &str) -> bool {
46 (self.condition)(input)
47 }
48}
49
50pub struct ConditionalFlow<C: Callable> {
52 branches: Vec<Branch<C>>,
54 name: String,
56 default: Option<Arc<C>>,
58}
59
60impl<C: Callable> ConditionalFlow<C> {
61 pub fn new(name: impl Into<String>) -> Self {
63 Self {
64 branches: Vec::new(),
65 name: name.into(),
66 default: None,
67 }
68 }
69
70 pub fn add_branch(mut self, branch: Branch<C>) -> Self {
72 self.branches.push(branch);
73 self
74 }
75
76 pub fn when(
78 mut self,
79 name: impl Into<String>,
80 condition: impl Fn(&str) -> bool + Send + Sync + 'static,
81 callable: Arc<C>,
82 ) -> Self {
83 self.branches.push(Branch::new(name, condition, callable));
84 self
85 }
86
87 pub fn otherwise(mut self, callable: Arc<C>) -> Self {
89 self.default = Some(callable);
90 self
91 }
92
93 pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
95 for branch in &self.branches {
97 if branch.matches(input) {
98 return branch.callable.run(input).await;
99 }
100 }
101
102 if let Some(default) = &self.default {
104 return default.run(input).await;
105 }
106
107 anyhow::bail!("No matching branch and no default")
108 }
109
110 pub fn name(&self) -> &str {
112 &self.name
113 }
114
115 pub fn branch_count(&self) -> usize {
117 self.branches.len()
118 }
119}
120
121pub fn contains_condition(needle: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
124 let needle = needle.into();
125 Box::new(move |input: &str| input.contains(&needle))
126}
127
128pub fn starts_with_condition(prefix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
130 let prefix = prefix.into();
131 Box::new(move |input: &str| input.starts_with(&prefix))
132}
133
134pub fn ends_with_condition(suffix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
136 let suffix = suffix.into();
137 Box::new(move |input: &str| input.ends_with(&suffix))
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use async_trait::async_trait;
144
145 struct MockCallable {
147 name: String,
148 response: String,
149 }
150
151 impl MockCallable {
152 fn new(name: &str, response: &str) -> Self {
153 Self {
154 name: name.to_string(),
155 response: response.to_string(),
156 }
157 }
158 }
159
160 #[async_trait]
161 impl Callable for MockCallable {
162 fn name(&self) -> &str {
163 &self.name
164 }
165
166 async fn run(&self, _input: &str) -> anyhow::Result<String> {
167 Ok(self.response.clone())
168 }
169 }
170
171 #[tokio::test]
172 async fn test_conditional_first_match_wins() {
173 let flow = ConditionalFlow::new("router")
174 .when(
175 "branch_a",
176 |s| s.contains("foo"),
177 Arc::new(MockCallable::new("a", "matched_a")),
178 )
179 .when(
180 "branch_b",
181 |s| s.contains("bar"),
182 Arc::new(MockCallable::new("b", "matched_b")),
183 )
184 .otherwise(Arc::new(MockCallable::new("default", "matched_default")));
185
186 let result = flow.execute("foo").await.unwrap();
188 assert_eq!(result, "matched_a");
189
190 let result = flow.execute("bar").await.unwrap();
192 assert_eq!(result, "matched_b");
193
194 let result = flow.execute("baz").await.unwrap();
196 assert_eq!(result, "matched_default");
197 }
198
199 #[tokio::test]
200 async fn test_conditional_first_match_priority() {
201 let flow = ConditionalFlow::new("priority")
203 .when(
204 "first",
205 |s| !s.is_empty(),
206 Arc::new(MockCallable::new("a", "first_wins")),
207 )
208 .when(
209 "second",
210 |s| s.contains("x"),
211 Arc::new(MockCallable::new("b", "second_wins")),
212 );
213
214 let result = flow.execute("xyz").await.unwrap();
216 assert_eq!(result, "first_wins");
217 }
218
219 #[tokio::test]
220 async fn test_conditional_no_match_no_default() {
221 let flow: ConditionalFlow<MockCallable> = ConditionalFlow::new("strict").when(
222 "only_a",
223 |s| s == "a",
224 Arc::new(MockCallable::new("a", "matched")),
225 );
226
227 let result = flow.execute("b").await;
228 assert!(result.is_err());
229 assert!(result
230 .unwrap_err()
231 .to_string()
232 .contains("No matching branch"));
233 }
234
235 #[tokio::test]
236 async fn test_conditional_with_default() {
237 let flow = ConditionalFlow::new("with_default")
238 .when(
239 "specific",
240 |s| s == "specific",
241 Arc::new(MockCallable::new("s", "specific_response")),
242 )
243 .otherwise(Arc::new(MockCallable::new("d", "default_response")));
244
245 let result = flow.execute("anything").await.unwrap();
246 assert_eq!(result, "default_response");
247 }
248
249 #[tokio::test]
250 async fn test_branch_new_and_matches() {
251 let callable = Arc::new(MockCallable::new("test", "response"));
252 let branch = Branch::new("test_branch", |s| s.starts_with("hello"), callable);
253
254 assert!(branch.matches("hello world"));
255 assert!(!branch.matches("world hello"));
256 assert_eq!(branch.name, "test_branch");
257 }
258
259 #[tokio::test]
260 async fn test_branch_default_always_matches() {
261 let callable = Arc::new(MockCallable::new("test", "response"));
262 let branch = Branch::default("default_branch", callable);
263
264 assert!(branch.matches("anything"));
265 assert!(branch.matches(""));
266 assert!(branch.matches("123"));
267 }
268
269 #[tokio::test]
270 async fn test_contains_condition() {
271 let condition = contains_condition("needle");
272 assert!(condition("haystack needle here"));
273 assert!(!condition("no match"));
274 }
275
276 #[tokio::test]
277 async fn test_starts_with_condition() {
278 let condition = starts_with_condition("prefix");
279 assert!(condition("prefix_rest"));
280 assert!(!condition("no_prefix"));
281 }
282
283 #[tokio::test]
284 async fn test_ends_with_condition() {
285 let condition = ends_with_condition("suffix");
286 assert!(condition("word_suffix"));
287 assert!(!condition("suffix_not"));
288 }
289
290 #[tokio::test]
291 async fn test_conditional_flow_properties() {
292 let flow = ConditionalFlow::new("test_flow")
293 .when("b1", |_| true, Arc::new(MockCallable::new("1", "r1")))
294 .when("b2", |_| true, Arc::new(MockCallable::new("2", "r2")));
295
296 assert_eq!(flow.name(), "test_flow");
297 assert_eq!(flow.branch_count(), 2);
298 }
299
300 #[tokio::test]
301 async fn test_conditional_error_propagation() {
302 struct FailingCallable;
303
304 #[async_trait]
305 impl Callable for FailingCallable {
306 fn name(&self) -> &str {
307 "failing"
308 }
309 async fn run(&self, _input: &str) -> anyhow::Result<String> {
310 anyhow::bail!("Branch failed")
311 }
312 }
313
314 let flow: ConditionalFlow<FailingCallable> =
315 ConditionalFlow::new("failing").when("fail", |_| true, Arc::new(FailingCallable));
316
317 let result = flow.execute("any").await;
318 assert!(result.is_err());
319 assert!(result.unwrap_err().to_string().contains("Branch failed"));
320 }
321}