Skip to main content

paramodel_elements/
expression.rs

1// Copyright (c) Jonathan Shook
2// SPDX-License-Identifier: Apache-2.0
3
4//! Expression language for [`DerivedParameter`](crate::parameter::DerivedParameter).
5//!
6//! A declarative AST — literals, parameter references, binary/unary
7//! ops, built-in functions, and conditionals — serde-able end to end so
8//! plan files fingerprint deterministically. Evaluated against a
9//! [`ValueBindings`] map (the already-bound parameter values).
10//!
11//! Type checking happens at [`Expression::eval`]: an expression is only
12//! statically checked for shape at construction (the AST type system
13//! does that); kind errors surface at bind time with
14//! [`DerivationError::TypeMismatch`]. This matches the upstream Java
15//! behaviour where `compute()` throws at runtime.
16
17use std::collections::HashMap;
18
19use serde::{Deserialize, Serialize};
20
21use crate::names::ParameterName;
22use crate::value::{Value, ValueKind};
23
24/// Parameter-name → value lookup passed to [`Expression::eval`].
25///
26/// Aliased rather than newtyped so callers can build bindings however
27/// is convenient — from iterators, from a hand-constructed `HashMap`,
28/// whatever. A later slice can upgrade this to a trait once the plan-
29/// binder lands.
30pub type ValueBindings = HashMap<ParameterName, Value>;
31
32// ---------------------------------------------------------------------------
33// DerivationError.
34// ---------------------------------------------------------------------------
35
36/// Errors produced at [`Expression::eval`] time.
37#[derive(Debug, thiserror::Error, PartialEq, Eq)]
38pub enum DerivationError {
39    /// A `Ref(name)` pointed at a parameter that isn't in the bindings.
40    #[error("unknown parameter reference: {0}")]
41    UnknownParameter(ParameterName),
42
43    /// An operation received operands of the wrong kind.
44    #[error("type mismatch in {op}: expected {expected}, got {actual}")]
45    TypeMismatch {
46        /// Operator or builtin name.
47        op:       String,
48        /// The kind the operator expected.
49        expected: String,
50        /// The kind that was actually supplied.
51        actual:   String,
52    },
53
54    /// Integer division or modulo by zero.
55    #[error("division by zero")]
56    DivisionByZero,
57
58    /// A builtin received the wrong number of arguments.
59    #[error("{builtin} expects {expected} argument(s), got {actual}")]
60    InvalidArity {
61        /// Builtin name.
62        builtin:  String,
63        /// Expected arity.
64        expected: usize,
65        /// Actual arity.
66        actual:   usize,
67    },
68
69    /// Selection values can't flow through arithmetic/logic expressions.
70    #[error("selection values are not supported in derivation expressions")]
71    SelectionNotSupported,
72}
73
74// ---------------------------------------------------------------------------
75// EvalValue — the typed result of evaluating an Expression.
76// ---------------------------------------------------------------------------
77
78/// Typed value produced by [`Expression::eval`].
79///
80/// Distinct from [`Value`] because an intermediate expression result
81/// has no parameter name — it's a raw typed value that the calling
82/// derived parameter then wraps with its own name and provenance.
83#[derive(Debug, Clone, PartialEq)]
84pub enum EvalValue {
85    /// 64-bit signed integer.
86    Integer(i64),
87    /// IEEE-754 `f64`.
88    Double(f64),
89    /// Boolean.
90    Boolean(bool),
91    /// UTF-8 string.
92    String(String),
93}
94
95impl EvalValue {
96    /// Discriminator.
97    #[must_use]
98    pub const fn kind(&self) -> ValueKind {
99        match self {
100            Self::Integer(_) => ValueKind::Integer,
101            Self::Double(_) => ValueKind::Double,
102            Self::Boolean(_) => ValueKind::Boolean,
103            Self::String(_) => ValueKind::String,
104        }
105    }
106
107    /// `i64` accessor.
108    #[must_use]
109    pub const fn as_integer(&self) -> Option<i64> {
110        if let Self::Integer(v) = self {
111            Some(*v)
112        } else {
113            None
114        }
115    }
116
117    /// `f64` accessor.
118    #[must_use]
119    pub const fn as_double(&self) -> Option<f64> {
120        if let Self::Double(v) = self {
121            Some(*v)
122        } else {
123            None
124        }
125    }
126
127    /// `bool` accessor.
128    #[must_use]
129    pub const fn as_boolean(&self) -> Option<bool> {
130        if let Self::Boolean(v) = self {
131            Some(*v)
132        } else {
133            None
134        }
135    }
136
137    /// `str` accessor.
138    #[must_use]
139    pub fn as_string(&self) -> Option<&str> {
140        if let Self::String(v) = self {
141            Some(v)
142        } else {
143            None
144        }
145    }
146
147    const fn kind_label(&self) -> &'static str {
148        match self {
149            Self::Integer(_) => "integer",
150            Self::Double(_) => "double",
151            Self::Boolean(_) => "boolean",
152            Self::String(_) => "string",
153        }
154    }
155}
156
157impl TryFrom<&Value> for EvalValue {
158    type Error = DerivationError;
159    fn try_from(v: &Value) -> Result<Self, Self::Error> {
160        Ok(match v {
161            Value::Integer(i) => Self::Integer(i.value),
162            Value::Double(d) => Self::Double(d.value),
163            Value::Boolean(b) => Self::Boolean(b.value),
164            Value::String(s) => Self::String(s.value.clone()),
165            Value::Selection(_) => return Err(DerivationError::SelectionNotSupported),
166        })
167    }
168}
169
170// ---------------------------------------------------------------------------
171// Expression AST.
172// ---------------------------------------------------------------------------
173
174/// A literal value embedded in an expression.
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
176#[serde(tag = "kind", rename_all = "snake_case")]
177pub enum Literal {
178    /// `i64` literal.
179    Integer {
180        /// The integer.
181        value: i64,
182    },
183    /// `f64` literal. `NaN` is not forbidden here but arithmetic will
184    /// propagate it per IEEE-754.
185    Double {
186        /// The float.
187        value: f64,
188    },
189    /// Boolean literal.
190    Boolean {
191        /// The flag.
192        value: bool,
193    },
194    /// String literal.
195    String {
196        /// The text.
197        value: String,
198    },
199}
200
201/// Binary operators.
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
203#[serde(rename_all = "snake_case")]
204pub enum BinOp {
205    /// Arithmetic addition. Numeric operands only.
206    Add,
207    /// Arithmetic subtraction.
208    Sub,
209    /// Arithmetic multiplication.
210    Mul,
211    /// Arithmetic division. Integer division rejects zero divisors.
212    Div,
213    /// Arithmetic modulo. Integer only; rejects zero divisors.
214    Mod,
215    /// Equality. Same-kind operands.
216    Eq,
217    /// Inequality.
218    Ne,
219    /// Less than. Numeric.
220    Lt,
221    /// Less than or equal. Numeric.
222    Le,
223    /// Greater than. Numeric.
224    Gt,
225    /// Greater than or equal. Numeric.
226    Ge,
227    /// Logical AND. Boolean operands.
228    And,
229    /// Logical OR. Boolean operands.
230    Or,
231}
232
233/// Unary operators.
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
235#[serde(rename_all = "snake_case")]
236pub enum UnOp {
237    /// Numeric negation.
238    Neg,
239    /// Boolean negation.
240    Not,
241}
242
243/// Built-in functions.
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
245#[serde(rename_all = "snake_case")]
246pub enum BuiltinFn {
247    /// `ceil(f64) -> i64`.
248    Ceil,
249    /// `floor(f64) -> i64`.
250    Floor,
251    /// `round(f64) -> i64` (ties-to-even per IEEE).
252    Round,
253    /// `min(numeric, numeric, ...) -> numeric` (same kind).
254    Min,
255    /// `max(numeric, numeric, ...) -> numeric`.
256    Max,
257    /// `abs(numeric) -> numeric`.
258    Abs,
259    /// `pow(f64, f64) -> f64`.
260    Pow,
261    /// `len(string) -> i64` (byte length).
262    Len,
263}
264
265impl BuiltinFn {
266    const fn label(self) -> &'static str {
267        match self {
268            Self::Ceil => "ceil",
269            Self::Floor => "floor",
270            Self::Round => "round",
271            Self::Min => "min",
272            Self::Max => "max",
273            Self::Abs => "abs",
274            Self::Pow => "pow",
275            Self::Len => "len",
276        }
277    }
278}
279
280/// An expression tree. Serde-able end to end.
281#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
282#[serde(tag = "node", rename_all = "snake_case")]
283pub enum Expression {
284    /// Literal value.
285    Literal {
286        /// The literal.
287        value: Literal,
288    },
289    /// Reference to a previously-bound parameter value.
290    Ref {
291        /// The parameter name.
292        name: ParameterName,
293    },
294    /// Binary operation.
295    BinOp {
296        /// Operator.
297        op:  BinOp,
298        /// Left operand.
299        lhs: Box<Self>,
300        /// Right operand.
301        rhs: Box<Self>,
302    },
303    /// Unary operation.
304    UnOp {
305        /// Operator.
306        op:  UnOp,
307        /// Operand.
308        arg: Box<Self>,
309    },
310    /// Built-in function call.
311    Call {
312        /// Builtin.
313        func: BuiltinFn,
314        /// Arguments.
315        args: Vec<Self>,
316    },
317    /// Conditional (ternary).
318    If {
319        /// Boolean condition.
320        cond:  Box<Self>,
321        /// Value when `cond` is true.
322        then_: Box<Self>,
323        /// Value when `cond` is false.
324        else_: Box<Self>,
325    },
326}
327
328impl Expression {
329    /// Convenience constructor for a literal.
330    #[must_use]
331    pub const fn literal(lit: Literal) -> Self {
332        Self::Literal { value: lit }
333    }
334
335    /// Convenience constructor for a parameter reference.
336    #[must_use]
337    pub const fn reference(name: ParameterName) -> Self {
338        Self::Ref { name }
339    }
340
341    /// Convenience constructor for a binary operation.
342    #[must_use]
343    pub fn binop(op: BinOp, lhs: Self, rhs: Self) -> Self {
344        Self::BinOp {
345            op,
346            lhs: Box::new(lhs),
347            rhs: Box::new(rhs),
348        }
349    }
350
351    /// Convenience constructor for a unary operation.
352    #[must_use]
353    pub fn unop(op: UnOp, arg: Self) -> Self {
354        Self::UnOp {
355            op,
356            arg: Box::new(arg),
357        }
358    }
359
360    /// Convenience constructor for a builtin call.
361    #[must_use]
362    pub const fn call(func: BuiltinFn, args: Vec<Self>) -> Self {
363        Self::Call { func, args }
364    }
365
366    /// Convenience constructor for an if-expression.
367    #[must_use]
368    pub fn if_then_else(cond: Self, then_: Self, else_: Self) -> Self {
369        Self::If {
370            cond:  Box::new(cond),
371            then_: Box::new(then_),
372            else_: Box::new(else_),
373        }
374    }
375
376    /// Evaluate this expression against the given bindings.
377    pub fn eval(&self, bindings: &ValueBindings) -> Result<EvalValue, DerivationError> {
378        match self {
379            Self::Literal { value } => Ok(eval_literal(value)),
380            Self::Ref { name } => bindings.get(name).map_or_else(
381                || Err(DerivationError::UnknownParameter(name.clone())),
382                EvalValue::try_from,
383            ),
384            Self::BinOp { op, lhs, rhs } => {
385                let l = lhs.eval(bindings)?;
386                let r = rhs.eval(bindings)?;
387                eval_binop(*op, l, r)
388            }
389            Self::UnOp { op, arg } => {
390                let a = arg.eval(bindings)?;
391                eval_unop(*op, a)
392            }
393            Self::Call { func, args } => {
394                let vs: Result<Vec<EvalValue>, _> =
395                    args.iter().map(|a| a.eval(bindings)).collect();
396                eval_call(*func, vs?)
397            }
398            Self::If { cond, then_, else_ } => {
399                let c = cond.eval(bindings)?;
400                match c {
401                    EvalValue::Boolean(true) => then_.eval(bindings),
402                    EvalValue::Boolean(false) => else_.eval(bindings),
403                    other => Err(type_mismatch("if", "boolean", &other)),
404                }
405            }
406        }
407    }
408}
409
410// ---------------------------------------------------------------------------
411// Evaluation helpers.
412// ---------------------------------------------------------------------------
413
414fn eval_literal(lit: &Literal) -> EvalValue {
415    match lit {
416        Literal::Integer { value } => EvalValue::Integer(*value),
417        Literal::Double { value } => EvalValue::Double(*value),
418        Literal::Boolean { value } => EvalValue::Boolean(*value),
419        Literal::String { value } => EvalValue::String(value.clone()),
420    }
421}
422
423fn type_mismatch(op: &str, expected: &str, actual: &EvalValue) -> DerivationError {
424    DerivationError::TypeMismatch {
425        op:       op.to_owned(),
426        expected: expected.to_owned(),
427        actual:   actual.kind_label().to_owned(),
428    }
429}
430
431fn eval_binop(op: BinOp, l: EvalValue, r: EvalValue) -> Result<EvalValue, DerivationError> {
432    use EvalValue as E;
433    let op_label = binop_label(op);
434    match op {
435        BinOp::Add | BinOp::Sub | BinOp::Mul => match (l, r) {
436            (E::Integer(a), E::Integer(b)) => Ok(E::Integer(match op {
437                BinOp::Add => a.wrapping_add(b),
438                BinOp::Sub => a.wrapping_sub(b),
439                BinOp::Mul => a.wrapping_mul(b),
440                _ => unreachable!(),
441            })),
442            (E::Double(a), E::Double(b)) => Ok(E::Double(match op {
443                BinOp::Add => a + b,
444                BinOp::Sub => a - b,
445                BinOp::Mul => a * b,
446                _ => unreachable!(),
447            })),
448            (a, _) => Err(type_mismatch(op_label, "matching numeric operands", &a)),
449        },
450        BinOp::Div => match (l, r) {
451            (E::Integer(_), E::Integer(0)) => Err(DerivationError::DivisionByZero),
452            (E::Integer(a), E::Integer(b)) => Ok(E::Integer(a / b)),
453            (E::Double(a), E::Double(b)) => Ok(E::Double(a / b)),
454            (a, _) => Err(type_mismatch(op_label, "matching numeric operands", &a)),
455        },
456        BinOp::Mod => match (l, r) {
457            (E::Integer(_), E::Integer(0)) => Err(DerivationError::DivisionByZero),
458            (E::Integer(a), E::Integer(b)) => Ok(E::Integer(a % b)),
459            (a, _) => Err(type_mismatch(op_label, "integer operands", &a)),
460        },
461        BinOp::Eq => Ok(E::Boolean(values_equal(&l, &r)?)),
462        BinOp::Ne => Ok(E::Boolean(!values_equal(&l, &r)?)),
463        BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge => compare_numeric(op, &l, &r),
464        BinOp::And => match (l, r) {
465            (E::Boolean(a), E::Boolean(b)) => Ok(E::Boolean(a && b)),
466            (a, _) => Err(type_mismatch(op_label, "boolean operands", &a)),
467        },
468        BinOp::Or => match (l, r) {
469            (E::Boolean(a), E::Boolean(b)) => Ok(E::Boolean(a || b)),
470            (a, _) => Err(type_mismatch(op_label, "boolean operands", &a)),
471        },
472    }
473}
474
475const fn binop_label(op: BinOp) -> &'static str {
476    match op {
477        BinOp::Add => "+",
478        BinOp::Sub => "-",
479        BinOp::Mul => "*",
480        BinOp::Div => "/",
481        BinOp::Mod => "%",
482        BinOp::Eq => "==",
483        BinOp::Ne => "!=",
484        BinOp::Lt => "<",
485        BinOp::Le => "<=",
486        BinOp::Gt => ">",
487        BinOp::Ge => ">=",
488        BinOp::And => "&&",
489        BinOp::Or => "||",
490    }
491}
492
493fn values_equal(l: &EvalValue, r: &EvalValue) -> Result<bool, DerivationError> {
494    use EvalValue as E;
495    Ok(match (l, r) {
496        (E::Integer(a), E::Integer(b)) => a == b,
497        #[allow(
498            clippy::float_cmp,
499            reason = "IEEE equality is the intended semantics here"
500        )]
501        (E::Double(a), E::Double(b)) => a == b,
502        (E::Boolean(a), E::Boolean(b)) => a == b,
503        (E::String(a), E::String(b)) => a == b,
504        (a, _) => return Err(type_mismatch("eq", "matching operands", a)),
505    })
506}
507
508fn compare_numeric(
509    op: BinOp,
510    l:  &EvalValue,
511    r:  &EvalValue,
512) -> Result<EvalValue, DerivationError> {
513    use std::cmp::Ordering;
514
515    use EvalValue as E;
516    let op_label = binop_label(op);
517    let ord = match (l, r) {
518        (E::Integer(a), E::Integer(b)) => a.cmp(b),
519        (E::Double(a), E::Double(b)) => a.total_cmp(b),
520        (a, _) => return Err(type_mismatch(op_label, "matching numeric operands", a)),
521    };
522    Ok(EvalValue::Boolean(match op {
523        BinOp::Lt => ord == Ordering::Less,
524        BinOp::Le => ord != Ordering::Greater,
525        BinOp::Gt => ord == Ordering::Greater,
526        BinOp::Ge => ord != Ordering::Less,
527        _ => unreachable!(),
528    }))
529}
530
531fn eval_unop(op: UnOp, a: EvalValue) -> Result<EvalValue, DerivationError> {
532    use EvalValue as E;
533    match op {
534        UnOp::Neg => match a {
535            E::Integer(n) => Ok(E::Integer(n.wrapping_neg())),
536            E::Double(n) => Ok(E::Double(-n)),
537            other => Err(type_mismatch("neg", "numeric operand", &other)),
538        },
539        UnOp::Not => match a {
540            E::Boolean(b) => Ok(E::Boolean(!b)),
541            other => Err(type_mismatch("not", "boolean operand", &other)),
542        },
543    }
544}
545
546fn eval_call(func: BuiltinFn, args: Vec<EvalValue>) -> Result<EvalValue, DerivationError> {
547    use EvalValue as E;
548    match func {
549        BuiltinFn::Ceil | BuiltinFn::Floor | BuiltinFn::Round => {
550            check_arity(func, &args, 1)?;
551            match &args[0] {
552                E::Double(v) => {
553                    let folded = match func {
554                        BuiltinFn::Ceil => v.ceil(),
555                        BuiltinFn::Floor => v.floor(),
556                        BuiltinFn::Round => v.round_ties_even(),
557                        _ => unreachable!(),
558                    };
559                    // Clamp to i64 range so we can return Integer. The
560                    // precision/truncation concerns are by construction:
561                    // any Integer-shaped output from ceil/floor/round of
562                    // an f64 rounds through the i64 range.
563                    #[allow(
564                        clippy::cast_precision_loss,
565                        clippy::cast_possible_truncation,
566                        reason = "deliberate i64↔f64 projection for ceil/floor/round"
567                    )]
568                    let clamped = folded.clamp(i64::MIN as f64, i64::MAX as f64) as i64;
569                    Ok(E::Integer(clamped))
570                }
571                other => Err(type_mismatch(func.label(), "double", other)),
572            }
573        }
574        BuiltinFn::Min | BuiltinFn::Max => {
575            if args.len() < 2 {
576                return Err(DerivationError::InvalidArity {
577                    builtin:  func.label().to_owned(),
578                    expected: 2,
579                    actual:   args.len(),
580                });
581            }
582            fold_minmax(func, args)
583        }
584        BuiltinFn::Abs => {
585            check_arity(func, &args, 1)?;
586            match &args[0] {
587                E::Integer(n) => Ok(E::Integer(n.wrapping_abs())),
588                E::Double(n) => Ok(E::Double(n.abs())),
589                other => Err(type_mismatch(func.label(), "numeric", other)),
590            }
591        }
592        BuiltinFn::Pow => {
593            check_arity(func, &args, 2)?;
594            match (&args[0], &args[1]) {
595                (E::Double(b), E::Double(e)) => Ok(E::Double(b.powf(*e))),
596                (a, _) => Err(type_mismatch(func.label(), "double base", a)),
597            }
598        }
599        BuiltinFn::Len => {
600            check_arity(func, &args, 1)?;
601            match &args[0] {
602                E::String(s) => {
603                    let len = i64::try_from(s.len()).expect("string length fits in i64");
604                    Ok(E::Integer(len))
605                }
606                other => Err(type_mismatch(func.label(), "string", other)),
607            }
608        }
609    }
610}
611
612fn check_arity(
613    func:     BuiltinFn,
614    args:     &[EvalValue],
615    expected: usize,
616) -> Result<(), DerivationError> {
617    if args.len() == expected {
618        Ok(())
619    } else {
620        Err(DerivationError::InvalidArity {
621            builtin: func.label().to_owned(),
622            expected,
623            actual:  args.len(),
624        })
625    }
626}
627
628fn fold_minmax(
629    func: BuiltinFn,
630    args: Vec<EvalValue>,
631) -> Result<EvalValue, DerivationError> {
632    use EvalValue as E;
633    let mut iter = args.into_iter();
634    let first = iter.next().expect("≥2 args verified above");
635    iter.try_fold(first, |acc, next| match (acc, next) {
636        (E::Integer(a), E::Integer(b)) => Ok(E::Integer(if func == BuiltinFn::Min {
637            a.min(b)
638        } else {
639            a.max(b)
640        })),
641        (E::Double(a), E::Double(b)) => Ok(E::Double(if func == BuiltinFn::Min {
642            a.min(b)
643        } else {
644            a.max(b)
645        })),
646        (a, _) => Err(type_mismatch(func.label(), "matching numeric arguments", &a)),
647    })
648}
649
650// ---------------------------------------------------------------------------
651// Tests.
652// ---------------------------------------------------------------------------
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657
658    fn pname(s: &str) -> ParameterName {
659        ParameterName::new(s).unwrap()
660    }
661
662    fn empty_bindings() -> ValueBindings {
663        ValueBindings::new()
664    }
665
666    // ---------- Literal / Ref ----------
667
668    #[test]
669    fn literal_evaluates_to_eval_value() {
670        let b = empty_bindings();
671        assert_eq!(
672            Expression::literal(Literal::Integer { value: 7 }).eval(&b).unwrap(),
673            EvalValue::Integer(7)
674        );
675        assert_eq!(
676            Expression::literal(Literal::Boolean { value: true }).eval(&b).unwrap(),
677            EvalValue::Boolean(true)
678        );
679    }
680
681    #[test]
682    fn ref_reads_bindings_or_errors() {
683        let mut b = empty_bindings();
684        b.insert(pname("threads"), Value::integer(pname("threads"), 8, None));
685        let got = Expression::reference(pname("threads")).eval(&b).unwrap();
686        assert_eq!(got, EvalValue::Integer(8));
687
688        let err = Expression::reference(pname("missing")).eval(&b).unwrap_err();
689        assert!(matches!(err, DerivationError::UnknownParameter(_)));
690    }
691
692    #[test]
693    fn ref_rejects_selection_values() {
694        use crate::value::SelectionItem;
695        use indexmap::IndexSet;
696        let mut set = IndexSet::new();
697        set.insert(SelectionItem::new("a").unwrap());
698        let mut b = empty_bindings();
699        b.insert(pname("pick"), Value::selection(pname("pick"), set, None));
700        let err = Expression::reference(pname("pick")).eval(&b).unwrap_err();
701        assert_eq!(err, DerivationError::SelectionNotSupported);
702    }
703
704    // ---------- Arithmetic ----------
705
706    #[test]
707    fn integer_arithmetic() {
708        let b = empty_bindings();
709        let e = Expression::binop(
710            BinOp::Add,
711            Expression::literal(Literal::Integer { value: 3 }),
712            Expression::literal(Literal::Integer { value: 4 }),
713        );
714        assert_eq!(e.eval(&b).unwrap(), EvalValue::Integer(7));
715
716        let e = Expression::binop(
717            BinOp::Mul,
718            Expression::literal(Literal::Integer { value: 6 }),
719            Expression::literal(Literal::Integer { value: 7 }),
720        );
721        assert_eq!(e.eval(&b).unwrap(), EvalValue::Integer(42));
722    }
723
724    #[test]
725    fn double_arithmetic() {
726        let b = empty_bindings();
727        let e = Expression::binop(
728            BinOp::Sub,
729            Expression::literal(Literal::Double { value: 1.5 }),
730            Expression::literal(Literal::Double { value: 0.5 }),
731        );
732        assert_eq!(e.eval(&b).unwrap(), EvalValue::Double(1.0));
733    }
734
735    #[test]
736    fn integer_division_by_zero_errors() {
737        let b = empty_bindings();
738        let e = Expression::binop(
739            BinOp::Div,
740            Expression::literal(Literal::Integer { value: 1 }),
741            Expression::literal(Literal::Integer { value: 0 }),
742        );
743        assert_eq!(e.eval(&b).unwrap_err(), DerivationError::DivisionByZero);
744    }
745
746    #[test]
747    fn mod_rejects_non_integer() {
748        let b = empty_bindings();
749        let e = Expression::binop(
750            BinOp::Mod,
751            Expression::literal(Literal::Double { value: 1.0 }),
752            Expression::literal(Literal::Double { value: 2.0 }),
753        );
754        assert!(matches!(
755            e.eval(&b).unwrap_err(),
756            DerivationError::TypeMismatch { .. }
757        ));
758    }
759
760    // ---------- Comparison / boolean ----------
761
762    #[test]
763    fn comparisons_yield_boolean() {
764        let b = empty_bindings();
765        let e = Expression::binop(
766            BinOp::Lt,
767            Expression::literal(Literal::Integer { value: 3 }),
768            Expression::literal(Literal::Integer { value: 4 }),
769        );
770        assert_eq!(e.eval(&b).unwrap(), EvalValue::Boolean(true));
771    }
772
773    #[test]
774    fn equality_across_kinds() {
775        let b = empty_bindings();
776        let e = Expression::binop(
777            BinOp::Eq,
778            Expression::literal(Literal::String { value: "a".into() }),
779            Expression::literal(Literal::String { value: "a".into() }),
780        );
781        assert_eq!(e.eval(&b).unwrap(), EvalValue::Boolean(true));
782    }
783
784    #[test]
785    fn logical_and_or_and_not() {
786        let b = empty_bindings();
787        let t = Expression::literal(Literal::Boolean { value: true });
788        let f = Expression::literal(Literal::Boolean { value: false });
789        assert_eq!(
790            Expression::binop(BinOp::And, t.clone(), f.clone()).eval(&b).unwrap(),
791            EvalValue::Boolean(false)
792        );
793        assert_eq!(
794            Expression::binop(BinOp::Or, t, f.clone()).eval(&b).unwrap(),
795            EvalValue::Boolean(true)
796        );
797        assert_eq!(
798            Expression::unop(UnOp::Not, f).eval(&b).unwrap(),
799            EvalValue::Boolean(true)
800        );
801    }
802
803    // ---------- Builtins ----------
804
805    #[test]
806    fn min_and_max() {
807        let b = empty_bindings();
808        let args = vec![
809            Expression::literal(Literal::Integer { value: 3 }),
810            Expression::literal(Literal::Integer { value: 1 }),
811            Expression::literal(Literal::Integer { value: 2 }),
812        ];
813        assert_eq!(
814            Expression::call(BuiltinFn::Min, args.clone()).eval(&b).unwrap(),
815            EvalValue::Integer(1)
816        );
817        assert_eq!(
818            Expression::call(BuiltinFn::Max, args).eval(&b).unwrap(),
819            EvalValue::Integer(3)
820        );
821    }
822
823    #[test]
824    fn ceil_floor_round_return_integer() {
825        let b = empty_bindings();
826        assert_eq!(
827            Expression::call(BuiltinFn::Ceil, vec![Expression::literal(Literal::Double { value: 1.2 })])
828                .eval(&b)
829                .unwrap(),
830            EvalValue::Integer(2)
831        );
832        assert_eq!(
833            Expression::call(BuiltinFn::Floor, vec![Expression::literal(Literal::Double { value: 1.9 })])
834                .eval(&b)
835                .unwrap(),
836            EvalValue::Integer(1)
837        );
838        assert_eq!(
839            Expression::call(BuiltinFn::Round, vec![Expression::literal(Literal::Double { value: 1.5 })])
840                .eval(&b)
841                .unwrap(),
842            EvalValue::Integer(2)
843        );
844    }
845
846    #[test]
847    fn abs_pow_len() {
848        let b = empty_bindings();
849        assert_eq!(
850            Expression::call(BuiltinFn::Abs, vec![Expression::literal(Literal::Integer { value: -5 })])
851                .eval(&b)
852                .unwrap(),
853            EvalValue::Integer(5)
854        );
855        assert_eq!(
856            Expression::call(
857                BuiltinFn::Pow,
858                vec![
859                    Expression::literal(Literal::Double { value: 2.0 }),
860                    Expression::literal(Literal::Double { value: 10.0 }),
861                ]
862            )
863            .eval(&b)
864            .unwrap(),
865            EvalValue::Double(1024.0)
866        );
867        assert_eq!(
868            Expression::call(
869                BuiltinFn::Len,
870                vec![Expression::literal(Literal::String { value: "hello".into() })]
871            )
872            .eval(&b)
873            .unwrap(),
874            EvalValue::Integer(5)
875        );
876    }
877
878    #[test]
879    fn arity_errors() {
880        let b = empty_bindings();
881        let err = Expression::call(BuiltinFn::Abs, vec![]).eval(&b).unwrap_err();
882        assert!(matches!(err, DerivationError::InvalidArity { .. }));
883        let err = Expression::call(BuiltinFn::Min, vec![Expression::literal(Literal::Integer { value: 1 })])
884            .eval(&b)
885            .unwrap_err();
886        assert!(matches!(err, DerivationError::InvalidArity { .. }));
887    }
888
889    // ---------- If ----------
890
891    #[test]
892    fn if_expression_picks_branch() {
893        let b = empty_bindings();
894        let e = Expression::if_then_else(
895            Expression::literal(Literal::Boolean { value: true }),
896            Expression::literal(Literal::Integer { value: 1 }),
897            Expression::literal(Literal::Integer { value: 2 }),
898        );
899        assert_eq!(e.eval(&b).unwrap(), EvalValue::Integer(1));
900    }
901
902    #[test]
903    fn if_condition_must_be_boolean() {
904        let b = empty_bindings();
905        let e = Expression::if_then_else(
906            Expression::literal(Literal::Integer { value: 1 }),
907            Expression::literal(Literal::Integer { value: 1 }),
908            Expression::literal(Literal::Integer { value: 2 }),
909        );
910        assert!(matches!(
911            e.eval(&b).unwrap_err(),
912            DerivationError::TypeMismatch { .. }
913        ));
914    }
915
916    // ---------- serde ----------
917
918    #[test]
919    fn expression_serde_roundtrip() {
920        let e = Expression::if_then_else(
921            Expression::binop(
922                BinOp::Lt,
923                Expression::reference(pname("threads")),
924                Expression::literal(Literal::Integer { value: 16 }),
925            ),
926            Expression::literal(Literal::Integer { value: 8 }),
927            Expression::literal(Literal::Integer { value: 16 }),
928        );
929        let json = serde_json::to_string(&e).unwrap();
930        let back: Expression = serde_json::from_str(&json).unwrap();
931        assert_eq!(e, back);
932    }
933}