cctr_expr/
lib.rs

1//! Expression language parser and evaluator for cctr constraints.
2//!
3//! Supports:
4//! - Numbers: `42`, `-3.14`, `0.5`
5//! - Strings: `"hello"`, `"with \"escapes\""`
6//! - Booleans: `true`, `false`
7//! - Arrays: `[1, 2, 3]`, `["a", "b"]`
8//! - Arithmetic: `+`, `-`, `*`, `/`, `^`
9//! - Comparison: `==`, `!=`, `<`, `<=`, `>`, `>=`
10//! - Logical: `and`, `or`, `not`
11//! - String ops: `contains`, `startswith`, `endswith`, `matches`
12//! - Membership: `in`
13//! - Functions: `len(s)`
14//!
15//! # Example
16//!
17//! ```
18//! use cctr_expr::{eval_bool, Value};
19//! use std::collections::HashMap;
20//!
21//! let mut vars = HashMap::new();
22//! vars.insert("n".to_string(), Value::Number(42.0));
23//!
24//! assert!(eval_bool("n > 0 and n < 100", &vars).unwrap());
25//! ```
26
27use std::collections::HashMap;
28use thiserror::Error;
29use winnow::ascii::{digit1, multispace0};
30use winnow::combinator::{alt, delimited, opt, preceded, repeat, separated, terminated};
31use winnow::error::ContextError;
32use winnow::prelude::*;
33use winnow::token::{any, none_of, one_of, take_while};
34
35// ============ Value Types ============
36
37#[derive(Debug, Clone, PartialEq)]
38pub enum Value {
39    Number(f64),
40    String(String),
41    Bool(bool),
42    Array(Vec<Value>),
43}
44
45impl Value {
46    pub fn as_bool(&self) -> Result<bool, EvalError> {
47        match self {
48            Value::Bool(b) => Ok(*b),
49            _ => Err(EvalError::TypeError {
50                expected: "bool",
51                got: self.type_name(),
52            }),
53        }
54    }
55
56    pub fn as_number(&self) -> Result<f64, EvalError> {
57        match self {
58            Value::Number(n) => Ok(*n),
59            _ => Err(EvalError::TypeError {
60                expected: "number",
61                got: self.type_name(),
62            }),
63        }
64    }
65
66    pub fn as_string(&self) -> Result<&str, EvalError> {
67        match self {
68            Value::String(s) => Ok(s),
69            _ => Err(EvalError::TypeError {
70                expected: "string",
71                got: self.type_name(),
72            }),
73        }
74    }
75
76    pub fn as_array(&self) -> Result<&[Value], EvalError> {
77        match self {
78            Value::Array(a) => Ok(a),
79            _ => Err(EvalError::TypeError {
80                expected: "array",
81                got: self.type_name(),
82            }),
83        }
84    }
85
86    fn type_name(&self) -> &'static str {
87        match self {
88            Value::Number(_) => "number",
89            Value::String(_) => "string",
90            Value::Bool(_) => "bool",
91            Value::Array(_) => "array",
92        }
93    }
94}
95
96// ============ AST Types ============
97
98#[derive(Debug, Clone, PartialEq)]
99pub enum Expr {
100    Number(f64),
101    String(String),
102    Bool(bool),
103    Var(String),
104    Array(Vec<Expr>),
105    UnaryOp {
106        op: UnaryOp,
107        expr: Box<Expr>,
108    },
109    BinaryOp {
110        op: BinaryOp,
111        left: Box<Expr>,
112        right: Box<Expr>,
113    },
114    FuncCall {
115        name: String,
116        args: Vec<Expr>,
117    },
118}
119
120#[derive(Debug, Clone, Copy, PartialEq)]
121pub enum UnaryOp {
122    Not,
123    Neg,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq)]
127pub enum BinaryOp {
128    Add,
129    Sub,
130    Mul,
131    Div,
132    Pow,
133    Eq,
134    Ne,
135    Lt,
136    Le,
137    Gt,
138    Ge,
139    And,
140    Or,
141    In,
142    Contains,
143    StartsWith,
144    EndsWith,
145    Matches,
146}
147
148#[derive(Error, Debug, Clone, PartialEq)]
149pub enum EvalError {
150    #[error("type error: expected {expected}, got {got}")]
151    TypeError {
152        expected: &'static str,
153        got: &'static str,
154    },
155    #[error("undefined variable: {0}")]
156    UndefinedVariable(String),
157    #[error("undefined function: {0}")]
158    UndefinedFunction(String),
159    #[error("invalid regex: {0}")]
160    InvalidRegex(String),
161    #[error("division by zero")]
162    DivisionByZero,
163    #[error("parse error: {0}")]
164    ParseError(String),
165    #[error("wrong number of arguments for {func}: expected {expected}, got {got}")]
166    WrongArgCount {
167        func: String,
168        expected: usize,
169        got: usize,
170    },
171}
172
173// ============ Parser ============
174
175fn ws<'a, P, O>(p: P) -> impl Parser<&'a str, O, ContextError>
176where
177    P: Parser<&'a str, O, ContextError>,
178{
179    delimited(multispace0, p, multispace0)
180}
181
182fn number(input: &mut &str) -> ModalResult<Expr> {
183    let neg: Option<char> = opt('-').parse_next(input)?;
184    let int_part: &str = digit1.parse_next(input)?;
185    let frac_part: Option<&str> = opt(preceded('.', digit1)).parse_next(input)?;
186
187    let mut s = String::new();
188    if neg.is_some() {
189        s.push('-');
190    }
191    s.push_str(int_part);
192    if let Some(frac) = frac_part {
193        s.push('.');
194        s.push_str(frac);
195    }
196
197    Ok(Expr::Number(s.parse().unwrap()))
198}
199
200fn string_char(input: &mut &str) -> ModalResult<char> {
201    let c: char = none_of('"').parse_next(input)?;
202    if c == '\\' {
203        let escaped: char = any.parse_next(input)?;
204        Ok(match escaped {
205            'n' => '\n',
206            't' => '\t',
207            'r' => '\r',
208            '"' => '"',
209            '\\' => '\\',
210            c => c,
211        })
212    } else {
213        Ok(c)
214    }
215}
216
217fn string_literal(input: &mut &str) -> ModalResult<Expr> {
218    let chars: String = delimited(
219        '"',
220        repeat(0.., string_char).fold(String::new, |mut s, c| {
221            s.push(c);
222            s
223        }),
224        '"',
225    )
226    .parse_next(input)?;
227    Ok(Expr::String(chars))
228}
229
230fn regex_literal(input: &mut &str) -> ModalResult<Expr> {
231    '/'.parse_next(input)?;
232    let mut s = String::new();
233    loop {
234        let c: char = any.parse_next(input)?;
235        if c == '/' {
236            break;
237        }
238        if c == '\\' {
239            let escaped: char = any.parse_next(input)?;
240            s.push('\\');
241            s.push(escaped);
242        } else {
243            s.push(c);
244        }
245    }
246    Ok(Expr::String(s))
247}
248
249fn ident(input: &mut &str) -> ModalResult<String> {
250    let first: char = one_of(|c: char| c.is_ascii_alphabetic() || c == '_').parse_next(input)?;
251    let rest: &str =
252        take_while(0.., |c: char| c.is_ascii_alphanumeric() || c == '_').parse_next(input)?;
253    Ok(format!("{}{}", first, rest))
254}
255
256fn var_or_bool_or_func(input: &mut &str) -> ModalResult<Expr> {
257    let name = ident.parse_next(input)?;
258
259    let _ = multispace0.parse_next(input)?;
260    if input.starts_with('(') {
261        '('.parse_next(input)?;
262        let _ = multispace0.parse_next(input)?;
263        let args: Vec<Expr> = separated(0.., ws(expr), ws(',')).parse_next(input)?;
264        let _ = multispace0.parse_next(input)?;
265        ')'.parse_next(input)?;
266        return Ok(Expr::FuncCall { name, args });
267    }
268
269    match name.as_str() {
270        "true" => Ok(Expr::Bool(true)),
271        "false" => Ok(Expr::Bool(false)),
272        _ => Ok(Expr::Var(name)),
273    }
274}
275
276fn array(input: &mut &str) -> ModalResult<Expr> {
277    let elements: Vec<Expr> = delimited(
278        ('[', multispace0),
279        separated(0.., ws(expr), ws(',')),
280        (multispace0, ']'),
281    )
282    .parse_next(input)?;
283    Ok(Expr::Array(elements))
284}
285
286fn atom(input: &mut &str) -> ModalResult<Expr> {
287    let _ = multispace0.parse_next(input)?;
288    alt((
289        delimited(('(', multispace0), expr, (multispace0, ')')),
290        array,
291        string_literal,
292        regex_literal,
293        number,
294        var_or_bool_or_func,
295    ))
296    .parse_next(input)
297}
298
299fn unary(input: &mut &str) -> ModalResult<Expr> {
300    let _ = multispace0.parse_next(input)?;
301    let neg: Option<char> = opt('-').parse_next(input)?;
302    if neg.is_some() {
303        let e = unary.parse_next(input)?;
304        return Ok(Expr::UnaryOp {
305            op: UnaryOp::Neg,
306            expr: Box::new(e),
307        });
308    }
309    atom(input)
310}
311
312fn pow(input: &mut &str) -> ModalResult<Expr> {
313    let base = unary.parse_next(input)?;
314    let _ = multispace0.parse_next(input)?;
315    let caret: Option<char> = opt('^').parse_next(input)?;
316    if caret.is_some() {
317        let _ = multispace0.parse_next(input)?;
318        let exp = pow.parse_next(input)?;
319        Ok(Expr::BinaryOp {
320            op: BinaryOp::Pow,
321            left: Box::new(base),
322            right: Box::new(exp),
323        })
324    } else {
325        Ok(base)
326    }
327}
328
329fn term(input: &mut &str) -> ModalResult<Expr> {
330    let init = pow.parse_next(input)?;
331
332    repeat(0.., (ws(one_of(['*', '/'])), pow))
333        .fold(
334            move || init.clone(),
335            |acc, (op_char, val): (char, Expr)| {
336                let op = if op_char == '*' {
337                    BinaryOp::Mul
338                } else {
339                    BinaryOp::Div
340                };
341                Expr::BinaryOp {
342                    op,
343                    left: Box::new(acc),
344                    right: Box::new(val),
345                }
346            },
347        )
348        .parse_next(input)
349}
350
351fn arith(input: &mut &str) -> ModalResult<Expr> {
352    let init = term.parse_next(input)?;
353
354    repeat(0.., (ws(one_of(['+', '-'])), term))
355        .fold(
356            move || init.clone(),
357            |acc, (op_char, val): (char, Expr)| {
358                let op = if op_char == '+' {
359                    BinaryOp::Add
360                } else {
361                    BinaryOp::Sub
362                };
363                Expr::BinaryOp {
364                    op,
365                    left: Box::new(acc),
366                    right: Box::new(val),
367                }
368            },
369        )
370        .parse_next(input)
371}
372
373fn peek_non_ident(input: &mut &str) -> ModalResult<()> {
374    let next = input.chars().next();
375    if next
376        .map(|c| c.is_ascii_alphanumeric() || c == '_')
377        .unwrap_or(false)
378    {
379        Err(winnow::error::ErrMode::Backtrack(ContextError::new()))
380    } else {
381        Ok(())
382    }
383}
384
385fn cmp_op(input: &mut &str) -> ModalResult<BinaryOp> {
386    alt((
387        "==".value(BinaryOp::Eq),
388        "!=".value(BinaryOp::Ne),
389        "<=".value(BinaryOp::Le),
390        ">=".value(BinaryOp::Ge),
391        "<".value(BinaryOp::Lt),
392        ">".value(BinaryOp::Gt),
393        terminated("in", peek_non_ident).value(BinaryOp::In),
394        terminated("contains", peek_non_ident).value(BinaryOp::Contains),
395        terminated("startswith", peek_non_ident).value(BinaryOp::StartsWith),
396        terminated("endswith", peek_non_ident).value(BinaryOp::EndsWith),
397        terminated("matches", peek_non_ident).value(BinaryOp::Matches),
398    ))
399    .parse_next(input)
400}
401
402fn comparison(input: &mut &str) -> ModalResult<Expr> {
403    let left = arith.parse_next(input)?;
404    let _ = multispace0.parse_next(input)?;
405
406    let op_opt: Option<BinaryOp> = opt(cmp_op).parse_next(input)?;
407    match op_opt {
408        Some(op) => {
409            let _ = multispace0.parse_next(input)?;
410            let right = arith.parse_next(input)?;
411            Ok(Expr::BinaryOp {
412                op,
413                left: Box::new(left),
414                right: Box::new(right),
415            })
416        }
417        None => Ok(left),
418    }
419}
420
421fn not_expr(input: &mut &str) -> ModalResult<Expr> {
422    let _ = multispace0.parse_next(input)?;
423    let not_kw: Option<&str> = opt(terminated("not", peek_non_ident)).parse_next(input)?;
424    if not_kw.is_some() {
425        let _ = multispace0.parse_next(input)?;
426        let e = not_expr.parse_next(input)?;
427        Ok(Expr::UnaryOp {
428            op: UnaryOp::Not,
429            expr: Box::new(e),
430        })
431    } else {
432        comparison(input)
433    }
434}
435
436fn and_expr(input: &mut &str) -> ModalResult<Expr> {
437    let init = not_expr.parse_next(input)?;
438
439    repeat(
440        0..,
441        preceded((multispace0, "and", peek_non_ident, multispace0), not_expr),
442    )
443    .fold(
444        move || init.clone(),
445        |acc, val| Expr::BinaryOp {
446            op: BinaryOp::And,
447            left: Box::new(acc),
448            right: Box::new(val),
449        },
450    )
451    .parse_next(input)
452}
453
454fn or_expr(input: &mut &str) -> ModalResult<Expr> {
455    let init = and_expr.parse_next(input)?;
456
457    repeat(
458        0..,
459        preceded((multispace0, "or", peek_non_ident, multispace0), and_expr),
460    )
461    .fold(
462        move || init.clone(),
463        |acc, val| Expr::BinaryOp {
464            op: BinaryOp::Or,
465            left: Box::new(acc),
466            right: Box::new(val),
467        },
468    )
469    .parse_next(input)
470}
471
472fn expr(input: &mut &str) -> ModalResult<Expr> {
473    or_expr(input)
474}
475
476pub fn parse(input: &str) -> Result<Expr, EvalError> {
477    let mut input = input.trim();
478    match expr.parse_next(&mut input) {
479        Ok(e) => {
480            let remaining = input.trim();
481            if remaining.is_empty() {
482                Ok(e)
483            } else {
484                Err(EvalError::ParseError(format!(
485                    "unexpected trailing input: {:?}",
486                    remaining
487                )))
488            }
489        }
490        Err(e) => Err(EvalError::ParseError(format!("{:?}", e))),
491    }
492}
493
494// ============ Evaluator ============
495
496pub fn evaluate(expr: &Expr, vars: &HashMap<String, Value>) -> Result<Value, EvalError> {
497    match expr {
498        Expr::Number(n) => Ok(Value::Number(*n)),
499        Expr::String(s) => Ok(Value::String(s.clone())),
500        Expr::Bool(b) => Ok(Value::Bool(*b)),
501        Expr::Var(name) => vars
502            .get(name)
503            .cloned()
504            .ok_or_else(|| EvalError::UndefinedVariable(name.clone())),
505        Expr::Array(elements) => {
506            let values: Result<Vec<_>, _> = elements.iter().map(|e| evaluate(e, vars)).collect();
507            Ok(Value::Array(values?))
508        }
509        Expr::UnaryOp { op, expr } => {
510            let val = evaluate(expr, vars)?;
511            match op {
512                UnaryOp::Not => Ok(Value::Bool(!val.as_bool()?)),
513                UnaryOp::Neg => Ok(Value::Number(-val.as_number()?)),
514            }
515        }
516        Expr::BinaryOp { op, left, right } => eval_binary_op(*op, left, right, vars),
517        Expr::FuncCall { name, args } => eval_func_call(name, args, vars),
518    }
519}
520
521fn eval_func_call(
522    name: &str,
523    args: &[Expr],
524    vars: &HashMap<String, Value>,
525) -> Result<Value, EvalError> {
526    match name {
527        "len" => {
528            if args.len() != 1 {
529                return Err(EvalError::WrongArgCount {
530                    func: name.to_string(),
531                    expected: 1,
532                    got: args.len(),
533                });
534            }
535            let val = evaluate(&args[0], vars)?;
536            match val {
537                Value::String(s) => Ok(Value::Number(s.len() as f64)),
538                Value::Array(a) => Ok(Value::Number(a.len() as f64)),
539                _ => Err(EvalError::TypeError {
540                    expected: "string or array",
541                    got: val.type_name(),
542                }),
543            }
544        }
545        _ => Err(EvalError::UndefinedFunction(name.to_string())),
546    }
547}
548
549fn eval_binary_op(
550    op: BinaryOp,
551    left: &Expr,
552    right: &Expr,
553    vars: &HashMap<String, Value>,
554) -> Result<Value, EvalError> {
555    if op == BinaryOp::And {
556        let l = evaluate(left, vars)?.as_bool()?;
557        if !l {
558            return Ok(Value::Bool(false));
559        }
560        return Ok(Value::Bool(evaluate(right, vars)?.as_bool()?));
561    }
562    if op == BinaryOp::Or {
563        let l = evaluate(left, vars)?.as_bool()?;
564        if l {
565            return Ok(Value::Bool(true));
566        }
567        return Ok(Value::Bool(evaluate(right, vars)?.as_bool()?));
568    }
569
570    let l = evaluate(left, vars)?;
571    let r = evaluate(right, vars)?;
572
573    match op {
574        BinaryOp::Add => Ok(Value::Number(l.as_number()? + r.as_number()?)),
575        BinaryOp::Sub => Ok(Value::Number(l.as_number()? - r.as_number()?)),
576        BinaryOp::Mul => Ok(Value::Number(l.as_number()? * r.as_number()?)),
577        BinaryOp::Div => {
578            let divisor = r.as_number()?;
579            if divisor == 0.0 {
580                Err(EvalError::DivisionByZero)
581            } else {
582                Ok(Value::Number(l.as_number()? / divisor))
583            }
584        }
585        BinaryOp::Pow => Ok(Value::Number(l.as_number()?.powf(r.as_number()?))),
586        BinaryOp::Eq => Ok(Value::Bool(values_equal(&l, &r))),
587        BinaryOp::Ne => Ok(Value::Bool(!values_equal(&l, &r))),
588        BinaryOp::Lt => Ok(Value::Bool(l.as_number()? < r.as_number()?)),
589        BinaryOp::Le => Ok(Value::Bool(l.as_number()? <= r.as_number()?)),
590        BinaryOp::Gt => Ok(Value::Bool(l.as_number()? > r.as_number()?)),
591        BinaryOp::Ge => Ok(Value::Bool(l.as_number()? >= r.as_number()?)),
592        BinaryOp::In => {
593            let arr = r.as_array()?;
594            Ok(Value::Bool(arr.iter().any(|v| values_equal(&l, v))))
595        }
596        BinaryOp::Contains => {
597            let haystack = l.as_string()?;
598            let needle = r.as_string()?;
599            Ok(Value::Bool(haystack.contains(needle)))
600        }
601        BinaryOp::StartsWith => {
602            let s = l.as_string()?;
603            let prefix = r.as_string()?;
604            Ok(Value::Bool(s.starts_with(prefix)))
605        }
606        BinaryOp::EndsWith => {
607            let s = l.as_string()?;
608            let suffix = r.as_string()?;
609            Ok(Value::Bool(s.ends_with(suffix)))
610        }
611        BinaryOp::Matches => {
612            let s = l.as_string()?;
613            let pattern = r.as_string()?;
614            let re =
615                regex::Regex::new(pattern).map_err(|e| EvalError::InvalidRegex(e.to_string()))?;
616            Ok(Value::Bool(re.is_match(s)))
617        }
618        BinaryOp::And | BinaryOp::Or => unreachable!(),
619    }
620}
621
622fn values_equal(a: &Value, b: &Value) -> bool {
623    match (a, b) {
624        (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
625        (Value::String(a), Value::String(b)) => a == b,
626        (Value::Bool(a), Value::Bool(b)) => a == b,
627        (Value::Array(a), Value::Array(b)) => {
628            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
629        }
630        _ => false,
631    }
632}
633
634// ============ Public API ============
635
636pub fn eval_bool(expr_str: &str, vars: &HashMap<String, Value>) -> Result<bool, EvalError> {
637    let ast = parse(expr_str)?;
638    let result = evaluate(&ast, vars)?;
639    result.as_bool()
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    fn vars(pairs: &[(&str, Value)]) -> HashMap<String, Value> {
647        pairs
648            .iter()
649            .map(|(k, v)| (k.to_string(), v.clone()))
650            .collect()
651    }
652
653    #[test]
654    fn test_number_parsing() {
655        assert_eq!(parse("42").unwrap(), Expr::Number(42.0));
656        assert_eq!(parse("0.5").unwrap(), Expr::Number(0.5));
657    }
658
659    #[test]
660    fn test_string_parsing() {
661        assert_eq!(
662            parse(r#""hello""#).unwrap(),
663            Expr::String("hello".to_string())
664        );
665    }
666
667    #[test]
668    fn test_arithmetic() {
669        let v = vars(&[]);
670        assert!(eval_bool("1 + 2 == 3", &v).unwrap());
671        assert!(eval_bool("10 - 3 == 7", &v).unwrap());
672        assert!(eval_bool("4 * 5 == 20", &v).unwrap());
673        assert!(eval_bool("10 / 2 == 5", &v).unwrap());
674        assert!(eval_bool("2 ^ 3 == 8", &v).unwrap());
675        assert!(eval_bool("1 + 2 * 3 == 7", &v).unwrap());
676        assert!(eval_bool("(1 + 2) * 3 == 9", &v).unwrap());
677    }
678
679    #[test]
680    fn test_comparisons() {
681        let v = vars(&[("n", Value::Number(42.0))]);
682        assert!(eval_bool("n > 0", &v).unwrap());
683        assert!(eval_bool("n < 100", &v).unwrap());
684        assert!(eval_bool("n >= 42", &v).unwrap());
685        assert!(eval_bool("n <= 42", &v).unwrap());
686        assert!(eval_bool("n == 42", &v).unwrap());
687        assert!(eval_bool("n != 0", &v).unwrap());
688    }
689
690    #[test]
691    fn test_boolean_logic() {
692        let v = vars(&[("n", Value::Number(42.0))]);
693        assert!(eval_bool("n > 0 and n < 100", &v).unwrap());
694        assert!(eval_bool("n < 0 or n > 0", &v).unwrap());
695        assert!(eval_bool("not (n < 0)", &v).unwrap());
696    }
697
698    #[test]
699    fn test_in_operator() {
700        let v = vars(&[("n", Value::Number(2.0))]);
701        assert!(eval_bool("n in [1, 2, 3]", &v).unwrap());
702        assert!(!eval_bool("n in [4, 5, 6]", &v).unwrap());
703    }
704
705    #[test]
706    fn test_string_operators() {
707        let v = vars(&[("s", Value::String("hello world".to_string()))]);
708        assert!(eval_bool(r#"s contains "world""#, &v).unwrap());
709        assert!(eval_bool(r#"s startswith "hello""#, &v).unwrap());
710        assert!(eval_bool(r#"s endswith "world""#, &v).unwrap());
711    }
712
713    #[test]
714    fn test_regex_matches() {
715        let v = vars(&[("s", Value::String("hello123".to_string()))]);
716        assert!(eval_bool(r#"s matches /^hello\d+$/"#, &v).unwrap());
717    }
718
719    #[test]
720    fn test_len_function() {
721        let v = vars(&[("s", Value::String("hello".to_string()))]);
722        assert!(eval_bool("len(s) == 5", &v).unwrap());
723    }
724
725    #[test]
726    fn test_backslash_in_string() {
727        // Test that backslash is parsed correctly
728        let v = vars(&[("p", Value::String("C:\\Users\\test".to_string()))]);
729
730        // Should contain "test"
731        assert!(eval_bool(r#"p contains "test""#, &v).unwrap());
732
733        // Should contain backslash (escaped in the expression)
734        assert!(eval_bool(r#"p contains "\\""#, &v).unwrap());
735
736        // Should contain "Users"
737        assert!(eval_bool(r#"p contains "Users""#, &v).unwrap());
738    }
739}