1use std::collections::HashMap;
32
33use regex::Regex;
34use thiserror::Error;
35
36use crate::facts::{FactValue, FactValues};
37
38#[derive(Debug, Error)]
41pub enum WhenError {
42 #[error("when parse error at column {pos}: {message}")]
43 Parse { pos: usize, message: String },
44 #[error("when evaluation error: {0}")]
45 Eval(String),
46 #[error("invalid regex in `matches`: {0}")]
47 Regex(String),
48}
49
50#[derive(Debug, Clone)]
53pub enum Value {
54 Bool(bool),
55 Int(i64),
56 String(String),
57 List(Vec<Value>),
58 Null,
59}
60
61impl Value {
62 pub fn truthy(&self) -> bool {
63 match self {
64 Self::Bool(b) => *b,
65 Self::Int(n) => *n != 0,
66 Self::String(s) => !s.is_empty(),
67 Self::List(v) => !v.is_empty(),
68 Self::Null => false,
69 }
70 }
71
72 fn type_name(&self) -> &'static str {
73 match self {
74 Self::Bool(_) => "bool",
75 Self::Int(_) => "int",
76 Self::String(_) => "string",
77 Self::List(_) => "list",
78 Self::Null => "null",
79 }
80 }
81}
82
83impl From<&FactValue> for Value {
84 fn from(f: &FactValue) -> Self {
85 match f {
86 FactValue::Bool(b) => Self::Bool(*b),
87 FactValue::Int(n) => Self::Int(*n),
88 FactValue::String(s) => Self::String(s.clone()),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum Namespace {
97 Facts,
98 Vars,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum CmpOp {
103 Eq,
104 Ne,
105 Lt,
106 Le,
107 Gt,
108 Ge,
109 In,
110}
111
112#[derive(Debug, Clone)]
113pub enum WhenExpr {
114 Literal(Value),
115 Ident {
116 ns: Namespace,
117 name: String,
118 },
119 Not(Box<WhenExpr>),
120 And(Box<WhenExpr>, Box<WhenExpr>),
121 Or(Box<WhenExpr>, Box<WhenExpr>),
122 Cmp {
123 left: Box<WhenExpr>,
124 op: CmpOp,
125 right: Box<WhenExpr>,
126 },
127 Matches {
129 left: Box<WhenExpr>,
130 pattern: Regex,
131 },
132 List(Vec<WhenExpr>),
133}
134
135#[derive(Debug)]
138pub struct WhenEnv<'a> {
139 pub facts: &'a FactValues,
140 pub vars: &'a HashMap<String, String>,
141}
142
143pub fn parse(src: &str) -> Result<WhenExpr, WhenError> {
146 let tokens = lex(src)?;
147 let mut p = Parser { tokens, pos: 0 };
148 let expr = p.parse_expr()?;
149 p.expect_eof()?;
150 Ok(expr)
151}
152
153impl WhenExpr {
154 pub fn evaluate(&self, env: &WhenEnv<'_>) -> Result<bool, WhenError> {
155 let v = eval(self, env)?;
156 Ok(v.truthy())
157 }
158}
159
160#[derive(Debug, Clone)]
163enum Tok {
164 Bool(bool),
165 Null,
166 Int(i64),
167 Str(String),
168 Ident(String),
169 Dot,
170 LParen,
171 RParen,
172 LBracket,
173 RBracket,
174 Comma,
175 Eq2,
176 Ne,
177 Lt,
178 Le,
179 Gt,
180 Ge,
181 KwAnd,
182 KwOr,
183 KwNot,
184 KwIn,
185 KwMatches,
186}
187
188#[allow(clippy::too_many_lines)]
189fn lex(src: &str) -> Result<Vec<(Tok, usize)>, WhenError> {
190 let bytes = src.as_bytes();
191 let mut out = Vec::new();
192 let mut i = 0;
193 while i < bytes.len() {
194 let c = bytes[i];
195 if c == b' ' || c == b'\t' || c == b'\n' || c == b'\r' {
197 i += 1;
198 continue;
199 }
200 let start = i;
201 match c {
202 b'.' => {
203 out.push((Tok::Dot, start));
204 i += 1;
205 }
206 b'(' => {
207 out.push((Tok::LParen, start));
208 i += 1;
209 }
210 b')' => {
211 out.push((Tok::RParen, start));
212 i += 1;
213 }
214 b'[' => {
215 out.push((Tok::LBracket, start));
216 i += 1;
217 }
218 b']' => {
219 out.push((Tok::RBracket, start));
220 i += 1;
221 }
222 b',' => {
223 out.push((Tok::Comma, start));
224 i += 1;
225 }
226 b'=' => {
227 if bytes.get(i + 1) == Some(&b'=') {
228 out.push((Tok::Eq2, start));
229 i += 2;
230 } else {
231 return Err(WhenError::Parse {
232 pos: start,
233 message: "expected '==' (bare '=' is not an operator)".into(),
234 });
235 }
236 }
237 b'!' => {
238 if bytes.get(i + 1) == Some(&b'=') {
239 out.push((Tok::Ne, start));
240 i += 2;
241 } else {
242 return Err(WhenError::Parse {
243 pos: start,
244 message: "expected '!=' (use 'not' for logical negation)".into(),
245 });
246 }
247 }
248 b'<' => {
249 if bytes.get(i + 1) == Some(&b'=') {
250 out.push((Tok::Le, start));
251 i += 2;
252 } else {
253 out.push((Tok::Lt, start));
254 i += 1;
255 }
256 }
257 b'>' => {
258 if bytes.get(i + 1) == Some(&b'=') {
259 out.push((Tok::Ge, start));
260 i += 2;
261 } else {
262 out.push((Tok::Gt, start));
263 i += 1;
264 }
265 }
266 b'"' | b'\'' => {
267 let quote = c;
268 i += 1;
269 let mut s = String::new();
270 while i < bytes.len() && bytes[i] != quote {
271 if bytes[i] == b'\\' && i + 1 < bytes.len() {
272 let esc = bytes[i + 1];
273 let ch = match esc {
274 b'n' => '\n',
275 b't' => '\t',
276 b'r' => '\r',
277 b'\\' => '\\',
278 b'"' => '"',
279 b'\'' => '\'',
280 _ => {
281 return Err(WhenError::Parse {
282 pos: i,
283 message: format!(
284 "unknown escape \\{} in string literal",
285 esc as char,
286 ),
287 });
288 }
289 };
290 s.push(ch);
291 i += 2;
292 } else {
293 s.push(bytes[i] as char);
294 i += 1;
295 }
296 }
297 if i >= bytes.len() {
298 return Err(WhenError::Parse {
299 pos: start,
300 message: "unterminated string literal".into(),
301 });
302 }
303 i += 1;
304 out.push((Tok::Str(s), start));
305 }
306 c if c.is_ascii_digit() => {
307 let mut j = i;
308 while j < bytes.len() && bytes[j].is_ascii_digit() {
309 j += 1;
310 }
311 let num = std::str::from_utf8(&bytes[i..j])
312 .unwrap()
313 .parse::<i64>()
314 .map_err(|e| WhenError::Parse {
315 pos: start,
316 message: format!("invalid integer: {e}"),
317 })?;
318 out.push((Tok::Int(num), start));
319 i = j;
320 }
321 c if is_ident_start(c) => {
322 let mut j = i;
323 while j < bytes.len() && is_ident_cont(bytes[j]) {
324 j += 1;
325 }
326 let word = &src[i..j];
327 let tok = match word {
328 "true" => Tok::Bool(true),
329 "false" => Tok::Bool(false),
330 "null" => Tok::Null,
331 "and" => Tok::KwAnd,
332 "or" => Tok::KwOr,
333 "not" => Tok::KwNot,
334 "in" => Tok::KwIn,
335 "matches" => Tok::KwMatches,
336 _ => Tok::Ident(word.to_string()),
337 };
338 out.push((tok, start));
339 i = j;
340 }
341 _ => {
342 return Err(WhenError::Parse {
343 pos: start,
344 message: format!("unexpected character {:?}", c as char),
345 });
346 }
347 }
348 }
349 Ok(out)
350}
351
352fn is_ident_start(c: u8) -> bool {
353 c.is_ascii_alphabetic() || c == b'_'
354}
355
356fn is_ident_cont(c: u8) -> bool {
357 c.is_ascii_alphanumeric() || c == b'_'
358}
359
360struct Parser {
363 tokens: Vec<(Tok, usize)>,
364 pos: usize,
365}
366
367impl Parser {
368 fn peek(&self) -> Option<&Tok> {
369 self.tokens.get(self.pos).map(|(t, _)| t)
370 }
371
372 fn advance(&mut self) -> Option<&(Tok, usize)> {
373 let p = self.pos;
374 self.pos += 1;
375 self.tokens.get(p)
376 }
377
378 fn pos_here(&self) -> usize {
379 self.tokens.get(self.pos).map_or_else(
380 || self.tokens.last().map_or(0, |(_, p)| *p + 1),
381 |(_, p)| *p,
382 )
383 }
384
385 fn err(&self, message: impl Into<String>) -> WhenError {
386 WhenError::Parse {
387 pos: self.pos_here(),
388 message: message.into(),
389 }
390 }
391
392 fn expect_eof(&mut self) -> Result<(), WhenError> {
393 if self.peek().is_some() {
394 Err(self.err("unexpected trailing token"))
395 } else {
396 Ok(())
397 }
398 }
399
400 fn parse_expr(&mut self) -> Result<WhenExpr, WhenError> {
401 self.parse_or()
402 }
403
404 fn parse_or(&mut self) -> Result<WhenExpr, WhenError> {
405 let mut left = self.parse_and()?;
406 while matches!(self.peek(), Some(Tok::KwOr)) {
407 self.advance();
408 let right = self.parse_and()?;
409 left = WhenExpr::Or(Box::new(left), Box::new(right));
410 }
411 Ok(left)
412 }
413
414 fn parse_and(&mut self) -> Result<WhenExpr, WhenError> {
415 let mut left = self.parse_not()?;
416 while matches!(self.peek(), Some(Tok::KwAnd)) {
417 self.advance();
418 let right = self.parse_not()?;
419 left = WhenExpr::And(Box::new(left), Box::new(right));
420 }
421 Ok(left)
422 }
423
424 fn parse_not(&mut self) -> Result<WhenExpr, WhenError> {
425 if matches!(self.peek(), Some(Tok::KwNot)) {
426 self.advance();
427 let inner = self.parse_cmp()?;
428 return Ok(WhenExpr::Not(Box::new(inner)));
429 }
430 self.parse_cmp()
431 }
432
433 fn parse_cmp(&mut self) -> Result<WhenExpr, WhenError> {
434 let left = self.parse_primary()?;
435 let op = match self.peek() {
436 Some(Tok::Eq2) => Some(CmpOp::Eq),
437 Some(Tok::Ne) => Some(CmpOp::Ne),
438 Some(Tok::Lt) => Some(CmpOp::Lt),
439 Some(Tok::Le) => Some(CmpOp::Le),
440 Some(Tok::Gt) => Some(CmpOp::Gt),
441 Some(Tok::Ge) => Some(CmpOp::Ge),
442 Some(Tok::KwIn) => Some(CmpOp::In),
443 _ => None,
444 };
445 if let Some(op) = op {
446 self.advance();
447 let right = self.parse_primary()?;
448 return Ok(WhenExpr::Cmp {
449 left: Box::new(left),
450 op,
451 right: Box::new(right),
452 });
453 }
454 if matches!(self.peek(), Some(Tok::KwMatches)) {
455 self.advance();
456 let pos = self.pos_here();
457 match self.advance() {
458 Some((Tok::Str(s), _)) => {
459 let pattern = Regex::new(s)
460 .map_err(|e| WhenError::Regex(format!("{e} (at column {pos})")))?;
461 return Ok(WhenExpr::Matches {
462 left: Box::new(left),
463 pattern,
464 });
465 }
466 _ => {
467 return Err(WhenError::Parse {
468 pos,
469 message: "`matches` right-hand side must be a string literal".into(),
470 });
471 }
472 }
473 }
474 Ok(left)
475 }
476
477 fn parse_primary(&mut self) -> Result<WhenExpr, WhenError> {
478 let pos = self.pos_here();
479 match self.advance() {
480 Some((Tok::Bool(b), _)) => Ok(WhenExpr::Literal(Value::Bool(*b))),
481 Some((Tok::Null, _)) => Ok(WhenExpr::Literal(Value::Null)),
482 Some((Tok::Int(n), _)) => Ok(WhenExpr::Literal(Value::Int(*n))),
483 Some((Tok::Str(s), _)) => Ok(WhenExpr::Literal(Value::String(s.clone()))),
484 Some((Tok::LParen, _)) => {
485 let inner = self.parse_expr()?;
486 match self.advance() {
487 Some((Tok::RParen, _)) => Ok(inner),
488 _ => Err(WhenError::Parse {
489 pos,
490 message: "expected ')'".into(),
491 }),
492 }
493 }
494 Some((Tok::LBracket, _)) => {
495 let mut items = Vec::new();
496 if !matches!(self.peek(), Some(Tok::RBracket)) {
497 items.push(self.parse_expr()?);
498 while matches!(self.peek(), Some(Tok::Comma)) {
499 self.advance();
500 items.push(self.parse_expr()?);
501 }
502 }
503 match self.advance() {
504 Some((Tok::RBracket, _)) => Ok(WhenExpr::List(items)),
505 _ => Err(WhenError::Parse {
506 pos,
507 message: "expected ']'".into(),
508 }),
509 }
510 }
511 Some((Tok::Ident(name), _)) => {
512 let name_owned = name.clone();
513 let ns = match name_owned.as_str() {
514 "facts" => Namespace::Facts,
515 "vars" => Namespace::Vars,
516 other => {
517 return Err(WhenError::Parse {
518 pos,
519 message: format!(
520 "unknown identifier {other:?}; only `facts.NAME` and `vars.NAME` are allowed"
521 ),
522 });
523 }
524 };
525 if !matches!(self.advance(), Some((Tok::Dot, _))) {
526 return Err(WhenError::Parse {
527 pos,
528 message: format!("expected '.' after {name_owned:?}"),
529 });
530 }
531 let field_pos = self.pos_here();
532 let field = match self.advance() {
533 Some((Tok::Ident(f), _)) => f.clone(),
534 _ => {
535 return Err(WhenError::Parse {
536 pos: field_pos,
537 message: "expected identifier after '.'".into(),
538 });
539 }
540 };
541 Ok(WhenExpr::Ident { ns, name: field })
542 }
543 _ => Err(WhenError::Parse {
544 pos,
545 message: "expected literal, identifier, '(' or '['".into(),
546 }),
547 }
548 }
549}
550
551fn eval(e: &WhenExpr, env: &WhenEnv<'_>) -> Result<Value, WhenError> {
554 match e {
555 WhenExpr::Literal(v) => Ok(v.clone()),
556 WhenExpr::Ident { ns, name } => match ns {
557 Namespace::Facts => match env.facts.get(name) {
558 Some(f) => Ok(Value::from(f)),
559 None => Ok(Value::Null),
560 },
561 Namespace::Vars => match env.vars.get(name) {
562 Some(v) => Ok(Value::String(v.clone())),
563 None => Ok(Value::Null),
564 },
565 },
566 WhenExpr::Not(inner) => Ok(Value::Bool(!eval(inner, env)?.truthy())),
567 WhenExpr::And(l, r) => {
568 let lv = eval(l, env)?;
569 if !lv.truthy() {
570 return Ok(Value::Bool(false));
571 }
572 Ok(Value::Bool(eval(r, env)?.truthy()))
573 }
574 WhenExpr::Or(l, r) => {
575 let lv = eval(l, env)?;
576 if lv.truthy() {
577 return Ok(Value::Bool(true));
578 }
579 Ok(Value::Bool(eval(r, env)?.truthy()))
580 }
581 WhenExpr::Cmp { left, op, right } => {
582 let lv = eval(left, env)?;
583 let rv = eval(right, env)?;
584 Ok(Value::Bool(apply_cmp(&lv, *op, &rv)?))
585 }
586 WhenExpr::Matches { left, pattern } => {
587 let lv = eval(left, env)?;
588 match lv {
589 Value::String(s) => Ok(Value::Bool(pattern.is_match(&s))),
590 other => Err(WhenError::Eval(format!(
591 "`matches` left-hand side must be a string; got {}",
592 other.type_name()
593 ))),
594 }
595 }
596 WhenExpr::List(items) => {
597 let mut out = Vec::with_capacity(items.len());
598 for item in items {
599 out.push(eval(item, env)?);
600 }
601 Ok(Value::List(out))
602 }
603 }
604}
605
606fn apply_cmp(l: &Value, op: CmpOp, r: &Value) -> Result<bool, WhenError> {
607 use Value::{Bool, Int, List, Null, String as S};
608 match op {
609 CmpOp::Eq => Ok(values_equal(l, r)),
610 CmpOp::Ne => Ok(!values_equal(l, r)),
611 CmpOp::Lt | CmpOp::Le | CmpOp::Gt | CmpOp::Ge => match (l, r) {
612 (Int(a), Int(b)) => Ok(cmp_ord(a, b, op)),
613 (S(a), S(b)) => Ok(cmp_ord(&a.as_str(), &b.as_str(), op)),
614 _ => Err(WhenError::Eval(format!(
615 "cannot compare {} with {}",
616 l.type_name(),
617 r.type_name(),
618 ))),
619 },
620 CmpOp::In => match r {
621 List(items) => Ok(items.iter().any(|x| values_equal(l, x))),
622 S(haystack) => match l {
623 S(needle) => Ok(haystack.contains(needle.as_str())),
624 _ => Err(WhenError::Eval(format!(
625 "`in` with a string right-hand side requires a string left; got {}",
626 l.type_name()
627 ))),
628 },
629 _ => {
630 let _ = (Bool(false), Null);
631 Err(WhenError::Eval(format!(
632 "`in` right-hand side must be a list or string; got {}",
633 r.type_name()
634 )))
635 }
636 },
637 }
638}
639
640fn values_equal(a: &Value, b: &Value) -> bool {
641 match (a, b) {
642 (Value::Bool(x), Value::Bool(y)) => x == y,
643 (Value::Int(x), Value::Int(y)) => x == y,
644 (Value::String(x), Value::String(y)) => x == y,
645 (Value::Null, Value::Null) => true,
646 (Value::List(x), Value::List(y)) => {
647 x.len() == y.len() && x.iter().zip(y.iter()).all(|(a, b)| values_equal(a, b))
648 }
649 _ => false,
650 }
651}
652
653fn cmp_ord<T: PartialOrd>(a: &T, b: &T, op: CmpOp) -> bool {
654 match op {
655 CmpOp::Lt => a < b,
656 CmpOp::Le => a <= b,
657 CmpOp::Gt => a > b,
658 CmpOp::Ge => a >= b,
659 _ => unreachable!(),
660 }
661}
662
663#[cfg(test)]
666mod tests {
667 use super::*;
668
669 fn env() -> (FactValues, HashMap<String, String>) {
670 let mut f = FactValues::new();
671 f.insert("is_rust".into(), FactValue::Bool(true));
672 f.insert("is_node".into(), FactValue::Bool(false));
673 f.insert("n_files".into(), FactValue::Int(42));
674 f.insert("primary".into(), FactValue::String("Rust".into()));
675 let mut v = HashMap::new();
676 v.insert("org".into(), "Acme Corp".into());
677 v.insert("year".into(), "2026".into());
678 (f, v)
679 }
680
681 fn check(src: &str) -> bool {
682 let (facts, vars) = env();
683 let expr = parse(src).unwrap();
684 expr.evaluate(&WhenEnv {
685 facts: &facts,
686 vars: &vars,
687 })
688 .unwrap()
689 }
690
691 #[test]
692 fn simple_facts() {
693 assert!(check("facts.is_rust"));
694 assert!(!check("facts.is_node"));
695 assert!(check("not facts.is_node"));
696 }
697
698 #[test]
699 fn integer_comparison() {
700 assert!(check("facts.n_files > 0"));
701 assert!(check("facts.n_files == 42"));
702 assert!(!check("facts.n_files < 10"));
703 assert!(check("facts.n_files >= 42"));
704 }
705
706 #[test]
707 fn string_equality() {
708 assert!(check("facts.primary == \"Rust\""));
709 assert!(!check("facts.primary == \"Go\""));
710 }
711
712 #[test]
713 fn logical_ops_short_circuit() {
714 assert!(check("facts.is_rust and facts.n_files > 0"));
715 assert!(check("facts.is_node or facts.is_rust"));
716 assert!(!check("facts.is_node and facts.nonexistent == 5"));
717 }
718
719 #[test]
720 fn in_list() {
721 assert!(check("facts.primary in [\"Rust\", \"Go\"]"));
722 assert!(!check("facts.primary in [\"Python\", \"Java\"]"));
723 }
724
725 #[test]
726 fn in_string_is_substring() {
727 assert!(check("\"cme\" in vars.org"));
728 assert!(!check("\"Xyz\" in vars.org"));
729 }
730
731 #[test]
732 fn matches_regex() {
733 assert!(check("vars.org matches \"^Acme\""));
734 assert!(check("vars.year matches \"^\\\\d{4}$\""));
735 assert!(!check("vars.org matches \"^Xyz\""));
736 }
737
738 #[test]
739 fn parentheses_override_precedence() {
740 assert!(check(
741 "(facts.is_node or facts.is_rust) and facts.n_files > 0"
742 ));
743 assert!(!check("facts.is_node or facts.is_rust and facts.is_node"));
744 }
747
748 #[test]
749 fn unknown_facts_are_null_and_falsy() {
750 assert!(!check("facts.nonexistent"));
751 assert!(check("not facts.nonexistent"));
752 }
753
754 #[test]
755 fn unknown_vars_are_null() {
756 assert!(!check("vars.not_set"));
757 }
758
759 #[test]
760 fn null_equals_null() {
761 assert!(check("facts.nonexistent == null"));
762 }
763
764 #[test]
765 fn parse_rejects_bare_equals() {
766 let e = parse("facts.x = 1").unwrap_err();
767 matches!(e, WhenError::Parse { .. });
768 }
769
770 #[test]
771 fn parse_rejects_bang_alone() {
772 let e = parse("!facts.x").unwrap_err();
773 matches!(e, WhenError::Parse { .. });
774 }
775
776 #[test]
777 fn parse_rejects_invalid_identifier_namespace() {
778 let e = parse("ctx.x").unwrap_err();
779 let WhenError::Parse { message, .. } = e else {
780 panic!();
781 };
782 assert!(message.contains("facts.NAME"));
783 }
784
785 #[test]
786 fn parse_rejects_matches_with_non_literal_rhs() {
787 let e = parse("vars.org matches vars.pattern").unwrap_err();
788 let WhenError::Parse { message, .. } = e else {
789 panic!();
790 };
791 assert!(message.contains("string literal"));
792 }
793
794 #[test]
795 fn parse_rejects_invalid_regex() {
796 let e = parse("vars.org matches \"[unclosed\"").unwrap_err();
797 matches!(e, WhenError::Regex(_));
798 }
799
800 #[test]
801 fn evaluate_rejects_ordering_mixed_types() {
802 let (facts, vars) = env();
803 let expr = parse("facts.primary > facts.n_files").unwrap();
804 let result = expr.evaluate(&WhenEnv {
805 facts: &facts,
806 vars: &vars,
807 });
808 assert!(result.is_err());
809 }
810
811 #[test]
812 fn string_escapes() {
813 let (facts, vars) = env();
814 let expr = parse("vars.org == \"Acme Corp\"").unwrap();
815 assert!(
816 expr.evaluate(&WhenEnv {
817 facts: &facts,
818 vars: &vars,
819 })
820 .unwrap()
821 );
822 }
823
824 #[test]
825 fn nested_not_and_or() {
826 assert!(check(
827 "not (facts.is_node or (facts.n_files == 0 and facts.is_rust))"
828 ));
829 }
830}