1use super::context::EvaluationContext;
6use crate::{
7 ast::Span, ArithmeticOperation, Expression, ExpressionKind, LemmaError, LiteralValue,
8 MathematicalOperator, OperationRecord, OperationResult,
9};
10use rust_decimal::Decimal;
11use std::sync::Arc;
12
13pub fn evaluate_expression(
18 expr: &Expression,
19 context: &mut EvaluationContext,
20) -> Result<OperationResult, LemmaError> {
21 match &expr.kind {
22 ExpressionKind::Literal(lit) => {
23 Ok(OperationResult::Value(lit.clone()))
25 }
26
27 ExpressionKind::FactReference(fact_ref) => {
28 let fact_name = fact_ref.reference.join(".");
30 let value = context
31 .facts
32 .get(&fact_name)
33 .ok_or_else(|| LemmaError::Engine(format!("Missing fact: {}", fact_name)))?;
34
35 context.operations.push(OperationRecord::FactUsed {
37 name: fact_name,
38 value: value.clone(),
39 });
40
41 Ok(OperationResult::Value(value.clone()))
42 }
43
44 ExpressionKind::RuleReference(rule_ref) => {
45 let rule_name = rule_ref.reference.join(".");
48
49 if let Some(result) = context.rule_results.get(&rule_name) {
51 match result {
52 OperationResult::Veto(msg) => {
53 return Ok(OperationResult::Veto(msg.clone()));
55 }
56 OperationResult::Value(value) => {
57 context.operations.push(OperationRecord::RuleUsed {
59 name: rule_name,
60 value: value.clone(),
61 });
62 return Ok(OperationResult::Value(value.clone()));
63 }
64 }
65 }
66
67 Err(LemmaError::Engine(format!(
69 "Rule {} not yet computed",
70 rule_name
71 )))
72 }
73
74 ExpressionKind::Arithmetic(left, op, right) => {
75 let left_result = evaluate_expression(left, context)?;
76 let right_result = evaluate_expression(right, context)?;
77
78 if let OperationResult::Veto(msg) = left_result {
80 return Ok(OperationResult::Veto(msg));
81 }
82 if let OperationResult::Veto(msg) = right_result {
83 return Ok(OperationResult::Veto(msg));
84 }
85
86 let left_val = left_result.value().unwrap();
87 let right_val = right_result.value().unwrap();
88
89 let result = super::operations::arithmetic_operation(left_val, op, right_val)
91 .map_err(|e| convert_engine_error_to_runtime(e, expr, context))?;
92
93 let op_name = match op {
95 ArithmeticOperation::Add => "add",
96 ArithmeticOperation::Subtract => "subtract",
97 ArithmeticOperation::Multiply => "multiply",
98 ArithmeticOperation::Divide => "divide",
99 ArithmeticOperation::Modulo => "modulo",
100 ArithmeticOperation::Power => "power",
101 };
102
103 context.operations.push(OperationRecord::OperationExecuted {
104 operation: op_name.to_string(),
105 inputs: vec![left_val.clone(), right_val.clone()],
106 result: result.clone(),
107 unless_clause_index: None,
108 });
109
110 Ok(OperationResult::Value(result))
111 }
112
113 ExpressionKind::Comparison(left, op, right) => {
114 let left_result = evaluate_expression(left, context)?;
115 let right_result = evaluate_expression(right, context)?;
116
117 if let OperationResult::Veto(msg) = left_result {
119 return Ok(OperationResult::Veto(msg));
120 }
121 if let OperationResult::Veto(msg) = right_result {
122 return Ok(OperationResult::Veto(msg));
123 }
124
125 let left_val = left_result.value().unwrap();
126 let right_val = right_result.value().unwrap();
127
128 let result = super::operations::comparison_operation(left_val, op, right_val)?;
129
130 let op_name = match op {
132 crate::ComparisonOperator::GreaterThan => "greater_than",
133 crate::ComparisonOperator::LessThan => "less_than",
134 crate::ComparisonOperator::GreaterThanOrEqual => "greater_than_or_equal",
135 crate::ComparisonOperator::LessThanOrEqual => "less_than_or_equal",
136 crate::ComparisonOperator::Equal => "equal",
137 crate::ComparisonOperator::NotEqual => "not_equal",
138 crate::ComparisonOperator::Is => "is",
139 crate::ComparisonOperator::IsNot => "is_not",
140 };
141
142 context.operations.push(OperationRecord::OperationExecuted {
143 operation: op_name.to_string(),
144 inputs: vec![left_val.clone(), right_val.clone()],
145 result: LiteralValue::Boolean(result),
146 unless_clause_index: None,
147 });
148
149 Ok(OperationResult::Value(LiteralValue::Boolean(result)))
150 }
151
152 ExpressionKind::LogicalAnd(left, right) => {
153 let left_result = evaluate_expression(left, context)?;
154 let right_result = evaluate_expression(right, context)?;
155
156 if let OperationResult::Veto(msg) = left_result {
158 return Ok(OperationResult::Veto(msg));
159 }
160 if let OperationResult::Veto(msg) = right_result {
161 return Ok(OperationResult::Veto(msg));
162 }
163
164 let left_val = left_result.value().unwrap();
165 let right_val = right_result.value().unwrap();
166
167 match (left_val, right_val) {
168 (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
169 Ok(OperationResult::Value(LiteralValue::Boolean(*l && *r)))
171 }
172 _ => Err(LemmaError::Engine(
173 "Logical AND requires boolean operands".to_string(),
174 )),
175 }
176 }
177
178 ExpressionKind::LogicalOr(left, right) => {
179 let left_result = evaluate_expression(left, context)?;
180 let right_result = evaluate_expression(right, context)?;
181
182 if let OperationResult::Veto(msg) = left_result {
184 return Ok(OperationResult::Veto(msg));
185 }
186 if let OperationResult::Veto(msg) = right_result {
187 return Ok(OperationResult::Veto(msg));
188 }
189
190 let left_val = left_result.value().unwrap();
191 let right_val = right_result.value().unwrap();
192
193 match (left_val, right_val) {
194 (LiteralValue::Boolean(l), LiteralValue::Boolean(r)) => {
195 Ok(OperationResult::Value(LiteralValue::Boolean(*l || *r)))
197 }
198 _ => Err(LemmaError::Engine(
199 "Logical OR requires boolean operands".to_string(),
200 )),
201 }
202 }
203
204 ExpressionKind::LogicalNegation(inner, _negation_type) => {
205 let result = evaluate_expression(inner, context)?;
206
207 if let OperationResult::Veto(msg) = result {
209 return Ok(OperationResult::Veto(msg));
210 }
211
212 let value = result.value().unwrap();
213
214 match value {
215 LiteralValue::Boolean(b) => Ok(OperationResult::Value(LiteralValue::Boolean(!b))),
216 _ => Err(LemmaError::Engine(
217 "Logical NOT requires boolean operand".to_string(),
218 )),
219 }
220 }
221
222 ExpressionKind::UnitConversion(value_expr, target) => {
223 let result = evaluate_expression(value_expr, context)?;
224
225 if let OperationResult::Veto(msg) = result {
227 return Ok(OperationResult::Veto(msg));
228 }
229
230 let value = result.value().unwrap();
231 let converted = super::units::convert_unit(value, target)?;
232 Ok(OperationResult::Value(converted))
233 }
234
235 ExpressionKind::MathematicalOperator(op, operand) => {
236 evaluate_mathematical_operator(op, operand, context)
237 }
238
239 ExpressionKind::Veto(veto_expr) => Ok(OperationResult::Veto(veto_expr.message.clone())),
240
241 ExpressionKind::FactHasAnyValue(fact_ref) => {
242 let fact_name = fact_ref.reference.join(".");
244 let has_value = context.facts.contains_key(&fact_name);
245 Ok(OperationResult::Value(LiteralValue::Boolean(has_value)))
246 }
247 }
248}
249
250fn evaluate_mathematical_operator(
252 op: &MathematicalOperator,
253 operand: &Expression,
254 context: &mut EvaluationContext,
255) -> Result<OperationResult, LemmaError> {
256 let result = evaluate_expression(operand, context)?;
257
258 if let OperationResult::Veto(msg) = result {
260 return Ok(OperationResult::Veto(msg));
261 }
262
263 let value = result.value().unwrap();
264
265 match value {
266 LiteralValue::Number(n) => {
267 use rust_decimal::prelude::ToPrimitive;
268 let float_val = n.to_f64().ok_or_else(|| {
269 LemmaError::Engine("Cannot convert to float for mathematical operation".to_string())
270 })?;
271
272 let math_result = match op {
273 MathematicalOperator::Sqrt => float_val.sqrt(),
274 MathematicalOperator::Sin => float_val.sin(),
275 MathematicalOperator::Cos => float_val.cos(),
276 MathematicalOperator::Tan => float_val.tan(),
277 MathematicalOperator::Asin => float_val.asin(),
278 MathematicalOperator::Acos => float_val.acos(),
279 MathematicalOperator::Atan => float_val.atan(),
280 MathematicalOperator::Log => float_val.ln(),
281 MathematicalOperator::Exp => float_val.exp(),
282 };
283
284 let decimal_result = Decimal::from_f64_retain(math_result).ok_or_else(|| {
285 LemmaError::Engine(
286 "Mathematical operation result cannot be represented".to_string(),
287 )
288 })?;
289
290 Ok(OperationResult::Value(LiteralValue::Number(decimal_result)))
291 }
292 _ => Err(LemmaError::Engine(
293 "Mathematical operators require number operands".to_string(),
294 )),
295 }
296}
297
298fn convert_engine_error_to_runtime(
302 error: LemmaError,
303 expr: &Expression,
304 context: &EvaluationContext,
305) -> LemmaError {
306 match error {
307 LemmaError::Engine(msg) => {
308 let span = expr.span.clone().unwrap_or(Span {
309 start: 0,
310 end: 0,
311 line: 0,
312 col: 0,
313 });
314
315 let source_id = context
316 .current_doc
317 .source
318 .as_ref()
319 .cloned()
320 .unwrap_or_else(|| "<input>".to_string());
321
322 let source_text: Arc<str> = context
323 .sources
324 .get(&source_id)
325 .map(|s| Arc::from(s.as_str()))
326 .unwrap_or_else(|| Arc::from(""));
327
328 let suggestion = if msg.contains("division") || msg.contains("zero") {
329 Some(
330 "Consider using an 'unless' clause to guard against division by zero"
331 .to_string(),
332 )
333 } else if msg.contains("type") || msg.contains("mismatch") {
334 Some("Check that operands have compatible types".to_string())
335 } else {
336 None
337 };
338
339 LemmaError::Runtime(Box::new(crate::error::ErrorDetails {
340 message: msg,
341 span,
342 source_id,
343 source_text,
344 doc_name: context.current_doc.name.clone(),
345 doc_start_line: context.current_doc.start_line,
346 suggestion,
347 }))
348 }
349 other => other,
350 }
351}