1use anyhow::{anyhow, Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum StepStatus {
16 Done,
18 Retry,
20 Blocked,
22}
23
24impl StepStatus {
25 pub fn parse(s: &str) -> Result<Self> {
27 match s.trim().to_lowercase().as_str() {
28 "done" => Ok(StepStatus::Done),
29 "retry" => Ok(StepStatus::Retry),
30 "blocked" => Ok(StepStatus::Blocked),
31 _ => Err(anyhow!(
32 "Invalid status: {}. Expected: done, retry, or blocked",
33 s
34 )),
35 }
36 }
37
38 pub fn as_str(&self) -> &'static str {
40 match self {
41 StepStatus::Done => "done",
42 StepStatus::Retry => "retry",
43 StepStatus::Blocked => "blocked",
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ExpectedField {
51 pub name: String,
53 #[serde(rename = "type")]
55 pub field_type: FieldType,
56 #[serde(default = "default_true")]
58 pub required: bool,
59 pub pattern: Option<String>,
61 #[serde(default)]
63 pub enum_values: Vec<String>,
64}
65
66fn default_true() -> bool {
67 true
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "lowercase")]
73pub enum FieldType {
74 String,
76 Integer,
78 Float,
80 Boolean,
82 Json,
84 StringArray,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ContractExpectation {
91 pub status: StepStatus,
93 #[serde(default)]
95 pub outputs: Vec<ExpectedField>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100#[serde(tag = "action", rename_all = "snake_case")]
101pub enum FailureAction {
102 Retry {
104 max_retries: u32,
106 #[serde(default)]
108 retry_target: Option<String>,
109 #[serde(default)]
111 feedback_field: Option<String>,
112 #[serde(default)]
114 on_exhausted: Option<Box<FailureAction>>,
115 },
116 Escalate {
118 to: String,
120 },
121 Skip,
123 Fail,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct StepContract {
130 pub expects: ContractExpectation,
132 #[serde(default)]
134 pub on_failure: Option<FailureAction>,
135}
136
137#[derive(Debug, Clone)]
139pub struct ParsedOutput {
140 pub status: StepStatus,
142 pub fields: HashMap<String, serde_json::Value>,
144 pub raw_output: String,
146}
147
148pub struct ContractParser;
150
151impl ContractParser {
152 pub fn parse(output: &str, contract: &StepContract) -> Result<ParsedOutput> {
154 let status = Self::extract_status(output)?;
156
157 let fields = Self::parse_fields(output)?;
159
160 if status == contract.expects.status {
162 Self::validate_fields(&fields, &contract.expects.outputs)?;
163 }
164
165 Ok(ParsedOutput {
166 status,
167 fields,
168 raw_output: output.to_string(),
169 })
170 }
171
172 fn extract_status(output: &str) -> Result<StepStatus> {
174 for line in output.lines() {
175 let line = line.trim();
176 if let Some(status_str) = line.strip_prefix("STATUS:") {
177 return StepStatus::parse(status_str.trim());
178 }
179 }
180
181 Err(anyhow!(
182 "Missing STATUS field. Expected: STATUS: done|retry|blocked\n\nOutput:\n{}",
183 output
184 ))
185 }
186
187 fn parse_fields(output: &str) -> Result<HashMap<String, serde_json::Value>> {
189 let mut fields = HashMap::new();
190
191 for line in output.lines() {
192 let line = line.trim();
193
194 if line.is_empty() || line.starts_with('#') {
196 continue;
197 }
198
199 if let Some(pos) = line.find(':') {
201 let key = line[..pos].trim();
202 let value = line[pos + 1..].trim();
203
204 if key == "STATUS" {
206 continue;
207 }
208
209 let parsed_value = if (value.starts_with('[') && value.ends_with(']'))
211 || (value.starts_with('{') && value.ends_with('}'))
212 {
213 serde_json::from_str(value)
214 .unwrap_or_else(|_| serde_json::Value::String(value.to_string()))
215 } else {
216 serde_json::Value::String(value.to_string())
217 };
218
219 fields.insert(key.to_string(), parsed_value);
220 }
221 }
222
223 Ok(fields)
224 }
225
226 fn validate_fields(
228 fields: &HashMap<String, serde_json::Value>,
229 expected: &[ExpectedField],
230 ) -> Result<()> {
231 let mut errors = Vec::new();
232
233 for field_def in expected {
234 let field_name = &field_def.name;
235
236 match fields.get(field_name) {
237 Some(value) => {
238 if let Err(e) = Self::validate_type(value, &field_def.field_type) {
240 errors.push(format!(
241 "Field '{}' type mismatch: expected {}, got error: {}",
242 field_name,
243 format!("{:?}", field_def.field_type).to_lowercase(),
244 e
245 ));
246 }
247
248 if let Some(pattern) = &field_def.pattern {
250 let value_str = match value {
251 serde_json::Value::String(s) => s.clone(),
252 other => other.to_string(),
253 };
254
255 let regex = regex::Regex::new(pattern).context("Invalid pattern regex")?;
256
257 if !regex.is_match(&value_str) {
258 errors.push(format!(
259 "Field '{}' value '{}' does not match pattern '{}'",
260 field_name, value_str, pattern
261 ));
262 }
263 }
264
265 if !field_def.enum_values.is_empty() {
267 let value_str = match value {
268 serde_json::Value::String(s) => s.clone(),
269 other => other.to_string(),
270 };
271
272 if !field_def.enum_values.contains(&value_str) {
273 errors.push(format!(
274 "Field '{}' value '{}' not in allowed values: {:?}",
275 field_name, value_str, field_def.enum_values
276 ));
277 }
278 }
279 }
280 None => {
281 if field_def.required {
282 errors.push(format!(
283 "Missing required field: {} (type: {:?})",
284 field_name, field_def.field_type
285 ));
286 }
287 }
288 }
289 }
290
291 if errors.is_empty() {
292 Ok(())
293 } else {
294 Err(anyhow!(
295 "Contract validation failed:\n{}",
296 errors.join("\n")
297 ))
298 }
299 }
300
301 fn validate_type(value: &serde_json::Value, expected: &FieldType) -> Result<()> {
303 match expected {
304 FieldType::String => {
305 if !value.is_string() {
306 return Err(anyhow!("Expected string, got {}", value));
307 }
308 }
309 FieldType::Integer => {
310 if !value.is_i64() && !value.is_u64() {
311 return Err(anyhow!("Expected integer, got {}", value));
312 }
313 }
314 FieldType::Float => {
315 if !value.is_f64() && !value.is_i64() && !value.is_u64() {
316 return Err(anyhow!("Expected number, got {}", value));
317 }
318 }
319 FieldType::Boolean => {
320 if !value.is_boolean() {
321 return Err(anyhow!("Expected boolean, got {}", value));
322 }
323 }
324 FieldType::Json => {
325 }
327 FieldType::StringArray => {
328 if let serde_json::Value::Array(arr) = value {
329 for (i, item) in arr.iter().enumerate() {
330 if !item.is_string() {
331 return Err(anyhow!(
332 "Expected string array, but item {} is {}",
333 i,
334 item
335 ));
336 }
337 }
338 } else {
339 return Err(anyhow!("Expected array, got {}", value));
340 }
341 }
342 }
343
344 Ok(())
345 }
346
347 pub fn get_feedback(output: &str, field_name: &str) -> Option<String> {
349 let fields = Self::parse_fields(output).ok()?;
350 fields.get(field_name).map(|v| v.to_string())
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_parse_status() {
360 let output = "STATUS: done\nREPO: /path/to/repo";
361 let status = ContractParser::extract_status(output).unwrap();
362 assert_eq!(status, StepStatus::Done);
363
364 let output = "STATUS: retry\nISSUES: something went wrong";
365 let status = ContractParser::extract_status(output).unwrap();
366 assert_eq!(status, StepStatus::Retry);
367
368 let output = "STATUS: blocked\nREASON: need permission";
369 let status = ContractParser::extract_status(output).unwrap();
370 assert_eq!(status, StepStatus::Blocked);
371 }
372
373 #[test]
374 fn test_parse_fields() {
375 let output = r#"
376STATUS: done
377REPO: /path/to/repo
378BRANCH: feature-branch
379COUNT: 42
380"#;
381
382 let fields = ContractParser::parse_fields(output).unwrap();
383 assert_eq!(
384 fields.get("REPO").unwrap().as_str().unwrap(),
385 "/path/to/repo"
386 );
387 assert_eq!(
388 fields.get("BRANCH").unwrap().as_str().unwrap(),
389 "feature-branch"
390 );
391 assert_eq!(fields.get("COUNT").unwrap().as_str().unwrap(), "42");
392 }
393
394 #[test]
395 fn test_parse_json_field() {
396 let output = r#"
397STATUS: done
398STORIES_JSON: [{"id": 1, "title": "Story 1"}, {"id": 2, "title": "Story 2"}]
399"#;
400
401 let fields = ContractParser::parse_fields(output).unwrap();
402 let stories = fields.get("STORIES_JSON").unwrap();
403 assert!(stories.is_array());
404 assert_eq!(stories.as_array().unwrap().len(), 2);
405 }
406
407 #[test]
408 fn test_validate_contract() {
409 let contract = StepContract {
410 expects: ContractExpectation {
411 status: StepStatus::Done,
412 outputs: vec![
413 ExpectedField {
414 name: "REPO".to_string(),
415 field_type: FieldType::String,
416 required: true,
417 pattern: None,
418 enum_values: vec![],
419 },
420 ExpectedField {
421 name: "BRANCH".to_string(),
422 field_type: FieldType::String,
423 required: true,
424 pattern: None,
425 enum_values: vec![],
426 },
427 ],
428 },
429 on_failure: None,
430 };
431
432 let output = r#"
433STATUS: done
434REPO: /path/to/repo
435BRANCH: feature-branch
436"#;
437
438 let result = ContractParser::parse(output, &contract);
439 assert!(result.is_ok());
440
441 let parsed = result.unwrap();
442 assert_eq!(parsed.status, StepStatus::Done);
443 assert_eq!(
444 parsed.fields.get("REPO").unwrap().as_str().unwrap(),
445 "/path/to/repo"
446 );
447 }
448
449 #[test]
450 fn test_validate_missing_field() {
451 let contract = StepContract {
452 expects: ContractExpectation {
453 status: StepStatus::Done,
454 outputs: vec![
455 ExpectedField {
456 name: "REPO".to_string(),
457 field_type: FieldType::String,
458 required: true,
459 pattern: None,
460 enum_values: vec![],
461 },
462 ExpectedField {
463 name: "BRANCH".to_string(),
464 field_type: FieldType::String,
465 required: true,
466 pattern: None,
467 enum_values: vec![],
468 },
469 ],
470 },
471 on_failure: None,
472 };
473
474 let output = r#"
475STATUS: done
476REPO: /path/to/repo
477"#;
478
479 let result = ContractParser::parse(output, &contract);
480 assert!(result.is_err());
481 assert!(result.unwrap_err().to_string().contains("BRANCH"));
482 }
483
484 #[test]
485 fn test_get_feedback() {
486 let output = r#"
487STATUS: retry
488ISSUES: The test is failing due to missing imports
489"#;
490
491 let feedback = ContractParser::get_feedback(output, "ISSUES");
492 assert!(feedback.is_some());
493 assert!(feedback.unwrap().contains("missing imports"));
494 }
495}