lemma/evaluator/
expression.rs

1//! Expression evaluation
2//!
3//! Recursively evaluates expressions to produce literal values.
4
5use super::context::EvaluationContext;
6use crate::{
7    ast::Span, ArithmeticOperation, Expression, ExpressionKind, LemmaError, LiteralValue,
8    MathematicalOperator, OperationRecord, OperationResult,
9};
10use rust_decimal::Decimal;
11use std::sync::Arc;
12
13/// Evaluate an expression to produce an operation result
14///
15/// This is the core of the evaluator - recursively processes expressions
16/// and records operations for every step.
17pub fn evaluate_expression(
18    expr: &Expression,
19    context: &mut EvaluationContext,
20) -> Result<OperationResult, LemmaError> {
21    match &expr.kind {
22        ExpressionKind::Literal(lit) => {
23            // Literals evaluate to themselves
24            Ok(OperationResult::Value(lit.clone()))
25        }
26
27        ExpressionKind::FactReference(fact_ref) => {
28            // Look up fact in context
29            let fact_name = fact_ref.reference.join(".");
30            let value = context
31                .facts
32                .get(&fact_name)
33                .ok_or_else(|| LemmaError::Engine(format!("Missing fact: {}", fact_name)))?;
34
35            // Record operation
36            context.operations.push(OperationRecord::FactUsed {
37                name: fact_name,
38                value: value.clone(),
39            });
40
41            Ok(OperationResult::Value(value.clone()))
42        }
43
44        ExpressionKind::RuleReference(rule_ref) => {
45            // Look up already-computed rule result
46            // Topological sort ensures this rule was computed before us
47            let rule_name = rule_ref.reference.join(".");
48
49            // Check if rule has a result
50            if let Some(result) = context.rule_results.get(&rule_name) {
51                match result {
52                    OperationResult::Veto(msg) => {
53                        // Rule was vetoed - the veto applies to this rule too
54                        return Ok(OperationResult::Veto(msg.clone()));
55                    }
56                    OperationResult::Value(value) => {
57                        // Record operation
58                        context.operations.push(OperationRecord::RuleUsed {
59                            name: rule_name,
60                            value: value.clone(),
61                        });
62                        return Ok(OperationResult::Value(value.clone()));
63                    }
64                }
65            }
66
67            // Rule not computed yet
68            Err(LemmaError::Engine(format!(
69                "Rule {} not yet computed",
70                rule_name
71            )))
72        }
73
74        ExpressionKind::Arithmetic(left, op, right) => {
75            let left_result = evaluate_expression(left, context)?;
76            let right_result = evaluate_expression(right, context)?;
77
78            // If either operand is vetoed, propagate the veto
79            if let OperationResult::Veto(msg) = left_result {
80                return Ok(OperationResult::Veto(msg));
81            }
82            if let OperationResult::Veto(msg) = right_result {
83                return Ok(OperationResult::Veto(msg));
84            }
85
86            let left_val = left_result.value().unwrap();
87            let right_val = right_result.value().unwrap();
88
89            // Convert Engine errors to Runtime errors with source location
90            let result = super::operations::arithmetic_operation(left_val, op, right_val)
91                .map_err(|e| convert_engine_error_to_runtime(e, expr, context))?;
92
93            // Record operation
94            let op_name = match op {
95                ArithmeticOperation::Add => "add",
96                ArithmeticOperation::Subtract => "subtract",
97                ArithmeticOperation::Multiply => "multiply",
98                ArithmeticOperation::Divide => "divide",
99                ArithmeticOperation::Modulo => "modulo",
100                ArithmeticOperation::Power => "power",
101            };
102
103            context.operations.push(OperationRecord::OperationExecuted {
104                operation: op_name.to_string(),
105                inputs: vec![left_val.clone(), right_val.clone()],
106                result: result.clone(),
107                unless_clause_index: None,
108            });
109
110            Ok(OperationResult::Value(result))
111        }
112
113        ExpressionKind::Comparison(left, op, right) => {
114            let left_result = evaluate_expression(left, context)?;
115            let right_result = evaluate_expression(right, context)?;
116
117            // If either operand is vetoed, propagate the veto
118            if let OperationResult::Veto(msg) = left_result {
119                return Ok(OperationResult::Veto(msg));
120            }
121            if let OperationResult::Veto(msg) = right_result {
122                return Ok(OperationResult::Veto(msg));
123            }
124
125            let left_val = left_result.value().unwrap();
126            let right_val = right_result.value().unwrap();
127
128            let result = super::operations::comparison_operation(left_val, op, right_val)?;
129
130            // Record operation
131            let op_name = match op {
132                crate::ComparisonOperator::GreaterThan => "greater_than",
133                crate::ComparisonOperator::LessThan => "less_than",
134                crate::ComparisonOperator::GreaterThanOrEqual => "greater_than_or_equal",
135                crate::ComparisonOperator::LessThanOrEqual => "less_than_or_equal",
136                crate::ComparisonOperator::Equal => "equal",
137                crate::ComparisonOperator::NotEqual => "not_equal",
138                crate::ComparisonOperator::Is => "is",
139                crate::ComparisonOperator::IsNot => "is_not",
140            };
141
142            context.operations.push(OperationRecord::OperationExecuted {
143                operation: op_name.to_string(),
144                inputs: vec![left_val.clone(), right_val.clone()],
145                result: LiteralValue::Boolean(result),
146                unless_clause_index: None,
147            });
148
149            Ok(OperationResult::Value(LiteralValue::Boolean(result)))
150        }
151
152        ExpressionKind::LogicalAnd(left, right) => {
153            let left_result = evaluate_expression(left, context)?;
154            let right_result = evaluate_expression(right, context)?;
155
156            // If either operand is vetoed, propagate the veto
157            if let OperationResult::Veto(msg) = left_result {
158                return Ok(OperationResult::Veto(msg));
159            }
160            if let OperationResult::Veto(msg) = right_result {
161                return Ok(OperationResult::Veto(msg));
162            }
163
164            let left_val = left_result.value().unwrap();
165            let right_val = right_result.value().unwrap();
166
167            match (left_val, right_val) {
168                (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
169                    // No operation record for logical operations - only record sub-expressions
170                    Ok(OperationResult::Value(LiteralValue::Boolean(*l && *r)))
171                }
172                _ => Err(LemmaError::Engine(
173                    "Logical AND requires boolean operands".to_string(),
174                )),
175            }
176        }
177
178        ExpressionKind::LogicalOr(left, right) => {
179            let left_result = evaluate_expression(left, context)?;
180            let right_result = evaluate_expression(right, context)?;
181
182            // If either operand is vetoed, propagate the veto
183            if let OperationResult::Veto(msg) = left_result {
184                return Ok(OperationResult::Veto(msg));
185            }
186            if let OperationResult::Veto(msg) = right_result {
187                return Ok(OperationResult::Veto(msg));
188            }
189
190            let left_val = left_result.value().unwrap();
191            let right_val = right_result.value().unwrap();
192
193            match (left_val, right_val) {
194                (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
195                    // No operation record for logical operations - only record sub-expressions
196                    Ok(OperationResult::Value(LiteralValue::Boolean(*l || *r)))
197                }
198                _ => Err(LemmaError::Engine(
199                    "Logical OR requires boolean operands".to_string(),
200                )),
201            }
202        }
203
204        ExpressionKind::LogicalNegation(inner, _negation_type) => {
205            let result = evaluate_expression(inner, context)?;
206
207            // If the operand is vetoed, propagate the veto
208            if let OperationResult::Veto(msg) = result {
209                return Ok(OperationResult::Veto(msg));
210            }
211
212            let value = result.value().unwrap();
213
214            match value {
215                LiteralValue::Boolean(b) => Ok(OperationResult::Value(LiteralValue::Boolean(!b))),
216                _ => Err(LemmaError::Engine(
217                    "Logical NOT requires boolean operand".to_string(),
218                )),
219            }
220        }
221
222        ExpressionKind::UnitConversion(value_expr, target) => {
223            let result = evaluate_expression(value_expr, context)?;
224
225            // If the operand is vetoed, propagate the veto
226            if let OperationResult::Veto(msg) = result {
227                return Ok(OperationResult::Veto(msg));
228            }
229
230            let value = result.value().unwrap();
231            let converted = super::units::convert_unit(value, target)?;
232            Ok(OperationResult::Value(converted))
233        }
234
235        ExpressionKind::MathematicalOperator(op, operand) => {
236            evaluate_mathematical_operator(op, operand, context)
237        }
238
239        ExpressionKind::Veto(veto_expr) => Ok(OperationResult::Veto(veto_expr.message.clone())),
240
241        ExpressionKind::FactHasAnyValue(fact_ref) => {
242            // Check if fact exists and has a value
243            let fact_name = fact_ref.reference.join(".");
244            let has_value = context.facts.contains_key(&fact_name);
245            Ok(OperationResult::Value(LiteralValue::Boolean(has_value)))
246        }
247    }
248}
249
250/// Evaluate a mathematical operator (sqrt, sin, cos, etc.)
251fn evaluate_mathematical_operator(
252    op: &MathematicalOperator,
253    operand: &Expression,
254    context: &mut EvaluationContext,
255) -> Result<OperationResult, LemmaError> {
256    let result = evaluate_expression(operand, context)?;
257
258    // If the operand is vetoed, propagate the veto
259    if let OperationResult::Veto(msg) = result {
260        return Ok(OperationResult::Veto(msg));
261    }
262
263    let value = result.value().unwrap();
264
265    match value {
266        LiteralValue::Number(n) => {
267            use rust_decimal::prelude::ToPrimitive;
268            let float_val = n.to_f64().ok_or_else(|| {
269                LemmaError::Engine("Cannot convert to float for mathematical operation".to_string())
270            })?;
271
272            let math_result = match op {
273                MathematicalOperator::Sqrt => float_val.sqrt(),
274                MathematicalOperator::Sin => float_val.sin(),
275                MathematicalOperator::Cos => float_val.cos(),
276                MathematicalOperator::Tan => float_val.tan(),
277                MathematicalOperator::Asin => float_val.asin(),
278                MathematicalOperator::Acos => float_val.acos(),
279                MathematicalOperator::Atan => float_val.atan(),
280                MathematicalOperator::Log => float_val.ln(),
281                MathematicalOperator::Exp => float_val.exp(),
282            };
283
284            let decimal_result = Decimal::from_f64_retain(math_result).ok_or_else(|| {
285                LemmaError::Engine(
286                    "Mathematical operation result cannot be represented".to_string(),
287                )
288            })?;
289
290            Ok(OperationResult::Value(LiteralValue::Number(decimal_result)))
291        }
292        _ => Err(LemmaError::Engine(
293            "Mathematical operators require number operands".to_string(),
294        )),
295    }
296}
297
298/// Convert an Engine error to a Runtime error with proper source location
299///
300/// This is used to add span information to errors that occur during expression evaluation.
301fn convert_engine_error_to_runtime(
302    error: LemmaError,
303    expr: &Expression,
304    context: &EvaluationContext,
305) -> LemmaError {
306    match error {
307        LemmaError::Engine(msg) => {
308            let span = expr.span.clone().unwrap_or(Span {
309                start: 0,
310                end: 0,
311                line: 0,
312                col: 0,
313            });
314
315            let source_id = context
316                .current_doc
317                .source
318                .as_ref()
319                .cloned()
320                .unwrap_or_else(|| "<input>".to_string());
321
322            let source_text: Arc<str> = context
323                .sources
324                .get(&source_id)
325                .map(|s| Arc::from(s.as_str()))
326                .unwrap_or_else(|| Arc::from(""));
327
328            let suggestion = if msg.contains("division") || msg.contains("zero") {
329                Some(
330                    "Consider using an 'unless' clause to guard against division by zero"
331                        .to_string(),
332                )
333            } else if msg.contains("type") || msg.contains("mismatch") {
334                Some("Check that operands have compatible types".to_string())
335            } else {
336                None
337            };
338
339            LemmaError::Runtime(Box::new(crate::error::ErrorDetails {
340                message: msg,
341                span,
342                source_id,
343                source_text,
344                doc_name: context.current_doc.name.clone(),
345                doc_start_line: context.current_doc.start_line,
346                suggestion,
347            }))
348        }
349        other => other,
350    }
351}