lemma/evaluator/
expression.rs

1//! Expression evaluation
2//!
3//! Recursively evaluates expressions to produce literal values.
4
5use super::context::EvaluationContext;
6use crate::{
7    ast::Span, ArithmeticOperation, Expression, ExpressionKind, FactReference, LemmaError,
8    LiteralValue, MathematicalOperator, OperationRecord, OperationResult,
9};
10use rust_decimal::Decimal;
11use std::sync::Arc;
12
13/// Evaluate an expression to produce an operation result
14///
15/// This is the core of the evaluator - recursively processes expressions
16/// and records operations for every step.
17///
18/// When evaluating a rule from a document referenced by a fact (e.g., `employee.some_rule?`
19/// where `employee` is a fact with value `doc other_doc`), pass the fact path via `fact_prefix`
20/// to qualify fact lookups within that rule. For local rules, pass an empty slice.
21pub fn evaluate_expression(
22    expr: &Expression,
23    context: &mut EvaluationContext,
24    fact_prefix: &[String],
25) -> Result<OperationResult, LemmaError> {
26    // Check timeout at the start of every expression evaluation
27    context.check_timeout()?;
28
29    match &expr.kind {
30        ExpressionKind::Literal(lit) => {
31            // Literals evaluate to themselves
32            Ok(OperationResult::Value(lit.clone()))
33        }
34
35        ExpressionKind::FactReference(fact_ref) => {
36            // Look up fact in context, prepending the prefix when evaluating a rule from a referenced document
37            let lookup_ref = if !fact_prefix.is_empty() {
38                // Evaluating a rule from a document referenced by a fact: prepend the fact path
39                // E.g., if `employee` references `doc hr_doc` and we're evaluating `employee.salary?`,
40                // fact references within that rule need the `employee` prefix
41                let mut qualified_reference = fact_prefix.to_vec();
42                qualified_reference.extend_from_slice(&fact_ref.reference);
43                FactReference {
44                    reference: qualified_reference,
45                }
46            } else {
47                // Local rule: use fact reference as-is
48                fact_ref.clone()
49            };
50
51            let value = context.facts.get(&lookup_ref).ok_or_else(|| {
52                LemmaError::Engine(format!("Missing fact: {}", lookup_ref.reference.join(".")))
53            })?;
54
55            // Record operation (convert path to string for display)
56            context.operations.push(OperationRecord::FactUsed {
57                name: lookup_ref.reference.join("."),
58                value: value.clone(),
59            });
60
61            Ok(OperationResult::Value(value.clone()))
62        }
63        ExpressionKind::RuleReference(rule_ref) => {
64            // Look up already-computed rule result
65            // Topological sort ensures this rule was computed before us
66            let rule_path = crate::RulePath::from_reference(
67                &rule_ref.reference,
68                context.current_doc,
69                context.all_documents,
70            )?;
71
72            // Check if rule has a result
73            if let Some(result) = context.rule_results.get(&rule_path) {
74                match result {
75                    OperationResult::Veto(msg) => {
76                        // Rule was vetoed - the veto applies to this rule too
77                        return Ok(OperationResult::Veto(msg.clone()));
78                    }
79                    OperationResult::Value(value) => {
80                        // Record operation
81                        context.operations.push(OperationRecord::RuleUsed {
82                            name: rule_path.to_string(),
83                            value: value.clone(),
84                        });
85                        return Ok(OperationResult::Value(value.clone()));
86                    }
87                }
88            }
89
90            // Rule not computed yet
91            Err(LemmaError::Engine(format!("Rule {} not found", rule_path)))
92        }
93
94        ExpressionKind::Arithmetic(left, op, right) => {
95            let left_result = evaluate_expression(left, context, fact_prefix)?;
96            let right_result = evaluate_expression(right, context, fact_prefix)?;
97
98            // If either operand is vetoed, propagate the veto
99            if let OperationResult::Veto(msg) = left_result {
100                return Ok(OperationResult::Veto(msg));
101            }
102            if let OperationResult::Veto(msg) = right_result {
103                return Ok(OperationResult::Veto(msg));
104            }
105
106            // Both operands must have values at this point
107            let left_val = left_result.expect_value("arithmetic left operand")?;
108            let right_val = right_result.expect_value("arithmetic right operand")?;
109
110            // Convert Engine errors to Runtime errors with source location
111            let result = super::operations::arithmetic_operation(left_val, op, right_val)
112                .map_err(|e| convert_engine_error_to_runtime(e, expr, context))?;
113
114            // Record operation
115            let op_name = match op {
116                ArithmeticOperation::Add => "add",
117                ArithmeticOperation::Subtract => "subtract",
118                ArithmeticOperation::Multiply => "multiply",
119                ArithmeticOperation::Divide => "divide",
120                ArithmeticOperation::Modulo => "modulo",
121                ArithmeticOperation::Power => "power",
122            };
123
124            context.operations.push(OperationRecord::OperationExecuted {
125                operation: op_name.to_string(),
126                inputs: vec![left_val.clone(), right_val.clone()],
127                result: result.clone(),
128                unless_clause_index: None,
129            });
130
131            Ok(OperationResult::Value(result))
132        }
133
134        ExpressionKind::Comparison(left, op, right) => {
135            let left_result = evaluate_expression(left, context, fact_prefix)?;
136            let right_result = evaluate_expression(right, context, fact_prefix)?;
137
138            // If either operand is vetoed, propagate the veto
139            if let OperationResult::Veto(msg) = left_result {
140                return Ok(OperationResult::Veto(msg));
141            }
142            if let OperationResult::Veto(msg) = right_result {
143                return Ok(OperationResult::Veto(msg));
144            }
145
146            // Both operands must have values at this point
147            let left_val = left_result.expect_value("comparison left operand")?;
148            let right_val = right_result.expect_value("comparison right operand")?;
149
150            let result = super::operations::comparison_operation(left_val, op, right_val)?;
151
152            // Record operation
153            let op_name = match op {
154                crate::ComparisonOperator::GreaterThan => "greater_than",
155                crate::ComparisonOperator::LessThan => "less_than",
156                crate::ComparisonOperator::GreaterThanOrEqual => "greater_than_or_equal",
157                crate::ComparisonOperator::LessThanOrEqual => "less_than_or_equal",
158                crate::ComparisonOperator::Equal => "equal",
159                crate::ComparisonOperator::NotEqual => "not_equal",
160                crate::ComparisonOperator::Is => "is",
161                crate::ComparisonOperator::IsNot => "is_not",
162            };
163
164            context.operations.push(OperationRecord::OperationExecuted {
165                operation: op_name.to_string(),
166                inputs: vec![left_val.clone(), right_val.clone()],
167                result: LiteralValue::Boolean(result),
168                unless_clause_index: None,
169            });
170
171            Ok(OperationResult::Value(LiteralValue::Boolean(result)))
172        }
173
174        ExpressionKind::LogicalAnd(left, right) => {
175            let left_result = evaluate_expression(left, context, fact_prefix)?;
176            let right_result = evaluate_expression(right, context, fact_prefix)?;
177
178            // If either operand is vetoed, propagate the veto
179            if let OperationResult::Veto(msg) = left_result {
180                return Ok(OperationResult::Veto(msg));
181            }
182            if let OperationResult::Veto(msg) = right_result {
183                return Ok(OperationResult::Veto(msg));
184            }
185
186            // Both operands must have boolean values at this point
187            let left_val = left_result.expect_value("logical AND left operand")?;
188            let right_val = right_result.expect_value("logical AND right operand")?;
189
190            match (left_val, right_val) {
191                (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
192                    // No operation record for logical operations - only record sub-expressions
193                    Ok(OperationResult::Value(LiteralValue::Boolean(*l && *r)))
194                }
195                _ => Err(LemmaError::Engine(
196                    "Logical AND requires boolean operands".to_string(),
197                )),
198            }
199        }
200
201        ExpressionKind::LogicalOr(left, right) => {
202            let left_result = evaluate_expression(left, context, fact_prefix)?;
203            let right_result = evaluate_expression(right, context, fact_prefix)?;
204
205            // If either operand is vetoed, propagate the veto
206            if let OperationResult::Veto(msg) = left_result {
207                return Ok(OperationResult::Veto(msg));
208            }
209            if let OperationResult::Veto(msg) = right_result {
210                return Ok(OperationResult::Veto(msg));
211            }
212
213            // Both operands must have boolean values at this point
214            let left_val = left_result.expect_value("logical OR left operand")?;
215            let right_val = right_result.expect_value("logical OR right operand")?;
216
217            match (left_val, right_val) {
218                (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
219                    // No operation record for logical operations - only record sub-expressions
220                    Ok(OperationResult::Value(LiteralValue::Boolean(*l || *r)))
221                }
222                _ => Err(LemmaError::Engine(
223                    "Logical OR requires boolean operands".to_string(),
224                )),
225            }
226        }
227
228        ExpressionKind::LogicalNegation(operand, _negation_type) => {
229            let result = evaluate_expression(operand, context, fact_prefix)?;
230
231            // If the operand is vetoed, propagate the veto
232            if let OperationResult::Veto(msg) = result {
233                return Ok(OperationResult::Veto(msg));
234            }
235
236            // Operand must have a value at this point
237            let value = result.expect_value("logical negation operand")?;
238
239            match value {
240                LiteralValue::Boolean(b) => Ok(OperationResult::Value(LiteralValue::Boolean(!b))),
241                _ => Err(LemmaError::Engine(
242                    "Logical NOT requires boolean operand".to_string(),
243                )),
244            }
245        }
246
247        ExpressionKind::UnitConversion(value_expr, target) => {
248            let result = evaluate_expression(value_expr, context, fact_prefix)?;
249
250            // If the value is vetoed, propagate the veto
251            if let OperationResult::Veto(msg) = result {
252                return Ok(OperationResult::Veto(msg));
253            }
254
255            // Value must exist at this point
256            let value = result.expect_value("unit conversion operand")?;
257            let converted = super::units::convert_unit(value, target)?;
258            Ok(OperationResult::Value(converted))
259        }
260
261        ExpressionKind::MathematicalOperator(op, operand) => {
262            evaluate_mathematical_operator(op, operand, context, fact_prefix)
263        }
264
265        ExpressionKind::Veto(veto_expr) => Ok(OperationResult::Veto(veto_expr.message.clone())),
266
267        ExpressionKind::FactHasAnyValue(fact_ref) => {
268            // Check if fact exists and has a value, with path prefix applied
269            let lookup_ref = if !fact_prefix.is_empty() {
270                let mut qualified_reference = fact_prefix.to_vec();
271                qualified_reference.extend_from_slice(&fact_ref.reference);
272                FactReference {
273                    reference: qualified_reference,
274                }
275            } else {
276                fact_ref.clone()
277            };
278            let has_value = context.facts.contains_key(&lookup_ref);
279            Ok(OperationResult::Value(LiteralValue::Boolean(has_value)))
280        }
281    }
282}
283
284/// Evaluate a mathematical operator (sqrt, sin, cos, etc.)
285fn evaluate_mathematical_operator(
286    op: &MathematicalOperator,
287    operand: &Expression,
288    context: &mut EvaluationContext,
289    fact_prefix: &[String],
290) -> Result<OperationResult, LemmaError> {
291    let result = evaluate_expression(operand, context, fact_prefix)?;
292
293    // If the operand is vetoed, propagate the veto
294    if let OperationResult::Veto(msg) = result {
295        return Ok(OperationResult::Veto(msg));
296    }
297
298    // Operand must have a numeric value at this point
299    let value = result.expect_value("mathematical operator operand")?;
300
301    match value {
302        LiteralValue::Number(n) => {
303            use rust_decimal::prelude::ToPrimitive;
304            let float_val = n.to_f64().ok_or_else(|| {
305                LemmaError::Engine("Cannot convert to float for mathematical operation".to_string())
306            })?;
307
308            match op {
309                // Float-based functions
310                MathematicalOperator::Sqrt
311                | MathematicalOperator::Sin
312                | MathematicalOperator::Cos
313                | MathematicalOperator::Tan
314                | MathematicalOperator::Asin
315                | MathematicalOperator::Acos
316                | MathematicalOperator::Atan
317                | MathematicalOperator::Log
318                | MathematicalOperator::Exp => {
319                    let math_result = match op {
320                        MathematicalOperator::Sqrt => float_val.sqrt(),
321                        MathematicalOperator::Sin => float_val.sin(),
322                        MathematicalOperator::Cos => float_val.cos(),
323                        MathematicalOperator::Tan => float_val.tan(),
324                        MathematicalOperator::Asin => float_val.asin(),
325                        MathematicalOperator::Acos => float_val.acos(),
326                        MathematicalOperator::Atan => float_val.atan(),
327                        MathematicalOperator::Log => float_val.ln(),
328                        MathematicalOperator::Exp => float_val.exp(),
329                        _ => unreachable!(),
330                    };
331                    let decimal_result =
332                        Decimal::from_f64_retain(math_result).ok_or_else(|| {
333                            LemmaError::Engine(
334                                "Mathematical operation result cannot be represented".to_string(),
335                            )
336                        })?;
337                    Ok(OperationResult::Value(LiteralValue::Number(decimal_result)))
338                }
339                // Decimal-native functions
340                MathematicalOperator::Abs => {
341                    Ok(OperationResult::Value(LiteralValue::Number(n.abs())))
342                }
343                MathematicalOperator::Floor => {
344                    Ok(OperationResult::Value(LiteralValue::Number(n.floor())))
345                }
346                MathematicalOperator::Ceil => {
347                    Ok(OperationResult::Value(LiteralValue::Number(n.ceil())))
348                }
349                MathematicalOperator::Round => {
350                    Ok(OperationResult::Value(LiteralValue::Number(n.round())))
351                }
352            }
353        }
354        _ => Err(LemmaError::Engine(
355            "Mathematical operators require number operands".to_string(),
356        )),
357    }
358}
359
360/// Convert an Engine error to a Runtime error with proper source location
361///
362/// This is used to add span information to errors that occur during expression evaluation.
363fn convert_engine_error_to_runtime(
364    error: LemmaError,
365    expr: &Expression,
366    context: &EvaluationContext,
367) -> LemmaError {
368    match error {
369        LemmaError::Engine(msg) => {
370            let span = expr.span.clone().unwrap_or(Span {
371                start: 0,
372                end: 0,
373                line: 0,
374                col: 0,
375            });
376
377            let source_id = context
378                .current_doc
379                .source
380                .as_ref()
381                .cloned()
382                .unwrap_or_else(|| "<input>".to_string());
383
384            let source_text: Arc<str> = context
385                .sources
386                .get(&source_id)
387                .map(|s| Arc::from(s.as_str()))
388                .unwrap_or_else(|| Arc::from(""));
389
390            let suggestion = if msg.contains("division") || msg.contains("zero") {
391                Some(
392                    "Consider using an 'unless' clause to guard against division by zero"
393                        .to_string(),
394                )
395            } else if msg.contains("type") || msg.contains("mismatch") {
396                Some("Check that operands have compatible types".to_string())
397            } else {
398                None
399            };
400
401            LemmaError::Runtime(Box::new(crate::error::ErrorDetails {
402                message: msg,
403                span,
404                source_id,
405                source_text,
406                doc_name: context.current_doc.name.clone(),
407                doc_start_line: context.current_doc.start_line,
408                suggestion,
409            }))
410        }
411        other => other,
412    }
413}