1use super::context::EvaluationContext;
6use crate::{
7 ast::Span, ArithmeticOperation, Expression, ExpressionKind, FactReference, LemmaError,
8 LiteralValue, MathematicalOperator, OperationRecord, OperationResult,
9};
10use rust_decimal::Decimal;
11use std::sync::Arc;
12
13pub fn evaluate_expression(
22 expr: &Expression,
23 context: &mut EvaluationContext,
24 fact_prefix: &[String],
25) -> Result<OperationResult, LemmaError> {
26 context.check_timeout()?;
28
29 match &expr.kind {
30 ExpressionKind::Literal(lit) => {
31 Ok(OperationResult::Value(lit.clone()))
33 }
34
35 ExpressionKind::FactReference(fact_ref) => {
36 let lookup_ref = if !fact_prefix.is_empty() {
38 let mut qualified_reference = fact_prefix.to_vec();
42 qualified_reference.extend_from_slice(&fact_ref.reference);
43 FactReference {
44 reference: qualified_reference,
45 }
46 } else {
47 fact_ref.clone()
49 };
50
51 let value = context.facts.get(&lookup_ref).ok_or_else(|| {
52 LemmaError::Engine(format!("Missing fact: {}", lookup_ref.reference.join(".")))
53 })?;
54
55 context.operations.push(OperationRecord::FactUsed {
57 name: lookup_ref.reference.join("."),
58 value: value.clone(),
59 });
60
61 Ok(OperationResult::Value(value.clone()))
62 }
63 ExpressionKind::RuleReference(rule_ref) => {
64 let rule_path = crate::RulePath::from_reference(
67 &rule_ref.reference,
68 context.current_doc,
69 context.all_documents,
70 )?;
71
72 if let Some(result) = context.rule_results.get(&rule_path) {
74 match result {
75 OperationResult::Veto(msg) => {
76 return Ok(OperationResult::Veto(msg.clone()));
78 }
79 OperationResult::Value(value) => {
80 context.operations.push(OperationRecord::RuleUsed {
82 name: rule_path.to_string(),
83 value: value.clone(),
84 });
85 return Ok(OperationResult::Value(value.clone()));
86 }
87 }
88 }
89
90 Err(LemmaError::Engine(format!("Rule {} not found", rule_path)))
92 }
93
94 ExpressionKind::Arithmetic(left, op, right) => {
95 let left_result = evaluate_expression(left, context, fact_prefix)?;
96 let right_result = evaluate_expression(right, context, fact_prefix)?;
97
98 if let OperationResult::Veto(msg) = left_result {
100 return Ok(OperationResult::Veto(msg));
101 }
102 if let OperationResult::Veto(msg) = right_result {
103 return Ok(OperationResult::Veto(msg));
104 }
105
106 let left_val = left_result.expect_value("arithmetic left operand")?;
108 let right_val = right_result.expect_value("arithmetic right operand")?;
109
110 let result = super::operations::arithmetic_operation(left_val, op, right_val)
112 .map_err(|e| convert_engine_error_to_runtime(e, expr, context))?;
113
114 let op_name = match op {
116 ArithmeticOperation::Add => "add",
117 ArithmeticOperation::Subtract => "subtract",
118 ArithmeticOperation::Multiply => "multiply",
119 ArithmeticOperation::Divide => "divide",
120 ArithmeticOperation::Modulo => "modulo",
121 ArithmeticOperation::Power => "power",
122 };
123
124 context.operations.push(OperationRecord::OperationExecuted {
125 operation: op_name.to_string(),
126 inputs: vec![left_val.clone(), right_val.clone()],
127 result: result.clone(),
128 unless_clause_index: None,
129 });
130
131 Ok(OperationResult::Value(result))
132 }
133
134 ExpressionKind::Comparison(left, op, right) => {
135 let left_result = evaluate_expression(left, context, fact_prefix)?;
136 let right_result = evaluate_expression(right, context, fact_prefix)?;
137
138 if let OperationResult::Veto(msg) = left_result {
140 return Ok(OperationResult::Veto(msg));
141 }
142 if let OperationResult::Veto(msg) = right_result {
143 return Ok(OperationResult::Veto(msg));
144 }
145
146 let left_val = left_result.expect_value("comparison left operand")?;
148 let right_val = right_result.expect_value("comparison right operand")?;
149
150 let result = super::operations::comparison_operation(left_val, op, right_val)?;
151
152 let op_name = match op {
154 crate::ComparisonOperator::GreaterThan => "greater_than",
155 crate::ComparisonOperator::LessThan => "less_than",
156 crate::ComparisonOperator::GreaterThanOrEqual => "greater_than_or_equal",
157 crate::ComparisonOperator::LessThanOrEqual => "less_than_or_equal",
158 crate::ComparisonOperator::Equal => "equal",
159 crate::ComparisonOperator::NotEqual => "not_equal",
160 crate::ComparisonOperator::Is => "is",
161 crate::ComparisonOperator::IsNot => "is_not",
162 };
163
164 context.operations.push(OperationRecord::OperationExecuted {
165 operation: op_name.to_string(),
166 inputs: vec![left_val.clone(), right_val.clone()],
167 result: LiteralValue::Boolean(result),
168 unless_clause_index: None,
169 });
170
171 Ok(OperationResult::Value(LiteralValue::Boolean(result)))
172 }
173
174 ExpressionKind::LogicalAnd(left, right) => {
175 let left_result = evaluate_expression(left, context, fact_prefix)?;
176 let right_result = evaluate_expression(right, context, fact_prefix)?;
177
178 if let OperationResult::Veto(msg) = left_result {
180 return Ok(OperationResult::Veto(msg));
181 }
182 if let OperationResult::Veto(msg) = right_result {
183 return Ok(OperationResult::Veto(msg));
184 }
185
186 let left_val = left_result.expect_value("logical AND left operand")?;
188 let right_val = right_result.expect_value("logical AND right operand")?;
189
190 match (left_val, right_val) {
191 (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
192 Ok(OperationResult::Value(LiteralValue::Boolean(*l && *r)))
194 }
195 _ => Err(LemmaError::Engine(
196 "Logical AND requires boolean operands".to_string(),
197 )),
198 }
199 }
200
201 ExpressionKind::LogicalOr(left, right) => {
202 let left_result = evaluate_expression(left, context, fact_prefix)?;
203 let right_result = evaluate_expression(right, context, fact_prefix)?;
204
205 if let OperationResult::Veto(msg) = left_result {
207 return Ok(OperationResult::Veto(msg));
208 }
209 if let OperationResult::Veto(msg) = right_result {
210 return Ok(OperationResult::Veto(msg));
211 }
212
213 let left_val = left_result.expect_value("logical OR left operand")?;
215 let right_val = right_result.expect_value("logical OR right operand")?;
216
217 match (left_val, right_val) {
218 (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
219 Ok(OperationResult::Value(LiteralValue::Boolean(*l || *r)))
221 }
222 _ => Err(LemmaError::Engine(
223 "Logical OR requires boolean operands".to_string(),
224 )),
225 }
226 }
227
228 ExpressionKind::LogicalNegation(operand, _negation_type) => {
229 let result = evaluate_expression(operand, context, fact_prefix)?;
230
231 if let OperationResult::Veto(msg) = result {
233 return Ok(OperationResult::Veto(msg));
234 }
235
236 let value = result.expect_value("logical negation operand")?;
238
239 match value {
240 LiteralValue::Boolean(b) => Ok(OperationResult::Value(LiteralValue::Boolean(!b))),
241 _ => Err(LemmaError::Engine(
242 "Logical NOT requires boolean operand".to_string(),
243 )),
244 }
245 }
246
247 ExpressionKind::UnitConversion(value_expr, target) => {
248 let result = evaluate_expression(value_expr, context, fact_prefix)?;
249
250 if let OperationResult::Veto(msg) = result {
252 return Ok(OperationResult::Veto(msg));
253 }
254
255 let value = result.expect_value("unit conversion operand")?;
257 let converted = super::units::convert_unit(value, target)?;
258 Ok(OperationResult::Value(converted))
259 }
260
261 ExpressionKind::MathematicalOperator(op, operand) => {
262 evaluate_mathematical_operator(op, operand, context, fact_prefix)
263 }
264
265 ExpressionKind::Veto(veto_expr) => Ok(OperationResult::Veto(veto_expr.message.clone())),
266
267 ExpressionKind::FactHasAnyValue(fact_ref) => {
268 let lookup_ref = if !fact_prefix.is_empty() {
270 let mut qualified_reference = fact_prefix.to_vec();
271 qualified_reference.extend_from_slice(&fact_ref.reference);
272 FactReference {
273 reference: qualified_reference,
274 }
275 } else {
276 fact_ref.clone()
277 };
278 let has_value = context.facts.contains_key(&lookup_ref);
279 Ok(OperationResult::Value(LiteralValue::Boolean(has_value)))
280 }
281 }
282}
283
284fn evaluate_mathematical_operator(
286 op: &MathematicalOperator,
287 operand: &Expression,
288 context: &mut EvaluationContext,
289 fact_prefix: &[String],
290) -> Result<OperationResult, LemmaError> {
291 let result = evaluate_expression(operand, context, fact_prefix)?;
292
293 if let OperationResult::Veto(msg) = result {
295 return Ok(OperationResult::Veto(msg));
296 }
297
298 let value = result.expect_value("mathematical operator operand")?;
300
301 match value {
302 LiteralValue::Number(n) => {
303 use rust_decimal::prelude::ToPrimitive;
304 let float_val = n.to_f64().ok_or_else(|| {
305 LemmaError::Engine("Cannot convert to float for mathematical operation".to_string())
306 })?;
307
308 match op {
309 MathematicalOperator::Sqrt
311 | MathematicalOperator::Sin
312 | MathematicalOperator::Cos
313 | MathematicalOperator::Tan
314 | MathematicalOperator::Asin
315 | MathematicalOperator::Acos
316 | MathematicalOperator::Atan
317 | MathematicalOperator::Log
318 | MathematicalOperator::Exp => {
319 let math_result = match op {
320 MathematicalOperator::Sqrt => float_val.sqrt(),
321 MathematicalOperator::Sin => float_val.sin(),
322 MathematicalOperator::Cos => float_val.cos(),
323 MathematicalOperator::Tan => float_val.tan(),
324 MathematicalOperator::Asin => float_val.asin(),
325 MathematicalOperator::Acos => float_val.acos(),
326 MathematicalOperator::Atan => float_val.atan(),
327 MathematicalOperator::Log => float_val.ln(),
328 MathematicalOperator::Exp => float_val.exp(),
329 _ => unreachable!(),
330 };
331 let decimal_result =
332 Decimal::from_f64_retain(math_result).ok_or_else(|| {
333 LemmaError::Engine(
334 "Mathematical operation result cannot be represented".to_string(),
335 )
336 })?;
337 Ok(OperationResult::Value(LiteralValue::Number(decimal_result)))
338 }
339 MathematicalOperator::Abs => {
341 Ok(OperationResult::Value(LiteralValue::Number(n.abs())))
342 }
343 MathematicalOperator::Floor => {
344 Ok(OperationResult::Value(LiteralValue::Number(n.floor())))
345 }
346 MathematicalOperator::Ceil => {
347 Ok(OperationResult::Value(LiteralValue::Number(n.ceil())))
348 }
349 MathematicalOperator::Round => {
350 Ok(OperationResult::Value(LiteralValue::Number(n.round())))
351 }
352 }
353 }
354 _ => Err(LemmaError::Engine(
355 "Mathematical operators require number operands".to_string(),
356 )),
357 }
358}
359
360fn convert_engine_error_to_runtime(
364 error: LemmaError,
365 expr: &Expression,
366 context: &EvaluationContext,
367) -> LemmaError {
368 match error {
369 LemmaError::Engine(msg) => {
370 let span = expr.span.clone().unwrap_or(Span {
371 start: 0,
372 end: 0,
373 line: 0,
374 col: 0,
375 });
376
377 let source_id = context
378 .current_doc
379 .source
380 .as_ref()
381 .cloned()
382 .unwrap_or_else(|| "<input>".to_string());
383
384 let source_text: Arc<str> = context
385 .sources
386 .get(&source_id)
387 .map(|s| Arc::from(s.as_str()))
388 .unwrap_or_else(|| Arc::from(""));
389
390 let suggestion = if msg.contains("division") || msg.contains("zero") {
391 Some(
392 "Consider using an 'unless' clause to guard against division by zero"
393 .to_string(),
394 )
395 } else if msg.contains("type") || msg.contains("mismatch") {
396 Some("Check that operands have compatible types".to_string())
397 } else {
398 None
399 };
400
401 LemmaError::Runtime(Box::new(crate::error::ErrorDetails {
402 message: msg,
403 span,
404 source_id,
405 source_text,
406 doc_name: context.current_doc.name.clone(),
407 doc_start_line: context.current_doc.start_line,
408 suggestion,
409 }))
410 }
411 other => other,
412 }
413}