Skip to main content

enact_core/workflow/
contract.rs

1//! Step Contract - I/O validation and status parsing
2//!
3//! Implements the strict step I/O contract from Antfarm patterns:
4//! - STATUS: done|retry|blocked
5//! - Typed key/value outputs
6//! - Strict expects validation with actionable failure reasons
7
8use anyhow::{anyhow, Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Step execution status
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum StepStatus {
16    /// Step completed successfully
17    Done,
18    /// Step needs to be retried
19    Retry,
20    /// Step is blocked and requires escalation
21    Blocked,
22}
23
24impl StepStatus {
25    /// Parse status from string
26    pub fn parse(s: &str) -> Result<Self> {
27        match s.trim().to_lowercase().as_str() {
28            "done" => Ok(StepStatus::Done),
29            "retry" => Ok(StepStatus::Retry),
30            "blocked" => Ok(StepStatus::Blocked),
31            _ => Err(anyhow!(
32                "Invalid status: {}. Expected: done, retry, or blocked",
33                s
34            )),
35        }
36    }
37
38    /// Convert to string
39    pub fn as_str(&self) -> &'static str {
40        match self {
41            StepStatus::Done => "done",
42            StepStatus::Retry => "retry",
43            StepStatus::Blocked => "blocked",
44        }
45    }
46}
47
48/// Expected output field definition
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ExpectedField {
51    /// Field name (e.g., "REPO", "BRANCH")
52    pub name: String,
53    /// Field type
54    #[serde(rename = "type")]
55    pub field_type: FieldType,
56    /// Whether this field is required
57    #[serde(default = "default_true")]
58    pub required: bool,
59    /// Optional validation pattern (regex)
60    pub pattern: Option<String>,
61    /// Optional enum values for validation
62    #[serde(default)]
63    pub enum_values: Vec<String>,
64}
65
66fn default_true() -> bool {
67    true
68}
69
70/// Field types for validation
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "lowercase")]
73pub enum FieldType {
74    /// String value
75    String,
76    /// Integer value
77    Integer,
78    /// Floating point value
79    Float,
80    /// Boolean value
81    Boolean,
82    /// JSON object or array
83    Json,
84    /// Array of strings
85    StringArray,
86}
87
88/// Contract expectation definition
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ContractExpectation {
91    /// Expected status
92    pub status: StepStatus,
93    /// Expected output fields
94    #[serde(default)]
95    pub outputs: Vec<ExpectedField>,
96}
97
98/// Failure handling configuration
99#[derive(Debug, Clone, Serialize, Deserialize)]
100#[serde(tag = "action", rename_all = "snake_case")]
101pub enum FailureAction {
102    /// Retry the step
103    Retry {
104        /// Maximum number of retries
105        max_retries: u32,
106        /// Optional step to retry (defaults to current)
107        #[serde(default)]
108        retry_target: Option<String>,
109        /// Field name containing feedback for retry
110        #[serde(default)]
111        feedback_field: Option<String>,
112        /// Action when retries exhausted
113        #[serde(default)]
114        on_exhausted: Option<Box<FailureAction>>,
115    },
116    /// Escalate to human
117    Escalate {
118        /// Who to escalate to
119        to: String,
120    },
121    /// Skip to next step
122    Skip,
123    /// Fail the workflow
124    Fail,
125}
126
127/// Step contract definition
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct StepContract {
130    /// Expected outputs
131    pub expects: ContractExpectation,
132    /// Failure handling
133    #[serde(default)]
134    pub on_failure: Option<FailureAction>,
135}
136
137/// Parsed step output
138#[derive(Debug, Clone)]
139pub struct ParsedOutput {
140    /// Execution status
141    pub status: StepStatus,
142    /// Parsed key-value pairs
143    pub fields: HashMap<String, serde_json::Value>,
144    /// Raw output text (for debugging)
145    pub raw_output: String,
146}
147
148/// Contract parser for step outputs
149pub struct ContractParser;
150
151impl ContractParser {
152    /// Parse step output according to contract
153    pub fn parse(output: &str, contract: &StepContract) -> Result<ParsedOutput> {
154        // Extract STATUS line
155        let status = Self::extract_status(output)?;
156
157        // Parse key-value pairs
158        let fields = Self::parse_fields(output)?;
159
160        // Validate against contract
161        if status == contract.expects.status {
162            Self::validate_fields(&fields, &contract.expects.outputs)?;
163        }
164
165        Ok(ParsedOutput {
166            status,
167            fields,
168            raw_output: output.to_string(),
169        })
170    }
171
172    /// Extract STATUS from output
173    fn extract_status(output: &str) -> Result<StepStatus> {
174        for line in output.lines() {
175            let line = line.trim();
176            if let Some(status_str) = line.strip_prefix("STATUS:") {
177                return StepStatus::parse(status_str.trim());
178            }
179        }
180
181        Err(anyhow!(
182            "Missing STATUS field. Expected: STATUS: done|retry|blocked\n\nOutput:\n{}",
183            output
184        ))
185    }
186
187    /// Parse key-value pairs from output
188    fn parse_fields(output: &str) -> Result<HashMap<String, serde_json::Value>> {
189        let mut fields = HashMap::new();
190
191        for line in output.lines() {
192            let line = line.trim();
193
194            // Skip empty lines and comments
195            if line.is_empty() || line.starts_with('#') {
196                continue;
197            }
198
199            // Look for KEY: value pattern
200            if let Some(pos) = line.find(':') {
201                let key = line[..pos].trim();
202                let value = line[pos + 1..].trim();
203
204                // Skip if it's a status line (handled separately)
205                if key == "STATUS" {
206                    continue;
207                }
208
209                // Try to parse as JSON if it looks like JSON
210                let parsed_value = if (value.starts_with('[') && value.ends_with(']'))
211                    || (value.starts_with('{') && value.ends_with('}'))
212                {
213                    serde_json::from_str(value)
214                        .unwrap_or_else(|_| serde_json::Value::String(value.to_string()))
215                } else {
216                    serde_json::Value::String(value.to_string())
217                };
218
219                fields.insert(key.to_string(), parsed_value);
220            }
221        }
222
223        Ok(fields)
224    }
225
226    /// Validate fields against expected definitions
227    fn validate_fields(
228        fields: &HashMap<String, serde_json::Value>,
229        expected: &[ExpectedField],
230    ) -> Result<()> {
231        let mut errors = Vec::new();
232
233        for field_def in expected {
234            let field_name = &field_def.name;
235
236            match fields.get(field_name) {
237                Some(value) => {
238                    // Validate type
239                    if let Err(e) = Self::validate_type(value, &field_def.field_type) {
240                        errors.push(format!(
241                            "Field '{}' type mismatch: expected {}, got error: {}",
242                            field_name,
243                            format!("{:?}", field_def.field_type).to_lowercase(),
244                            e
245                        ));
246                    }
247
248                    // Validate pattern if specified
249                    if let Some(pattern) = &field_def.pattern {
250                        let value_str = match value {
251                            serde_json::Value::String(s) => s.clone(),
252                            other => other.to_string(),
253                        };
254
255                        let regex = regex::Regex::new(pattern).context("Invalid pattern regex")?;
256
257                        if !regex.is_match(&value_str) {
258                            errors.push(format!(
259                                "Field '{}' value '{}' does not match pattern '{}'",
260                                field_name, value_str, pattern
261                            ));
262                        }
263                    }
264
265                    // Validate enum values if specified
266                    if !field_def.enum_values.is_empty() {
267                        let value_str = match value {
268                            serde_json::Value::String(s) => s.clone(),
269                            other => other.to_string(),
270                        };
271
272                        if !field_def.enum_values.contains(&value_str) {
273                            errors.push(format!(
274                                "Field '{}' value '{}' not in allowed values: {:?}",
275                                field_name, value_str, field_def.enum_values
276                            ));
277                        }
278                    }
279                }
280                None => {
281                    if field_def.required {
282                        errors.push(format!(
283                            "Missing required field: {} (type: {:?})",
284                            field_name, field_def.field_type
285                        ));
286                    }
287                }
288            }
289        }
290
291        if errors.is_empty() {
292            Ok(())
293        } else {
294            Err(anyhow!(
295                "Contract validation failed:\n{}",
296                errors.join("\n")
297            ))
298        }
299    }
300
301    /// Validate a value matches expected type
302    fn validate_type(value: &serde_json::Value, expected: &FieldType) -> Result<()> {
303        match expected {
304            FieldType::String => {
305                if !value.is_string() {
306                    return Err(anyhow!("Expected string, got {}", value));
307                }
308            }
309            FieldType::Integer => {
310                if !value.is_i64() && !value.is_u64() {
311                    return Err(anyhow!("Expected integer, got {}", value));
312                }
313            }
314            FieldType::Float => {
315                if !value.is_f64() && !value.is_i64() && !value.is_u64() {
316                    return Err(anyhow!("Expected number, got {}", value));
317                }
318            }
319            FieldType::Boolean => {
320                if !value.is_boolean() {
321                    return Err(anyhow!("Expected boolean, got {}", value));
322                }
323            }
324            FieldType::Json => {
325                // Any JSON value is valid
326            }
327            FieldType::StringArray => {
328                if let serde_json::Value::Array(arr) = value {
329                    for (i, item) in arr.iter().enumerate() {
330                        if !item.is_string() {
331                            return Err(anyhow!(
332                                "Expected string array, but item {} is {}",
333                                i,
334                                item
335                            ));
336                        }
337                    }
338                } else {
339                    return Err(anyhow!("Expected array, got {}", value));
340                }
341            }
342        }
343
344        Ok(())
345    }
346
347    /// Get feedback field value from output
348    pub fn get_feedback(output: &str, field_name: &str) -> Option<String> {
349        let fields = Self::parse_fields(output).ok()?;
350        fields.get(field_name).map(|v| v.to_string())
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_parse_status() {
360        let output = "STATUS: done\nREPO: /path/to/repo";
361        let status = ContractParser::extract_status(output).unwrap();
362        assert_eq!(status, StepStatus::Done);
363
364        let output = "STATUS: retry\nISSUES: something went wrong";
365        let status = ContractParser::extract_status(output).unwrap();
366        assert_eq!(status, StepStatus::Retry);
367
368        let output = "STATUS: blocked\nREASON: need permission";
369        let status = ContractParser::extract_status(output).unwrap();
370        assert_eq!(status, StepStatus::Blocked);
371    }
372
373    #[test]
374    fn test_parse_fields() {
375        let output = r#"
376STATUS: done
377REPO: /path/to/repo
378BRANCH: feature-branch
379COUNT: 42
380"#;
381
382        let fields = ContractParser::parse_fields(output).unwrap();
383        assert_eq!(
384            fields.get("REPO").unwrap().as_str().unwrap(),
385            "/path/to/repo"
386        );
387        assert_eq!(
388            fields.get("BRANCH").unwrap().as_str().unwrap(),
389            "feature-branch"
390        );
391        assert_eq!(fields.get("COUNT").unwrap().as_str().unwrap(), "42");
392    }
393
394    #[test]
395    fn test_parse_json_field() {
396        let output = r#"
397STATUS: done
398STORIES_JSON: [{"id": 1, "title": "Story 1"}, {"id": 2, "title": "Story 2"}]
399"#;
400
401        let fields = ContractParser::parse_fields(output).unwrap();
402        let stories = fields.get("STORIES_JSON").unwrap();
403        assert!(stories.is_array());
404        assert_eq!(stories.as_array().unwrap().len(), 2);
405    }
406
407    #[test]
408    fn test_validate_contract() {
409        let contract = StepContract {
410            expects: ContractExpectation {
411                status: StepStatus::Done,
412                outputs: vec![
413                    ExpectedField {
414                        name: "REPO".to_string(),
415                        field_type: FieldType::String,
416                        required: true,
417                        pattern: None,
418                        enum_values: vec![],
419                    },
420                    ExpectedField {
421                        name: "BRANCH".to_string(),
422                        field_type: FieldType::String,
423                        required: true,
424                        pattern: None,
425                        enum_values: vec![],
426                    },
427                ],
428            },
429            on_failure: None,
430        };
431
432        let output = r#"
433STATUS: done
434REPO: /path/to/repo
435BRANCH: feature-branch
436"#;
437
438        let result = ContractParser::parse(output, &contract);
439        assert!(result.is_ok());
440
441        let parsed = result.unwrap();
442        assert_eq!(parsed.status, StepStatus::Done);
443        assert_eq!(
444            parsed.fields.get("REPO").unwrap().as_str().unwrap(),
445            "/path/to/repo"
446        );
447    }
448
449    #[test]
450    fn test_validate_missing_field() {
451        let contract = StepContract {
452            expects: ContractExpectation {
453                status: StepStatus::Done,
454                outputs: vec![
455                    ExpectedField {
456                        name: "REPO".to_string(),
457                        field_type: FieldType::String,
458                        required: true,
459                        pattern: None,
460                        enum_values: vec![],
461                    },
462                    ExpectedField {
463                        name: "BRANCH".to_string(),
464                        field_type: FieldType::String,
465                        required: true,
466                        pattern: None,
467                        enum_values: vec![],
468                    },
469                ],
470            },
471            on_failure: None,
472        };
473
474        let output = r#"
475STATUS: done
476REPO: /path/to/repo
477"#;
478
479        let result = ContractParser::parse(output, &contract);
480        assert!(result.is_err());
481        assert!(result.unwrap_err().to_string().contains("BRANCH"));
482    }
483
484    #[test]
485    fn test_get_feedback() {
486        let output = r#"
487STATUS: retry
488ISSUES: The test is failing due to missing imports
489"#;
490
491        let feedback = ContractParser::get_feedback(output, "ISSUES");
492        assert!(feedback.is_some());
493        assert!(feedback.unwrap().contains("missing imports"));
494    }
495}