1use regex::Regex;
31
32#[derive(Debug, Clone, PartialEq)]
37pub enum Predicate {
38 CountAtLeast { key: String, min: usize },
40
41 CountAtMost { key: String, max: usize },
43
44 ContentMustNotContain {
46 key: String,
47 forbidden: Vec<ForbiddenTerm>,
48 },
49
50 ContentMustContain { key: String, required_field: String },
52
53 CrossReference {
55 source_key: String,
56 target_key: String,
57 },
58
59 HasFacts { key: String },
61
62 RequiredFields {
64 key: String,
65 fields: Vec<FieldRequirement>,
66 },
67
68 Custom { description: String },
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct ForbiddenTerm {
75 pub term: String,
76 pub reason: String,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct FieldRequirement {
82 pub field: String,
83 pub rule: String,
84}
85
86#[derive(Debug, Clone)]
88pub enum PredicateError {
89 UnknownContextKey(String),
91 ParseError(String),
93}
94
95impl std::fmt::Display for PredicateError {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 Self::UnknownContextKey(key) => write!(f, "unknown context key: {key}"),
99 Self::ParseError(msg) => write!(f, "parse error: {msg}"),
100 }
101 }
102}
103
104impl std::error::Error for PredicateError {}
105
106const KNOWN_KEYS: &[&str] = &[
108 "Seeds",
109 "Hypotheses",
110 "Strategies",
111 "Constraints",
112 "Signals",
113 "Competitors",
114 "Evaluations",
115];
116
117fn is_valid_key(key: &str) -> bool {
119 KNOWN_KEYS.contains(&key)
120}
121
122pub fn parse_steps(
143 steps: &[(&str, &str, Vec<Vec<String>>)],
144) -> Result<Vec<Predicate>, PredicateError> {
145 let mut predicates = Vec::new();
146
147 let mut current_key: Option<String> = None;
149
150 for (step_type, text, table) in steps {
151 match *step_type {
152 "Given" => {
153 if text.contains("engine halts") || text.contains("engine is") {
155 continue;
156 }
157
158 if let Some(key) = extract_context_key(text) {
160 if is_valid_key(&key) {
161 current_key = Some(key.clone());
162 if text.contains("any fact") || text.contains("facts") {
164 predicates.push(Predicate::HasFacts { key });
165 }
166 }
167 }
170 }
171 "Then" => {
172 let pred = parse_then_step(text, table, ¤t_key)?;
173 predicates.push(pred);
174 }
175 "And" => {
176 if text.contains("must include") || text.contains("must contain") {
178 let pred = parse_then_step(text, table, ¤t_key)?;
179 predicates.push(pred);
180 }
181 }
182 _ => {} }
184 }
185
186 Ok(predicates)
187}
188
189fn parse_then_step(
191 text: &str,
192 table: &[Vec<String>],
193 current_key: &Option<String>,
194) -> Result<Predicate, PredicateError> {
195 let count_at_least = Regex::new(r#"(?:contains?|at least)\s+(\d+)\s+facts?"#).unwrap();
197 if let Some(caps) = count_at_least.captures(text) {
198 let min: usize = caps[1].parse().unwrap_or(1);
199 let key = extract_context_key(text)
200 .or_else(|| current_key.clone())
201 .unwrap_or_default();
202 if !key.is_empty() {
203 validate_key(&key)?;
204 }
205 return Ok(Predicate::CountAtLeast { key, min });
206 }
207
208 let count_at_most = Regex::new(r#"at most\s+(\d+)\s+facts?"#).unwrap();
210 if let Some(caps) = count_at_most.captures(text) {
211 let max: usize = caps[1].parse().unwrap_or(1);
212 let key = extract_context_key(text)
213 .or_else(|| current_key.clone())
214 .unwrap_or_default();
215 if !key.is_empty() {
216 validate_key(&key)?;
217 }
218 return Ok(Predicate::CountAtMost { key, max });
219 }
220
221 if text.contains("must not contain") {
223 let key = current_key.clone().unwrap_or_default();
224 let forbidden = parse_forbidden_terms(table);
225 return Ok(Predicate::ContentMustNotContain { key, forbidden });
226 }
227
228 let cross_ref =
230 Regex::new(r#"for each\s+(\w+)\s+fact.*?exists?\s+(?:a |an )?(\w+)\s+fact"#).unwrap();
231 if let Some(caps) = cross_ref.captures(text) {
232 let source_key = caps[1].to_string();
233 let target_key = caps[2].to_string();
234 return Ok(Predicate::CrossReference {
235 source_key,
236 target_key,
237 });
238 }
239
240 if (text.contains("must include") || text.contains("must contain a field")) && !table.is_empty()
242 {
243 let key = current_key.clone().unwrap_or_default();
244 let fields = parse_field_requirements(table);
245 return Ok(Predicate::RequiredFields { key, fields });
246 }
247
248 let field_pattern = Regex::new(r#"must contain (?:a )?field\s+"(\w+)""#).unwrap();
250 if let Some(caps) = field_pattern.captures(text) {
251 let key = current_key.clone().unwrap_or_default();
252 return Ok(Predicate::ContentMustContain {
253 key,
254 required_field: caps[1].to_string(),
255 });
256 }
257
258 Ok(Predicate::Custom {
260 description: text.to_string(),
261 })
262}
263
264fn extract_context_key(text: &str) -> Option<String> {
266 let re = Regex::new(r#""(\w+)""#).unwrap();
267 re.captures(text).map(|caps| caps[1].to_string())
268}
269
270fn validate_key(key: &str) -> Result<(), PredicateError> {
272 if is_valid_key(key) {
273 Ok(())
274 } else {
275 Err(PredicateError::UnknownContextKey(key.to_string()))
276 }
277}
278
279fn parse_forbidden_terms(table: &[Vec<String>]) -> Vec<ForbiddenTerm> {
281 table
282 .iter()
283 .filter(|row| row.len() >= 2)
284 .map(|row| ForbiddenTerm {
285 term: row[0].clone(),
286 reason: row[1].clone(),
287 })
288 .collect()
289}
290
291fn parse_field_requirements(table: &[Vec<String>]) -> Vec<FieldRequirement> {
293 table
294 .iter()
295 .filter(|row| row.len() >= 2)
296 .map(|row| FieldRequirement {
297 field: row[0].clone(),
298 rule: row[1].clone(),
299 })
300 .collect()
301}
302
303pub fn extract_dependencies(predicates: &[Predicate]) -> Vec<String> {
308 let mut deps = std::collections::BTreeSet::new();
309
310 for pred in predicates {
311 match pred {
312 Predicate::CountAtLeast { key, .. }
313 | Predicate::CountAtMost { key, .. }
314 | Predicate::ContentMustNotContain { key, .. }
315 | Predicate::ContentMustContain { key, .. }
316 | Predicate::HasFacts { key }
317 | Predicate::RequiredFields { key, .. } => {
318 if !key.is_empty() {
319 deps.insert(key.clone());
320 }
321 }
322 Predicate::CrossReference {
323 source_key,
324 target_key,
325 } => {
326 deps.insert(source_key.clone());
327 deps.insert(target_key.clone());
328 }
329 Predicate::Custom { .. } => {}
330 }
331 }
332
333 deps.into_iter().collect()
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
345 fn parse_count_at_least() {
346 let steps = vec![(
347 "Then",
348 r#"the Context key "Strategies" contains at least 2 facts"#,
349 vec![],
350 )];
351 let preds = parse_steps(&steps).unwrap();
352 assert_eq!(preds.len(), 1);
353 assert!(matches!(
354 &preds[0],
355 Predicate::CountAtLeast { key, min: 2 } if key == "Strategies"
356 ));
357 }
358
359 #[test]
360 fn parse_count_at_least_with_given_context() {
361 let steps = vec![
362 (
363 "Given",
364 r#"the engine halts with reason "Converged""#,
365 vec![],
366 ),
367 (
368 "Then",
369 r#"the Context key "Strategies" contains at least 2 facts"#,
370 vec![],
371 ),
372 ];
373 let preds = parse_steps(&steps).unwrap();
374 assert!(matches!(
375 &preds[0],
376 Predicate::CountAtLeast { key, min: 2 } if key == "Strategies"
377 ));
378 }
379
380 #[test]
385 fn parse_forbidden_terms_with_table() {
386 let steps = vec![
387 ("Given", r#"any fact under key "Strategies""#, vec![]),
388 (
389 "Then",
390 "it must not contain any forbidden term:",
391 vec![
392 vec!["spam".to_string(), "illegal marketing".to_string()],
393 vec!["bot army".to_string(), "fake engagement".to_string()],
394 ],
395 ),
396 ];
397 let preds = parse_steps(&steps).unwrap();
398 assert_eq!(preds.len(), 2);
400 assert!(matches!(&preds[0], Predicate::HasFacts { key } if key == "Strategies"));
401 match &preds[1] {
402 Predicate::ContentMustNotContain { key, forbidden } => {
403 assert_eq!(key, "Strategies");
404 assert_eq!(forbidden.len(), 2);
405 assert_eq!(forbidden[0].term, "spam");
406 assert_eq!(forbidden[1].reason, "fake engagement");
407 }
408 _ => panic!("expected ContentMustNotContain"),
409 }
410 }
411
412 #[test]
417 fn parse_cross_reference() {
418 let steps = vec![(
419 "Then",
420 "for each Strategy fact there exists an Evaluation fact referencing it",
421 vec![],
422 )];
423 let preds = parse_steps(&steps).unwrap();
424 assert_eq!(preds.len(), 1);
425 assert!(matches!(
426 &preds[0],
427 Predicate::CrossReference { source_key, target_key }
428 if source_key == "Strategy" && target_key == "Evaluation"
429 ));
430 }
431
432 #[test]
437 fn parse_required_fields_with_table() {
438 let steps = vec![
439 ("Given", r#"any fact under key "Evaluations""#, vec![]),
440 (
441 "Then",
442 "it must include:",
443 vec![
444 vec!["score".to_string(), "integer between 0..100".to_string()],
445 vec!["rationale".to_string(), "non-empty string".to_string()],
446 ],
447 ),
448 ];
449 let preds = parse_steps(&steps).unwrap();
450 assert_eq!(preds.len(), 2); match &preds[1] {
452 Predicate::RequiredFields { key, fields } => {
453 assert_eq!(key, "Evaluations");
454 assert_eq!(fields.len(), 2);
455 assert_eq!(fields[0].field, "score");
456 assert_eq!(fields[1].field, "rationale");
457 }
458 _ => panic!("expected RequiredFields"),
459 }
460 }
461
462 #[test]
467 fn parse_content_must_contain_field() {
468 let steps = vec![
469 ("Given", r#"any fact under key "Strategies""#, vec![]),
470 (
471 "Then",
472 r#"it must contain a field "compliance_ref" with a non-empty value"#,
473 vec![],
474 ),
475 ];
476 let preds = parse_steps(&steps).unwrap();
477 assert!(matches!(
478 &preds[1],
479 Predicate::ContentMustContain { key, required_field }
480 if key == "Strategies" && required_field == "compliance_ref"
481 ));
482 }
483
484 #[test]
489 fn parse_has_facts() {
490 let steps = vec![(
491 "Given",
492 r#"the Context contains facts under key "Signals""#,
493 vec![],
494 )];
495 let preds = parse_steps(&steps).unwrap();
496 assert_eq!(preds.len(), 1);
497 assert!(matches!(&preds[0], Predicate::HasFacts { key } if key == "Signals"));
498 }
499
500 #[test]
505 fn unrecognized_step_becomes_custom() {
506 let steps = vec![("Then", "something completely different happens", vec![])];
507 let preds = parse_steps(&steps).unwrap();
508 assert_eq!(preds.len(), 1);
509 assert!(
510 matches!(&preds[0], Predicate::Custom { description } if description.contains("completely different"))
511 );
512 }
513
514 #[test]
519 fn unknown_context_key_in_then_step_error() {
520 let steps = vec![(
521 "Then",
522 r#"the Context key "Widgets" contains at least 2 facts"#,
523 vec![],
524 )];
525 let result = parse_steps(&steps);
526 assert!(result.is_err());
527 assert!(matches!(
528 result.unwrap_err(),
529 PredicateError::UnknownContextKey(k) if k == "Widgets"
530 ));
531 }
532
533 #[test]
534 fn unknown_key_in_given_is_ignored() {
535 let steps = vec![("Given", r#"any fact under key "Widgets""#, vec![])];
537 let result = parse_steps(&steps);
538 assert!(result.is_ok());
540 assert!(result.unwrap().is_empty());
541 }
542
543 #[test]
544 fn empty_steps_produces_no_predicates() {
545 let steps: Vec<(&str, &str, Vec<Vec<String>>)> = vec![];
546 let preds = parse_steps(&steps).unwrap();
547 assert!(preds.is_empty());
548 }
549
550 #[test]
555 fn extract_deps_from_predicates() {
556 let preds = vec![
557 Predicate::CountAtLeast {
558 key: "Strategies".to_string(),
559 min: 2,
560 },
561 Predicate::CrossReference {
562 source_key: "Strategies".to_string(),
563 target_key: "Evaluations".to_string(),
564 },
565 Predicate::HasFacts {
566 key: "Seeds".to_string(),
567 },
568 ];
569 let deps = extract_dependencies(&preds);
570 assert_eq!(deps, vec!["Evaluations", "Seeds", "Strategies"]);
571 }
572
573 #[test]
574 fn extract_deps_deduplicates() {
575 let preds = vec![
576 Predicate::HasFacts {
577 key: "Strategies".to_string(),
578 },
579 Predicate::CountAtLeast {
580 key: "Strategies".to_string(),
581 min: 1,
582 },
583 ];
584 let deps = extract_dependencies(&preds);
585 assert_eq!(deps, vec!["Strategies"]);
586 }
587
588 #[test]
589 fn custom_predicates_have_no_deps() {
590 let preds = vec![Predicate::Custom {
591 description: "something".to_string(),
592 }];
593 let deps = extract_dependencies(&preds);
594 assert!(deps.is_empty());
595 }
596
597 mod property_tests {
602 use super::*;
603 use proptest::prelude::*;
604
605 proptest! {
606 #[test]
607 fn any_step_produces_predicate_or_error(text in "\\PC{1,100}") {
608 let steps = vec![("Then", text.as_str(), vec![])];
609 let _ = parse_steps(&steps);
611 }
612
613 #[test]
614 fn count_pattern_always_parses(n in 1usize..1000, key in prop::sample::select(KNOWN_KEYS)) {
615 let text = format!(r#"the Context key "{key}" contains at least {n} facts"#);
616 let steps = vec![("Then", text.as_str(), vec![])];
617 let preds = parse_steps(&steps).unwrap();
618 assert!(matches!(&preds[0], Predicate::CountAtLeast { min, .. } if *min == n));
619 }
620
621 #[test]
622 fn dependency_extraction_never_crashes(
623 keys in proptest::collection::vec("[A-Z][a-z]{3,10}", 0..5)
624 ) {
625 let preds: Vec<Predicate> = keys.iter().map(|k| Predicate::HasFacts { key: k.clone() }).collect();
626 let _ = extract_dependencies(&preds);
627 }
628 }
629 }
630}