Skip to main content

punch_kernel/
workflow_conditions.rs

1//! Conditional branching for workflow steps.
2//!
3//! Each workflow step may carry an optional [`Condition`] that determines
4//! whether the step should execute based on results from prior steps.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::workflow::StepResult;
10
11/// A condition that gates step execution.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13#[serde(rename_all = "snake_case")]
14pub enum Condition {
15    /// Always execute.
16    Always,
17    /// Execute only if the named step's output contains the given substring.
18    IfOutput { step: String, contains: String },
19    /// Execute only if the named step completed successfully.
20    IfSuccess { step: String },
21    /// Execute only if the named step failed.
22    IfFailure { step: String },
23    /// Simple expression evaluator (supports basic boolean logic).
24    Expression(String),
25}
26
27/// Evaluate a [`Condition`] against the current set of completed step results.
28///
29/// Returns `true` if the step should execute, `false` if it should be skipped.
30pub fn evaluate_condition(
31    condition: &Condition,
32    step_results: &HashMap<String, StepResult>,
33) -> bool {
34    match condition {
35        Condition::Always => true,
36        Condition::IfOutput { step, contains } => step_results
37            .get(step)
38            .map(|r| r.response.contains(contains.as_str()))
39            .unwrap_or(false),
40        Condition::IfSuccess { step } => step_results
41            .get(step)
42            .map(|r| r.error.is_none())
43            .unwrap_or(false),
44        Condition::IfFailure { step } => step_results
45            .get(step)
46            .map(|r| r.error.is_some())
47            .unwrap_or(false),
48        Condition::Expression(expr) => evaluate_expression(expr, step_results),
49    }
50}
51
52/// Evaluate a simple boolean expression.
53///
54/// Supports:
55/// - `step_name.success` — true if step succeeded
56/// - `step_name.failed` — true if step failed
57/// - `step_name.output contains "text"` — true if output contains text
58/// - `not <expr>` — negation
59/// - `<expr> and <expr>` — conjunction
60/// - `<expr> or <expr>` — disjunction
61/// - `true` / `false` — literals
62fn evaluate_expression(expr: &str, step_results: &HashMap<String, StepResult>) -> bool {
63    let expr = expr.trim();
64
65    // Handle `true` / `false` literals
66    if expr.eq_ignore_ascii_case("true") {
67        return true;
68    }
69    if expr.eq_ignore_ascii_case("false") {
70        return false;
71    }
72
73    // Handle `not` prefix
74    if let Some(rest) = expr.strip_prefix("not ") {
75        return !evaluate_expression(rest, step_results);
76    }
77
78    // Handle `and` (lowest precedence after `or`)
79    // We split on ` or ` first (lower precedence)
80    if let Some(pos) = expr.find(" or ") {
81        let left = &expr[..pos];
82        let right = &expr[pos + 4..];
83        return evaluate_expression(left, step_results) || evaluate_expression(right, step_results);
84    }
85
86    // Then split on ` and `
87    if let Some(pos) = expr.find(" and ") {
88        let left = &expr[..pos];
89        let right = &expr[pos + 5..];
90        return evaluate_expression(left, step_results) && evaluate_expression(right, step_results);
91    }
92
93    // Handle `step_name.success`
94    if let Some(step_name) = expr.strip_suffix(".success") {
95        return step_results
96            .get(step_name)
97            .map(|r| r.error.is_none())
98            .unwrap_or(false);
99    }
100
101    // Handle `step_name.failed`
102    if let Some(step_name) = expr.strip_suffix(".failed") {
103        return step_results
104            .get(step_name)
105            .map(|r| r.error.is_some())
106            .unwrap_or(false);
107    }
108
109    // Handle `step_name.output contains "text"`
110    if let Some(contains_pos) = expr.find(".output contains ") {
111        let step_name = &expr[..contains_pos];
112        let rest = &expr[contains_pos + ".output contains ".len()..];
113        let text = rest.trim_matches('"');
114        return step_results
115            .get(step_name)
116            .map(|r| r.response.contains(text))
117            .unwrap_or(false);
118    }
119
120    // Unknown expression — default to false
121    false
122}
123
124// ---------------------------------------------------------------------------
125// Tests
126// ---------------------------------------------------------------------------
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    fn make_results() -> HashMap<String, StepResult> {
133        let mut results = HashMap::new();
134        results.insert(
135            "analyze".to_string(),
136            StepResult {
137                step_name: "analyze".to_string(),
138                response: "The code has 3 bugs and needs refactoring".to_string(),
139                tokens_used: 100,
140                duration_ms: 500,
141                error: None,
142                status: crate::workflow::StepStatus::Completed,
143                started_at: None,
144                completed_at: None,
145            },
146        );
147        results.insert(
148            "build".to_string(),
149            StepResult {
150                step_name: "build".to_string(),
151                response: String::new(),
152                tokens_used: 0,
153                duration_ms: 200,
154                error: Some("compilation failed".to_string()),
155                status: crate::workflow::StepStatus::Failed,
156                started_at: None,
157                completed_at: None,
158            },
159        );
160        results
161    }
162
163    #[test]
164    fn condition_always() {
165        let results = make_results();
166        assert!(evaluate_condition(&Condition::Always, &results));
167    }
168
169    #[test]
170    fn condition_if_output_match() {
171        let results = make_results();
172        let cond = Condition::IfOutput {
173            step: "analyze".to_string(),
174            contains: "bugs".to_string(),
175        };
176        assert!(evaluate_condition(&cond, &results));
177    }
178
179    #[test]
180    fn condition_if_output_no_match() {
181        let results = make_results();
182        let cond = Condition::IfOutput {
183            step: "analyze".to_string(),
184            contains: "perfect".to_string(),
185        };
186        assert!(!evaluate_condition(&cond, &results));
187    }
188
189    #[test]
190    fn condition_if_output_missing_step() {
191        let results = make_results();
192        let cond = Condition::IfOutput {
193            step: "nonexistent".to_string(),
194            contains: "anything".to_string(),
195        };
196        assert!(!evaluate_condition(&cond, &results));
197    }
198
199    #[test]
200    fn condition_if_success() {
201        let results = make_results();
202        assert!(evaluate_condition(
203            &Condition::IfSuccess {
204                step: "analyze".to_string()
205            },
206            &results
207        ));
208        assert!(!evaluate_condition(
209            &Condition::IfSuccess {
210                step: "build".to_string()
211            },
212            &results
213        ));
214    }
215
216    #[test]
217    fn condition_if_failure() {
218        let results = make_results();
219        assert!(!evaluate_condition(
220            &Condition::IfFailure {
221                step: "analyze".to_string()
222            },
223            &results
224        ));
225        assert!(evaluate_condition(
226            &Condition::IfFailure {
227                step: "build".to_string()
228            },
229            &results
230        ));
231    }
232
233    #[test]
234    fn condition_if_success_missing_step() {
235        let results = make_results();
236        assert!(!evaluate_condition(
237            &Condition::IfSuccess {
238                step: "nonexistent".to_string()
239            },
240            &results
241        ));
242    }
243
244    #[test]
245    fn expression_true_false_literals() {
246        let results = HashMap::new();
247        assert!(evaluate_condition(
248            &Condition::Expression("true".to_string()),
249            &results
250        ));
251        assert!(!evaluate_condition(
252            &Condition::Expression("false".to_string()),
253            &results
254        ));
255    }
256
257    #[test]
258    fn expression_step_success() {
259        let results = make_results();
260        assert!(evaluate_condition(
261            &Condition::Expression("analyze.success".to_string()),
262            &results
263        ));
264        assert!(!evaluate_condition(
265            &Condition::Expression("build.success".to_string()),
266            &results
267        ));
268    }
269
270    #[test]
271    fn expression_step_failed() {
272        let results = make_results();
273        assert!(evaluate_condition(
274            &Condition::Expression("build.failed".to_string()),
275            &results
276        ));
277        assert!(!evaluate_condition(
278            &Condition::Expression("analyze.failed".to_string()),
279            &results
280        ));
281    }
282
283    #[test]
284    fn expression_not() {
285        let results = make_results();
286        assert!(!evaluate_condition(
287            &Condition::Expression("not analyze.success".to_string()),
288            &results
289        ));
290        assert!(evaluate_condition(
291            &Condition::Expression("not build.success".to_string()),
292            &results
293        ));
294    }
295
296    #[test]
297    fn expression_and() {
298        let results = make_results();
299        assert!(!evaluate_condition(
300            &Condition::Expression("analyze.success and build.success".to_string()),
301            &results
302        ));
303        assert!(evaluate_condition(
304            &Condition::Expression("analyze.success and build.failed".to_string()),
305            &results
306        ));
307    }
308
309    #[test]
310    fn expression_or() {
311        let results = make_results();
312        assert!(evaluate_condition(
313            &Condition::Expression("analyze.success or build.success".to_string()),
314            &results
315        ));
316        assert!(!evaluate_condition(
317            &Condition::Expression("analyze.failed or build.success".to_string()),
318            &results
319        ));
320    }
321
322    #[test]
323    fn expression_output_contains() {
324        let results = make_results();
325        assert!(evaluate_condition(
326            &Condition::Expression("analyze.output contains \"3 bugs\"".to_string()),
327            &results
328        ));
329        assert!(!evaluate_condition(
330            &Condition::Expression("analyze.output contains \"no issues\"".to_string()),
331            &results
332        ));
333    }
334
335    #[test]
336    fn expression_unknown_defaults_false() {
337        let results = make_results();
338        assert!(!evaluate_condition(
339            &Condition::Expression("unknown_garbage".to_string()),
340            &results
341        ));
342    }
343
344    #[test]
345    fn condition_serialization_roundtrip() {
346        let cond = Condition::IfOutput {
347            step: "step1".to_string(),
348            contains: "hello".to_string(),
349        };
350        let json = serde_json::to_string(&cond).expect("serialize");
351        let deser: Condition = serde_json::from_str(&json).expect("deserialize");
352        assert_eq!(cond, deser);
353    }
354
355    #[test]
356    fn condition_always_serialization() {
357        let cond = Condition::Always;
358        let json = serde_json::to_string(&cond).expect("serialize");
359        let deser: Condition = serde_json::from_str(&json).expect("deserialize");
360        assert_eq!(cond, deser);
361    }
362
363    #[test]
364    fn condition_expression_serialization() {
365        let cond = Condition::Expression("step1.success and step2.failed".to_string());
366        let json = serde_json::to_string(&cond).expect("serialize");
367        let deser: Condition = serde_json::from_str(&json).expect("deserialize");
368        assert_eq!(cond, deser);
369    }
370}