Skip to main content

lemma/computation/
arithmetic.rs

1//! Type-aware arithmetic operations
2
3use crate::evaluation::OperationResult;
4use crate::planning::semantics::{
5    primitive_number, ArithmeticComputation, LiteralValue, ValueKind,
6};
7use rust_decimal::Decimal;
8
9/// Perform type-aware arithmetic operation, returning OperationResult (Veto for runtime errors)
10pub fn arithmetic_operation(
11    left: &LiteralValue,
12    op: &ArithmeticComputation,
13    right: &LiteralValue,
14) -> OperationResult {
15    match (&left.value, &right.value) {
16        (ValueKind::Number(l), ValueKind::Number(r)) => match number_arithmetic(*l, op, *r) {
17            Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
18                result,
19                left.lemma_type.clone(),
20            ))),
21            Err(msg) => OperationResult::Veto(Some(msg)),
22        },
23
24        (ValueKind::Date(_), _) | (_, ValueKind::Date(_)) => {
25            super::datetime::datetime_arithmetic(left, op, right)
26        }
27
28        (ValueKind::Time(_), _) | (_, ValueKind::Time(_)) => {
29            super::datetime::time_arithmetic(left, op, right)
30        }
31
32        // Duration arithmetic
33        (ValueKind::Duration(l, lu), ValueKind::Duration(r, ru)) => {
34            let left_seconds = super::units::duration_to_seconds(*l, lu);
35            let right_seconds = super::units::duration_to_seconds(*r, ru);
36            match op {
37                ArithmeticComputation::Add => {
38                    let result_seconds = left_seconds + right_seconds;
39                    let result_value = super::units::seconds_to_duration(result_seconds, lu);
40                    OperationResult::Value(Box::new(LiteralValue::duration_with_type(
41                        result_value,
42                        lu.clone(),
43                        left.lemma_type.clone(),
44                    )))
45                }
46                ArithmeticComputation::Subtract => {
47                    let result_seconds = left_seconds - right_seconds;
48                    let result_value = super::units::seconds_to_duration(result_seconds, lu);
49                    OperationResult::Value(Box::new(LiteralValue::duration_with_type(
50                        result_value,
51                        lu.clone(),
52                        left.lemma_type.clone(),
53                    )))
54                }
55                _ => OperationResult::Veto(Some(format!(
56                    "Operation {:?} not supported for durations",
57                    op
58                ))),
59            }
60        }
61
62        // Duration with Number → Duration
63        (ValueKind::Duration(value, unit), ValueKind::Number(n)) => {
64            match number_arithmetic(*value, op, *n) {
65                Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
66                    result,
67                    unit.clone(),
68                    left.lemma_type.clone(),
69                ))),
70                Err(msg) => OperationResult::Veto(Some(msg)),
71            }
72        }
73
74        // Number with Duration → Duration (commutative for Multiply/Add)
75        (ValueKind::Number(n), ValueKind::Duration(value, unit)) => {
76            match number_arithmetic(*n, op, *value) {
77                Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
78                    result,
79                    unit.clone(),
80                    right.lemma_type.clone(),
81                ))),
82                Err(msg) => OperationResult::Veto(Some(msg)),
83            }
84        }
85
86        // Ratio with Number → Number (ratio semantics: add/subtract apply as multiplier)
87        (ValueKind::Ratio(r, _), ValueKind::Number(n)) => match op {
88            ArithmeticComputation::Add => {
89                let result = *n * (Decimal::ONE + *r);
90                OperationResult::Value(Box::new(LiteralValue::number_with_type(
91                    result,
92                    primitive_number().clone(),
93                )))
94            }
95            ArithmeticComputation::Subtract => {
96                let result = *n * (Decimal::ONE - *r);
97                OperationResult::Value(Box::new(LiteralValue::number_with_type(
98                    result,
99                    primitive_number().clone(),
100                )))
101            }
102            _ => match number_arithmetic(*r, op, *n) {
103                Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
104                    result,
105                    primitive_number().clone(),
106                ))),
107                Err(msg) => OperationResult::Veto(Some(msg)),
108            },
109        },
110
111        // Number with Ratio → Number (ratio semantics: add/subtract apply as multiplier)
112        (ValueKind::Number(n), ValueKind::Ratio(r, _)) => match op {
113            ArithmeticComputation::Add => {
114                let result = *n * (Decimal::ONE + *r);
115                OperationResult::Value(Box::new(LiteralValue::number_with_type(
116                    result,
117                    primitive_number().clone(),
118                )))
119            }
120            ArithmeticComputation::Subtract => {
121                let result = *n * (Decimal::ONE - *r);
122                OperationResult::Value(Box::new(LiteralValue::number_with_type(
123                    result,
124                    primitive_number().clone(),
125                )))
126            }
127            _ => match number_arithmetic(*n, op, *r) {
128                Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
129                    result,
130                    primitive_number().clone(),
131                ))),
132                Err(msg) => OperationResult::Veto(Some(msg)),
133            },
134        },
135
136        // Duration with Ratio → Duration (ratio semantics: add/subtract apply as multiplier)
137        (ValueKind::Duration(value, unit), ValueKind::Ratio(r, _)) => match op {
138            ArithmeticComputation::Add => {
139                let result = *value * (Decimal::ONE + *r);
140                OperationResult::Value(Box::new(LiteralValue::duration_with_type(
141                    result,
142                    unit.clone(),
143                    left.lemma_type.clone(),
144                )))
145            }
146            ArithmeticComputation::Subtract => {
147                let result = *value * (Decimal::ONE - *r);
148                OperationResult::Value(Box::new(LiteralValue::duration_with_type(
149                    result,
150                    unit.clone(),
151                    left.lemma_type.clone(),
152                )))
153            }
154            _ => match number_arithmetic(*value, op, *r) {
155                Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
156                    result,
157                    unit.clone(),
158                    left.lemma_type.clone(),
159                ))),
160                Err(msg) => OperationResult::Veto(Some(msg)),
161            },
162        },
163
164        // Ratio with Duration → Duration (commutative for multiply)
165        (ValueKind::Ratio(r, _), ValueKind::Duration(value, unit)) => match op {
166            ArithmeticComputation::Add => {
167                let result = *value * (Decimal::ONE + *r);
168                OperationResult::Value(Box::new(LiteralValue::duration_with_type(
169                    result,
170                    unit.clone(),
171                    right.lemma_type.clone(),
172                )))
173            }
174            ArithmeticComputation::Subtract => {
175                let result = *value * (Decimal::ONE - *r);
176                OperationResult::Value(Box::new(LiteralValue::duration_with_type(
177                    result,
178                    unit.clone(),
179                    right.lemma_type.clone(),
180                )))
181            }
182            _ => match number_arithmetic(*r, op, *value) {
183                Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
184                    result,
185                    unit.clone(),
186                    right.lemma_type.clone(),
187                ))),
188                Err(msg) => OperationResult::Veto(Some(msg)),
189            },
190        },
191        // Ratio op Ratio → Ratio
192        (ValueKind::Ratio(l, lu), ValueKind::Ratio(r, _ru)) => {
193            // Preserve unit from left operand
194            match number_arithmetic(*l, op, *r) {
195                Ok(result) => OperationResult::Value(Box::new(LiteralValue::ratio_with_type(
196                    result,
197                    lu.clone(),
198                    left.lemma_type.clone(),
199                ))),
200                Err(msg) => OperationResult::Veto(Some(msg)),
201            }
202        }
203        // Scale operations with Scale
204        (ValueKind::Scale(l_val, l_unit), ValueKind::Scale(r_val, r_unit)) => {
205            // Units must match for addition/subtraction
206            if l_unit != r_unit
207                && (matches!(
208                    op,
209                    ArithmeticComputation::Add | ArithmeticComputation::Subtract
210                ))
211            {
212                return OperationResult::Veto(Some(format!(
213                    "Cannot apply '{}' to values with different units: {:?} and {:?}",
214                    op, l_unit, r_unit
215                )));
216            }
217            // Preserve unit from left
218            let preserved_unit = l_unit.clone();
219            match number_arithmetic(*l_val, op, *r_val) {
220                Ok(result) => OperationResult::Value(Box::new(LiteralValue::scale_with_type(
221                    result,
222                    preserved_unit,
223                    left.lemma_type.clone(),
224                ))),
225                Err(msg) => OperationResult::Veto(Some(msg)),
226            }
227        }
228        // Ratio op Scale → Scale (inherits Scale type and unit)
229        (ValueKind::Ratio(ratio_val, _), ValueKind::Scale(scale_val, scale_unit)) => {
230            match op {
231                ArithmeticComputation::Multiply => {
232                    match number_arithmetic(*ratio_val, op, *scale_val) {
233                        Ok(result) => {
234                            OperationResult::Value(Box::new(LiteralValue::scale_with_type(
235                                result,
236                                scale_unit.clone(),
237                                right.lemma_type.clone(),
238                            )))
239                        }
240                        Err(msg) => OperationResult::Veto(Some(msg)),
241                    }
242                }
243                ArithmeticComputation::Divide => {
244                    if *scale_val == Decimal::ZERO {
245                        return OperationResult::Veto(Some("Division by zero".to_string()));
246                    }
247                    match number_arithmetic(*ratio_val, op, *scale_val) {
248                        Ok(result) => {
249                            OperationResult::Value(Box::new(LiteralValue::scale_with_type(
250                                result,
251                                scale_unit.clone(),
252                                right.lemma_type.clone(),
253                            )))
254                        }
255                        Err(msg) => OperationResult::Veto(Some(msg)),
256                    }
257                }
258                ArithmeticComputation::Add | ArithmeticComputation::Subtract => {
259                    // Scale +/- Ratio applies ratio semantics: scale +/- (scale * ratio) = scale * (1 +/- ratio)
260                    let ratio_amount = *scale_val * *ratio_val;
261                    let result = match op {
262                        ArithmeticComputation::Add => *scale_val + ratio_amount,
263                        ArithmeticComputation::Subtract => *scale_val - ratio_amount,
264                        _ => {
265                            return OperationResult::Veto(Some(format!(
266                                "Operation '{}' not supported for ratio and scale",
267                                op
268                            )))
269                        }
270                    };
271                    OperationResult::Value(Box::new(LiteralValue::scale_with_type(
272                        result,
273                        scale_unit.clone(), // Preserve Scale unit
274                        right.lemma_type.clone(),
275                    )))
276                }
277                _ => OperationResult::Veto(Some(format!(
278                    "Operation {:?} not supported for ratio and scale",
279                    op
280                ))),
281            }
282        }
283        // Scale op Ratio → Scale (inherits Scale type and unit)
284        (ValueKind::Scale(scale_val, scale_unit), ValueKind::Ratio(ratio_val, _)) => {
285            match op {
286                ArithmeticComputation::Multiply => {
287                    match number_arithmetic(*scale_val, op, *ratio_val) {
288                        Ok(result) => {
289                            OperationResult::Value(Box::new(LiteralValue::scale_with_type(
290                                result,
291                                scale_unit.clone(),
292                                left.lemma_type.clone(),
293                            )))
294                        }
295                        Err(msg) => OperationResult::Veto(Some(msg)),
296                    }
297                }
298                ArithmeticComputation::Divide => {
299                    if *ratio_val == Decimal::ZERO {
300                        return OperationResult::Veto(Some("Division by zero".to_string()));
301                    }
302                    match number_arithmetic(*scale_val, op, *ratio_val) {
303                        Ok(result) => {
304                            OperationResult::Value(Box::new(LiteralValue::scale_with_type(
305                                result,
306                                scale_unit.clone(),
307                                left.lemma_type.clone(), // Inherit Scale type
308                            )))
309                        }
310                        Err(msg) => OperationResult::Veto(Some(msg)),
311                    }
312                }
313                ArithmeticComputation::Add | ArithmeticComputation::Subtract => {
314                    // Scale +/- Ratio applies ratio semantics: scale +/- (scale * ratio) = scale * (1 +/- ratio)
315                    let ratio_amount = *scale_val * *ratio_val;
316                    let result = match op {
317                        ArithmeticComputation::Add => *scale_val + ratio_amount,
318                        ArithmeticComputation::Subtract => *scale_val - ratio_amount,
319                        _ => {
320                            return OperationResult::Veto(Some(format!(
321                                "Operation '{}' not supported for scale and ratio",
322                                op
323                            )))
324                        }
325                    };
326                    OperationResult::Value(Box::new(LiteralValue::scale_with_type(
327                        result,
328                        scale_unit.clone(), // Preserve Scale unit
329                        left.lemma_type.clone(),
330                    )))
331                }
332                _ => OperationResult::Veto(Some(format!(
333                    "Operation {:?} not supported for scale and ratio",
334                    op
335                ))),
336            }
337        }
338
339        // Scale op Number → Scale (preserves unit)
340        (ValueKind::Scale(scale_val, scale_unit), ValueKind::Number(n)) => {
341            match number_arithmetic(*scale_val, op, *n) {
342                Ok(result) => OperationResult::Value(Box::new(LiteralValue::scale_with_type(
343                    result,
344                    scale_unit.clone(),
345                    left.lemma_type.clone(),
346                ))),
347                Err(msg) => OperationResult::Veto(Some(msg)),
348            }
349        }
350        // Number op Scale → Scale (preserves unit)
351        (ValueKind::Number(n), ValueKind::Scale(scale_val, scale_unit)) => {
352            match number_arithmetic(*n, op, *scale_val) {
353                Ok(result) => OperationResult::Value(Box::new(LiteralValue::scale_with_type(
354                    result,
355                    scale_unit.clone(),
356                    right.lemma_type.clone(),
357                ))),
358                Err(msg) => OperationResult::Veto(Some(msg)),
359            }
360        }
361        // Scale with Duration → Number
362        (ValueKind::Scale(scale_val, _), ValueKind::Duration(dur_val, _)) => {
363            match number_arithmetic(*scale_val, op, *dur_val) {
364                Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
365                    result,
366                    primitive_number().clone(),
367                ))),
368                Err(msg) => OperationResult::Veto(Some(msg)),
369            }
370        }
371
372        // Duration with Scale → Number
373        (ValueKind::Duration(dur_val, _), ValueKind::Scale(scale_val, _)) => {
374            match number_arithmetic(*dur_val, op, *scale_val) {
375                Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
376                    result,
377                    primitive_number().clone(),
378                ))),
379                Err(msg) => OperationResult::Veto(Some(msg)),
380            }
381        }
382        _ => OperationResult::Veto(Some(format!(
383            "Arithmetic operation {:?} not supported for types {:?} and {:?}",
384            op,
385            type_name(left),
386            type_name(right)
387        ))),
388    }
389}
390
391fn number_arithmetic(
392    left: Decimal,
393    op: &ArithmeticComputation,
394    right: Decimal,
395) -> Result<Decimal, String> {
396    use rust_decimal::prelude::ToPrimitive;
397
398    match op {
399        ArithmeticComputation::Add => Ok(left + right),
400        ArithmeticComputation::Subtract => Ok(left - right),
401        ArithmeticComputation::Multiply => Ok(left * right),
402        ArithmeticComputation::Divide => {
403            if right == Decimal::ZERO {
404                return Err("Division by zero".to_string());
405            }
406            Ok(left / right)
407        }
408        ArithmeticComputation::Modulo => {
409            if right == Decimal::ZERO {
410                return Err("Division by zero (modulo)".to_string());
411            }
412            Ok(left % right)
413        }
414        ArithmeticComputation::Power => {
415            let base = left
416                .to_f64()
417                .ok_or_else(|| "Cannot convert base to float".to_string())?;
418            let exp = right
419                .to_f64()
420                .ok_or_else(|| "Cannot convert exponent to float".to_string())?;
421            let result = base.powf(exp);
422            Decimal::from_f64_retain(result)
423                .ok_or_else(|| "Power result cannot be represented".to_string())
424        }
425    }
426}
427
428fn type_name(value: &LiteralValue) -> String {
429    value.get_type().name().to_string()
430}