ricecoder_workflows/
condition.rs

1//! Condition evaluation for conditional branching steps
2
3use crate::error::{WorkflowError, WorkflowResult};
4use crate::models::{ConditionStep, Workflow, WorkflowState};
5use serde_json::Value;
6
7/// Evaluates conditions and determines branching paths
8///
9/// Handles:
10/// - Evaluating condition expressions based on previous step results
11/// - Determining which branch (then or else) to execute
12/// - Supporting nested conditions
13pub struct ConditionEvaluator;
14
15impl ConditionEvaluator {
16    /// Evaluate a condition step and determine the execution path
17    ///
18    /// Returns the list of step IDs to execute based on the condition evaluation.
19    /// If condition is true, returns then_steps; otherwise returns else_steps.
20    pub fn evaluate_condition(
21        workflow: &Workflow,
22        state: &WorkflowState,
23        condition_step: &ConditionStep,
24    ) -> WorkflowResult<Vec<String>> {
25        // Evaluate the condition expression
26        let result = Self::evaluate_expression(&condition_step.condition, workflow, state)?;
27
28        // Return appropriate branch based on result
29        if result {
30            Ok(condition_step.then_steps.clone())
31        } else {
32            Ok(condition_step.else_steps.clone())
33        }
34    }
35
36    /// Evaluate a condition expression
37    ///
38    /// Supports simple expressions like:
39    /// - "step_id.output.field == value"
40    /// - "step_id.status == 'completed'"
41    /// - "step_id.output.count > 5"
42    ///
43    /// Returns true if the condition is satisfied, false otherwise.
44    fn evaluate_expression(
45        expression: &str,
46        workflow: &Workflow,
47        state: &WorkflowState,
48    ) -> WorkflowResult<bool> {
49        let expression = expression.trim();
50
51        // Handle not equal (check before ==)
52        if let Some(pos) = expression.find("!=") {
53            let left = expression[..pos].trim();
54            let right = expression[pos + 2..].trim();
55            let equal = Self::evaluate_equality(left, right, workflow, state)?;
56            return Ok(!equal);
57        }
58
59        // Handle simple equality comparisons
60        if let Some(pos) = expression.find("==") {
61            let left = expression[..pos].trim();
62            let right = expression[pos + 2..].trim();
63            return Self::evaluate_equality(left, right, workflow, state);
64        }
65
66        // Handle greater than or equal (check before >)
67        if let Some(pos) = expression.find(">=") {
68            let left = expression[..pos].trim();
69            let right = expression[pos + 2..].trim();
70            return Self::evaluate_greater_equal(left, right, workflow, state);
71        }
72
73        // Handle less than or equal (check before <)
74        if let Some(pos) = expression.find("<=") {
75            let left = expression[..pos].trim();
76            let right = expression[pos + 2..].trim();
77            return Self::evaluate_less_equal(left, right, workflow, state);
78        }
79
80        // Handle greater than comparisons
81        if let Some(pos) = expression.find('>') {
82            let left = expression[..pos].trim();
83            let right = expression[pos + 1..].trim();
84            return Self::evaluate_greater_than(left, right, workflow, state);
85        }
86
87        // Handle less than comparisons
88        if let Some(pos) = expression.find('<') {
89            let left = expression[..pos].trim();
90            let right = expression[pos + 1..].trim();
91            return Self::evaluate_less_than(left, right, workflow, state);
92        }
93
94        Err(WorkflowError::Invalid(format!(
95            "Unsupported condition expression: {}",
96            expression
97        )))
98    }
99
100    /// Evaluate equality comparison
101    fn evaluate_equality(
102        left: &str,
103        right: &str,
104        workflow: &Workflow,
105        state: &WorkflowState,
106    ) -> WorkflowResult<bool> {
107        let left_value = Self::resolve_value(left, workflow, state)?;
108        let right_value = Self::parse_value(right);
109
110        Ok(left_value == right_value)
111    }
112
113    /// Evaluate greater than comparison
114    fn evaluate_greater_than(
115        left: &str,
116        right: &str,
117        workflow: &Workflow,
118        state: &WorkflowState,
119    ) -> WorkflowResult<bool> {
120        let left_value = Self::resolve_value(left, workflow, state)?;
121        let right_value = Self::parse_value(right);
122
123        match (left_value, right_value) {
124            (Value::Number(l), Value::Number(r)) => {
125                let l_f64 = l.as_f64().unwrap_or(0.0);
126                let r_f64 = r.as_f64().unwrap_or(0.0);
127                Ok(l_f64 > r_f64)
128            }
129            _ => Err(WorkflowError::Invalid(
130                "Cannot compare non-numeric values with >".to_string(),
131            )),
132        }
133    }
134
135    /// Evaluate less than comparison
136    fn evaluate_less_than(
137        left: &str,
138        right: &str,
139        workflow: &Workflow,
140        state: &WorkflowState,
141    ) -> WorkflowResult<bool> {
142        let left_value = Self::resolve_value(left, workflow, state)?;
143        let right_value = Self::parse_value(right);
144
145        match (left_value, right_value) {
146            (Value::Number(l), Value::Number(r)) => {
147                let l_f64 = l.as_f64().unwrap_or(0.0);
148                let r_f64 = r.as_f64().unwrap_or(0.0);
149                Ok(l_f64 < r_f64)
150            }
151            _ => Err(WorkflowError::Invalid(
152                "Cannot compare non-numeric values with <".to_string(),
153            )),
154        }
155    }
156
157    /// Evaluate greater than or equal comparison
158    fn evaluate_greater_equal(
159        left: &str,
160        right: &str,
161        workflow: &Workflow,
162        state: &WorkflowState,
163    ) -> WorkflowResult<bool> {
164        let left_value = Self::resolve_value(left, workflow, state)?;
165        let right_value = Self::parse_value(right);
166
167        match (left_value, right_value) {
168            (Value::Number(l), Value::Number(r)) => {
169                let l_f64 = l.as_f64().unwrap_or(0.0);
170                let r_f64 = r.as_f64().unwrap_or(0.0);
171                Ok(l_f64 >= r_f64)
172            }
173            _ => Err(WorkflowError::Invalid(
174                "Cannot compare non-numeric values with >=".to_string(),
175            )),
176        }
177    }
178
179    /// Evaluate less than or equal comparison
180    fn evaluate_less_equal(
181        left: &str,
182        right: &str,
183        workflow: &Workflow,
184        state: &WorkflowState,
185    ) -> WorkflowResult<bool> {
186        let left_value = Self::resolve_value(left, workflow, state)?;
187        let right_value = Self::parse_value(right);
188
189        match (left_value, right_value) {
190            (Value::Number(l), Value::Number(r)) => {
191                let l_f64 = l.as_f64().unwrap_or(0.0);
192                let r_f64 = r.as_f64().unwrap_or(0.0);
193                Ok(l_f64 <= r_f64)
194            }
195            _ => Err(WorkflowError::Invalid(
196                "Cannot compare non-numeric values with <=".to_string(),
197            )),
198        }
199    }
200
201    /// Resolve a value reference (e.g., "step_id.output.field")
202    fn resolve_value(
203        reference: &str,
204        _workflow: &Workflow,
205        state: &WorkflowState,
206    ) -> WorkflowResult<Value> {
207        let parts: Vec<&str> = reference.split('.').collect();
208
209        if parts.is_empty() {
210            return Err(WorkflowError::Invalid(
211                "Invalid value reference".to_string(),
212            ));
213        }
214
215        let step_id = parts[0];
216
217        // Get the step result
218        let step_result = state.step_results.get(step_id).ok_or_else(|| {
219            WorkflowError::StateError(format!("Step {} has not been executed", step_id))
220        })?;
221
222        // Start with null
223        let mut value = Value::Null;
224        let mut is_first = true;
225
226        // Navigate through the path
227        for (i, part) in parts.iter().enumerate() {
228            if i == 0 {
229                // Skip the step_id itself
230                continue;
231            }
232
233            if part.is_empty() {
234                continue;
235            }
236
237            // Handle special fields of the step result (only on first access after step_id)
238            if is_first && *part == "output" {
239                value = step_result.output.clone().unwrap_or(Value::Null);
240                is_first = false;
241            } else if is_first && *part == "status" {
242                value = Value::String(format!("{:?}", step_result.status));
243                is_first = false;
244            } else if is_first && *part == "error" {
245                value = step_result
246                    .error
247                    .as_ref()
248                    .map(|e| Value::String(e.clone()))
249                    .unwrap_or(Value::Null);
250                is_first = false;
251            } else if is_first && *part == "duration_ms" {
252                value = Value::Number(serde_json::Number::from(step_result.duration_ms));
253                is_first = false;
254            } else {
255                // Navigate through the JSON object
256                is_first = false;
257                // Handle array indexing like "field[0]"
258                if let Some(bracket_pos) = part.find('[') {
259                    let field_name = &part[..bracket_pos];
260                    let index_str = &part[bracket_pos + 1..part.len() - 1];
261
262                    if let Ok(index) = index_str.parse::<usize>() {
263                        value = value[field_name][index].clone();
264                    } else {
265                        return Err(WorkflowError::Invalid(format!(
266                            "Invalid array index: {}",
267                            index_str
268                        )));
269                    }
270                } else {
271                    value = value[part].clone();
272                }
273            }
274        }
275
276        Ok(value)
277    }
278
279    /// Parse a literal value (e.g., "5", "'completed'", "true")
280    fn parse_value(value_str: &str) -> Value {
281        let trimmed = value_str.trim();
282
283        // Handle string literals (quoted)
284        if (trimmed.starts_with('\'') && trimmed.ends_with('\''))
285            || (trimmed.starts_with('"') && trimmed.ends_with('"'))
286        {
287            let unquoted = &trimmed[1..trimmed.len() - 1];
288            return Value::String(unquoted.to_string());
289        }
290
291        // Handle boolean
292        if trimmed == "true" {
293            return Value::Bool(true);
294        }
295        if trimmed == "false" {
296            return Value::Bool(false);
297        }
298
299        // Handle null
300        if trimmed == "null" {
301            return Value::Null;
302        }
303
304        // Handle numbers
305        if let Ok(int_val) = trimmed.parse::<i64>() {
306            return Value::Number(serde_json::Number::from(int_val));
307        }
308
309        if let Ok(float_val) = trimmed.parse::<f64>() {
310            if let Some(num) = serde_json::Number::from_f64(float_val) {
311                return Value::Number(num);
312            }
313        }
314
315        // Default to string
316        Value::String(trimmed.to_string())
317    }
318
319    /// Get the next steps to execute after a condition
320    ///
321    /// Returns the list of step IDs that should be executed based on the condition result.
322    pub fn get_next_steps(
323        workflow: &Workflow,
324        state: &WorkflowState,
325        condition_step: &ConditionStep,
326    ) -> WorkflowResult<Vec<String>> {
327        Self::evaluate_condition(workflow, state, condition_step)
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::models::{
335        AgentStep, ErrorAction, RiskFactors, StepConfig, StepStatus, StepType, WorkflowConfig,
336        WorkflowStep,
337    };
338
339    fn create_test_workflow() -> Workflow {
340        Workflow {
341            id: "test-workflow".to_string(),
342            name: "Test Workflow".to_string(),
343            description: "A test workflow".to_string(),
344            parameters: vec![],
345            steps: vec![
346                WorkflowStep {
347                    id: "step1".to_string(),
348                    name: "Step 1".to_string(),
349                    step_type: StepType::Agent(AgentStep {
350                        agent_id: "test-agent".to_string(),
351                        task: "test-task".to_string(),
352                    }),
353                    config: StepConfig {
354                        config: serde_json::json!({}),
355                    },
356                    dependencies: vec![],
357                    approval_required: false,
358                    on_error: ErrorAction::Fail,
359                    risk_score: None,
360                    risk_factors: RiskFactors::default(),
361                },
362                WorkflowStep {
363                    id: "condition".to_string(),
364                    name: "Condition".to_string(),
365                    step_type: StepType::Condition(ConditionStep {
366                        condition: "step1.output.count > 5".to_string(),
367                        then_steps: vec!["step2".to_string()],
368                        else_steps: vec!["step3".to_string()],
369                    }),
370                    config: StepConfig {
371                        config: serde_json::json!({}),
372                    },
373                    dependencies: vec!["step1".to_string()],
374                    approval_required: false,
375                    on_error: ErrorAction::Fail,
376                    risk_score: None,
377                    risk_factors: RiskFactors::default(),
378                },
379                WorkflowStep {
380                    id: "step2".to_string(),
381                    name: "Step 2".to_string(),
382                    step_type: StepType::Agent(AgentStep {
383                        agent_id: "test-agent".to_string(),
384                        task: "test-task".to_string(),
385                    }),
386                    config: StepConfig {
387                        config: serde_json::json!({}),
388                    },
389                    dependencies: vec!["condition".to_string()],
390                    approval_required: false,
391                    on_error: ErrorAction::Fail,
392                    risk_score: None,
393                    risk_factors: RiskFactors::default(),
394                },
395                WorkflowStep {
396                    id: "step3".to_string(),
397                    name: "Step 3".to_string(),
398                    step_type: StepType::Agent(AgentStep {
399                        agent_id: "test-agent".to_string(),
400                        task: "test-task".to_string(),
401                    }),
402                    config: StepConfig {
403                        config: serde_json::json!({}),
404                    },
405                    dependencies: vec!["condition".to_string()],
406                    approval_required: false,
407                    on_error: ErrorAction::Fail,
408                    risk_score: None,
409                    risk_factors: RiskFactors::default(),
410                },
411            ],
412            config: WorkflowConfig {
413                timeout_ms: None,
414                max_parallel: None,
415            },
416        }
417    }
418
419    #[test]
420    fn test_parse_value_string() {
421        let value = ConditionEvaluator::parse_value("'hello'");
422        assert_eq!(value, Value::String("hello".to_string()));
423    }
424
425    #[test]
426    fn test_parse_value_number() {
427        let value = ConditionEvaluator::parse_value("42");
428        assert_eq!(value.as_i64(), Some(42));
429    }
430
431    #[test]
432    fn test_parse_value_boolean() {
433        let value = ConditionEvaluator::parse_value("true");
434        assert_eq!(value, Value::Bool(true));
435    }
436
437    #[test]
438    fn test_parse_value_null() {
439        let value = ConditionEvaluator::parse_value("null");
440        assert_eq!(value, Value::Null);
441    }
442
443    #[test]
444    fn test_evaluate_equality_true() {
445        let workflow = create_test_workflow();
446        let mut state = crate::state::StateManager::create_state(&workflow);
447
448        // Add a step result with output
449        state.step_results.insert(
450            "step1".to_string(),
451            crate::models::StepResult {
452                status: StepStatus::Completed,
453                output: Some(serde_json::json!({"status": "completed"})),
454                error: None,
455                duration_ms: 100,
456            },
457        );
458
459        let result = ConditionEvaluator::evaluate_equality(
460            "step1.output.status",
461            "'completed'",
462            &workflow,
463            &state,
464        );
465        assert!(result.is_ok());
466        assert!(result.unwrap());
467    }
468
469    #[test]
470    fn test_evaluate_equality_false() {
471        let workflow = create_test_workflow();
472        let mut state = crate::state::StateManager::create_state(&workflow);
473
474        state.step_results.insert(
475            "step1".to_string(),
476            crate::models::StepResult {
477                status: StepStatus::Completed,
478                output: Some(serde_json::json!({"status": "failed"})),
479                error: None,
480                duration_ms: 100,
481            },
482        );
483
484        let result = ConditionEvaluator::evaluate_equality(
485            "step1.output.status",
486            "'completed'",
487            &workflow,
488            &state,
489        );
490        assert!(result.is_ok());
491        assert!(!result.unwrap());
492    }
493
494    #[test]
495    fn test_evaluate_greater_than_true() {
496        let workflow = create_test_workflow();
497        let mut state = crate::state::StateManager::create_state(&workflow);
498
499        state.step_results.insert(
500            "step1".to_string(),
501            crate::models::StepResult {
502                status: StepStatus::Completed,
503                output: Some(serde_json::json!({"count": 10})),
504                error: None,
505                duration_ms: 100,
506            },
507        );
508
509        let result =
510            ConditionEvaluator::evaluate_greater_than("step1.output.count", "5", &workflow, &state);
511        assert!(result.is_ok());
512        assert!(result.unwrap());
513    }
514
515    #[test]
516    fn test_evaluate_greater_than_false() {
517        let workflow = create_test_workflow();
518        let mut state = crate::state::StateManager::create_state(&workflow);
519
520        state.step_results.insert(
521            "step1".to_string(),
522            crate::models::StepResult {
523                status: StepStatus::Completed,
524                output: Some(serde_json::json!({"count": 3})),
525                error: None,
526                duration_ms: 100,
527            },
528        );
529
530        let result =
531            ConditionEvaluator::evaluate_greater_than("step1.output.count", "5", &workflow, &state);
532        assert!(result.is_ok());
533        assert!(!result.unwrap());
534    }
535
536    #[test]
537    fn test_evaluate_condition_then_branch() {
538        let workflow = create_test_workflow();
539        let mut state = crate::state::StateManager::create_state(&workflow);
540
541        state.step_results.insert(
542            "step1".to_string(),
543            crate::models::StepResult {
544                status: StepStatus::Completed,
545                output: Some(serde_json::json!({"count": 10})),
546                error: None,
547                duration_ms: 100,
548            },
549        );
550
551        let condition_step = ConditionStep {
552            condition: "step1.output.count > 5".to_string(),
553            then_steps: vec!["step2".to_string()],
554            else_steps: vec!["step3".to_string()],
555        };
556
557        let result = ConditionEvaluator::evaluate_condition(&workflow, &state, &condition_step);
558        assert!(result.is_ok());
559        assert_eq!(result.unwrap(), vec!["step2".to_string()]);
560    }
561
562    #[test]
563    fn test_evaluate_condition_else_branch() {
564        let workflow = create_test_workflow();
565        let mut state = crate::state::StateManager::create_state(&workflow);
566
567        state.step_results.insert(
568            "step1".to_string(),
569            crate::models::StepResult {
570                status: StepStatus::Completed,
571                output: Some(serde_json::json!({"count": 3})),
572                error: None,
573                duration_ms: 100,
574            },
575        );
576
577        let condition_step = ConditionStep {
578            condition: "step1.output.count > 5".to_string(),
579            then_steps: vec!["step2".to_string()],
580            else_steps: vec!["step3".to_string()],
581        };
582
583        let result = ConditionEvaluator::evaluate_condition(&workflow, &state, &condition_step);
584        assert!(result.is_ok());
585        assert_eq!(result.unwrap(), vec!["step3".to_string()]);
586    }
587
588    #[test]
589    fn test_evaluate_not_equal() {
590        let workflow = create_test_workflow();
591        let mut state = crate::state::StateManager::create_state(&workflow);
592
593        state.step_results.insert(
594            "step1".to_string(),
595            crate::models::StepResult {
596                status: StepStatus::Completed,
597                output: Some(serde_json::json!({"status": "failed"})),
598                error: None,
599                duration_ms: 100,
600            },
601        );
602
603        let result = ConditionEvaluator::evaluate_expression(
604            "step1.output.status != 'completed'",
605            &workflow,
606            &state,
607        );
608        assert!(result.is_ok());
609        assert!(result.unwrap());
610    }
611
612    #[test]
613    fn test_evaluate_less_than() {
614        let workflow = create_test_workflow();
615        let mut state = crate::state::StateManager::create_state(&workflow);
616
617        state.step_results.insert(
618            "step1".to_string(),
619            crate::models::StepResult {
620                status: StepStatus::Completed,
621                output: Some(serde_json::json!({"count": 3})),
622                error: None,
623                duration_ms: 100,
624            },
625        );
626
627        let result =
628            ConditionEvaluator::evaluate_expression("step1.output.count < 5", &workflow, &state);
629        assert!(result.is_ok());
630        assert!(result.unwrap());
631    }
632
633    #[test]
634    fn test_evaluate_greater_equal() {
635        let workflow = create_test_workflow();
636        let mut state = crate::state::StateManager::create_state(&workflow);
637
638        state.step_results.insert(
639            "step1".to_string(),
640            crate::models::StepResult {
641                status: StepStatus::Completed,
642                output: Some(serde_json::json!({"count": 5})),
643                error: None,
644                duration_ms: 100,
645            },
646        );
647
648        let result =
649            ConditionEvaluator::evaluate_expression("step1.output.count >= 5", &workflow, &state);
650        assert!(result.is_ok());
651        assert!(result.unwrap());
652    }
653
654    #[test]
655    fn test_evaluate_less_equal() {
656        let workflow = create_test_workflow();
657        let mut state = crate::state::StateManager::create_state(&workflow);
658
659        state.step_results.insert(
660            "step1".to_string(),
661            crate::models::StepResult {
662                status: StepStatus::Completed,
663                output: Some(serde_json::json!({"count": 5})),
664                error: None,
665                duration_ms: 100,
666            },
667        );
668
669        let result =
670            ConditionEvaluator::evaluate_expression("step1.output.count <= 5", &workflow, &state);
671        assert!(result.is_ok());
672        assert!(result.unwrap());
673    }
674}