Skip to main content

bamboo_agent/agent/core/composition/
condition.rs

1//! Condition predicates for workflow control flow
2//!
3//! This module provides condition types for branching logic in tool compositions.
4
5use crate::agent::core::tools::ToolResult;
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9/// Condition for control flow in tool expressions
10///
11/// Conditions are used in `Choice` expressions to determine which branch to execute.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13#[serde(rename_all = "snake_case", tag = "type")]
14pub enum Condition {
15    /// Check if the result was successful
16    Success,
17    /// Check if JSON path contains a specific value
18    Contains {
19        /// JSON path to check (dot notation, e.g., "data.status")
20        path: String,
21        /// Value to check for
22        value: String,
23    },
24    /// Check if value at JSON path matches a regex pattern
25    Matches {
26        /// JSON path to check (dot notation)
27        path: String,
28        /// Regex pattern to match
29        pattern: String,
30    },
31    /// All conditions must be true
32    And {
33        /// List of conditions (all must be true)
34        conditions: Vec<Condition>,
35    },
36    /// At least one condition must be true
37    Or {
38        /// List of conditions (at least one must be true)
39        conditions: Vec<Condition>,
40    },
41}
42
43impl Condition {
44    /// Evaluate the condition against a tool result
45    pub fn evaluate(&self, result: &ToolResult) -> bool {
46        match self {
47            Condition::Success => result.success,
48            Condition::Contains { path, value } => evaluate_contains(&result.result, path, value),
49            Condition::Matches { path, pattern } => evaluate_matches(&result.result, path, pattern),
50            Condition::And { conditions } => conditions.iter().all(|c| c.evaluate(result)),
51            Condition::Or { conditions } => conditions.iter().any(|c| c.evaluate(result)),
52        }
53    }
54}
55
56/// Extract value at JSON path (simple dot notation)
57fn extract_at_path(json_str: &str, path: &str) -> Option<String> {
58    let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
59    let parts: Vec<&str> = path.split('.').collect();
60
61    let mut current = &value;
62    for part in parts {
63        if let Ok(index) = part.parse::<usize>() {
64            current = current.get(index)?;
65        } else {
66            current = current.get(part)?;
67        }
68    }
69
70    Some(current.to_string().trim_matches('"').to_string())
71}
72
73/// Check if value at path contains the expected value
74fn evaluate_contains(json_str: &str, path: &str, expected: &str) -> bool {
75    if let Some(value) = extract_at_path(json_str, path) {
76        value.contains(expected)
77    } else {
78        false
79    }
80}
81
82/// Check if value at path matches the regex pattern
83fn evaluate_matches(json_str: &str, path: &str, pattern: &str) -> bool {
84    let value = match extract_at_path(json_str, path) {
85        Some(v) => v,
86        None => return false,
87    };
88
89    Regex::new(pattern)
90        .map(|re| re.is_match(&value))
91        .unwrap_or(false)
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    fn create_result(result_str: &str, success: bool) -> ToolResult {
99        ToolResult {
100            success,
101            result: result_str.to_string(),
102            display_preference: None,
103        }
104    }
105
106    #[test]
107    fn test_success_condition() {
108        let success_result = create_result("{}", true);
109        let failure_result = create_result("{}", false);
110
111        assert!(Condition::Success.evaluate(&success_result));
112        assert!(!Condition::Success.evaluate(&failure_result));
113    }
114
115    #[test]
116    fn test_contains_condition() {
117        let result = create_result(r#"{"status": "completed", "data": {"name": "test"}}"#, true);
118
119        let cond = Condition::Contains {
120            path: "status".to_string(),
121            value: "complete".to_string(),
122        };
123        assert!(cond.evaluate(&result));
124
125        let cond = Condition::Contains {
126            path: "data.name".to_string(),
127            value: "test".to_string(),
128        };
129        assert!(cond.evaluate(&result));
130
131        let cond = Condition::Contains {
132            path: "status".to_string(),
133            value: "failed".to_string(),
134        };
135        assert!(!cond.evaluate(&result));
136    }
137
138    #[test]
139    fn test_matches_condition() {
140        let result = create_result(r#"{"email": "user@example.com"}"#, true);
141
142        let cond = Condition::Matches {
143            path: "email".to_string(),
144            pattern: r"^\S+@\S+\.\S+$".to_string(),
145        };
146        assert!(cond.evaluate(&result));
147
148        let cond = Condition::Matches {
149            path: "email".to_string(),
150            pattern: r"^admin@".to_string(),
151        };
152        assert!(!cond.evaluate(&result));
153    }
154
155    #[test]
156    fn test_and_condition() {
157        let result = create_result(r#"{"status": "ok", "code": 200}"#, true);
158
159        let cond = Condition::And {
160            conditions: vec![
161                Condition::Success,
162                Condition::Contains {
163                    path: "status".to_string(),
164                    value: "ok".to_string(),
165                },
166            ],
167        };
168        assert!(cond.evaluate(&result));
169
170        let cond = Condition::And {
171            conditions: vec![
172                Condition::Success,
173                Condition::Contains {
174                    path: "status".to_string(),
175                    value: "error".to_string(),
176                },
177            ],
178        };
179        assert!(!cond.evaluate(&result));
180    }
181
182    #[test]
183    fn test_or_condition() {
184        let result = create_result(r#"{"status": "warning"}"#, true);
185
186        let cond = Condition::Or {
187            conditions: vec![
188                Condition::Contains {
189                    path: "status".to_string(),
190                    value: "ok".to_string(),
191                },
192                Condition::Contains {
193                    path: "status".to_string(),
194                    value: "warning".to_string(),
195                },
196            ],
197        };
198        assert!(cond.evaluate(&result));
199    }
200
201    #[test]
202    fn test_json_serialization() {
203        let cond = Condition::And {
204            conditions: vec![
205                Condition::Success,
206                Condition::Contains {
207                    path: "status".to_string(),
208                    value: "ok".to_string(),
209                },
210            ],
211        };
212
213        let json = serde_json::to_string(&cond).unwrap();
214        assert!(json.contains("\"type\":\"and\"") || json.contains("\"type\": \"and\""));
215
216        let deserialized: Condition = serde_json::from_str(&json).unwrap();
217        assert_eq!(cond, deserialized);
218    }
219
220    #[test]
221    fn test_condition_roundtrip() {
222        let conditions = vec![
223            Condition::Success,
224            Condition::Contains {
225                path: "status".to_string(),
226                value: "ok".to_string(),
227            },
228            Condition::Matches {
229                path: "email".to_string(),
230                pattern: r"^\S+@\S+\.\S+$".to_string(),
231            },
232            Condition::And {
233                conditions: vec![Condition::Success, Condition::Success],
234            },
235            Condition::Or {
236                conditions: vec![Condition::Success],
237            },
238        ];
239
240        for original in conditions {
241            let json = serde_json::to_string(&original).unwrap();
242            let deserialized: Condition = serde_json::from_str(&json).unwrap();
243            assert_eq!(original, deserialized);
244        }
245    }
246}