Skip to main content

lemma/computation/
arithmetic.rs

1//! Type-aware arithmetic operations
2
3use crate::evaluation::OperationResult;
4use crate::semantic::standard_number;
5use crate::{ArithmeticComputation, LiteralValue, Value};
6use rust_decimal::Decimal;
7
8/// Perform type-aware arithmetic operation, returning OperationResult (Veto for runtime errors)
9pub fn arithmetic_operation(
10    left: &LiteralValue,
11    op: &ArithmeticComputation,
12    right: &LiteralValue,
13) -> OperationResult {
14    match (&left.value, &right.value) {
15        (Value::Number(l), Value::Number(r)) => match number_arithmetic(*l, op, *r) {
16            Ok(result) => OperationResult::Value(LiteralValue::number_with_type(
17                result,
18                left.lemma_type.clone(),
19            )),
20            Err(msg) => OperationResult::Veto(Some(msg)),
21        },
22
23        (Value::Date(_), _) | (_, Value::Date(_)) => {
24            super::datetime::datetime_arithmetic(left, op, right)
25        }
26
27        (Value::Time(_), _) | (_, Value::Time(_)) => {
28            super::datetime::time_arithmetic(left, op, right)
29        }
30
31        // Duration arithmetic
32        (Value::Duration(l, lu), Value::Duration(r, ru)) => {
33            let left_seconds = super::units::duration_to_seconds(*l, lu);
34            let right_seconds = super::units::duration_to_seconds(*r, ru);
35            match op {
36                ArithmeticComputation::Add => {
37                    let result_seconds = left_seconds + right_seconds;
38                    let result_value = super::units::seconds_to_duration(result_seconds, lu);
39                    OperationResult::Value(LiteralValue::duration_with_type(
40                        result_value,
41                        lu.clone(),
42                        left.lemma_type.clone(),
43                    ))
44                }
45                ArithmeticComputation::Subtract => {
46                    let result_seconds = left_seconds - right_seconds;
47                    let result_value = super::units::seconds_to_duration(result_seconds, lu);
48                    OperationResult::Value(LiteralValue::duration_with_type(
49                        result_value,
50                        lu.clone(),
51                        left.lemma_type.clone(),
52                    ))
53                }
54                _ => OperationResult::Veto(Some(format!(
55                    "Operation {:?} not supported for durations",
56                    op
57                ))),
58            }
59        }
60
61        // Duration with number
62        (Value::Duration(value, unit), Value::Number(n)) => match op {
63            ArithmeticComputation::Multiply => OperationResult::Value(
64                LiteralValue::duration_with_type(value * n, unit.clone(), left.lemma_type.clone()),
65            ),
66            ArithmeticComputation::Divide => {
67                if *n == Decimal::ZERO {
68                    return OperationResult::Veto(Some("Division by zero".to_string()));
69                }
70                OperationResult::Value(LiteralValue::duration_with_type(
71                    value / n,
72                    unit.clone(),
73                    left.lemma_type.clone(),
74                ))
75            }
76            _ => OperationResult::Veto(Some(format!(
77                "Operation {:?} not supported for duration and number",
78                op
79            ))),
80        },
81
82        (Value::Number(n), Value::Duration(value, unit)) => match op {
83            ArithmeticComputation::Multiply => OperationResult::Value(
84                LiteralValue::duration_with_type(n * value, unit.clone(), left.lemma_type.clone()),
85            ),
86            _ => OperationResult::Veto(Some(format!(
87                "Operation {:?} not supported for number and duration",
88                op
89            ))),
90        },
91
92        // Ratio operations
93        // Ratio op Number → Number (ratio semantics: ratio + number = number * (1 + ratio))
94        (Value::Ratio(r, _), Value::Number(n)) if right.get_type().is_number() => {
95            match op {
96                ArithmeticComputation::Add => {
97                    // ratio + number = number * (1 + ratio)
98                    let result = *n * (Decimal::ONE + *r);
99                    OperationResult::Value(LiteralValue::number_with_type(
100                        result,
101                        standard_number().clone(),
102                    ))
103                }
104                ArithmeticComputation::Subtract => {
105                    // ratio - number = number * (1 - ratio)
106                    let result = *n * (Decimal::ONE - *r);
107                    OperationResult::Value(LiteralValue::number_with_type(
108                        result,
109                        standard_number().clone(),
110                    ))
111                }
112                ArithmeticComputation::Multiply => match number_arithmetic(*r, op, *n) {
113                    Ok(result) => OperationResult::Value(LiteralValue::number_with_type(
114                        result,
115                        standard_number().clone(),
116                    )),
117                    Err(msg) => OperationResult::Veto(Some(msg)),
118                },
119                ArithmeticComputation::Divide => {
120                    if *n == Decimal::ZERO {
121                        return OperationResult::Veto(Some("Division by zero".to_string()));
122                    }
123                    match number_arithmetic(*r, op, *n) {
124                        Ok(result) => OperationResult::Value(LiteralValue::number_with_type(
125                            result,
126                            standard_number().clone(),
127                        )),
128                        Err(msg) => OperationResult::Veto(Some(msg)),
129                    }
130                }
131                _ => OperationResult::Veto(Some(format!(
132                    "Operation {:?} not supported for ratio and number",
133                    op
134                ))),
135            }
136        }
137        // Number op Ratio → Number (ratio semantics: number + ratio = number * (1 + ratio))
138        (Value::Number(n), Value::Ratio(r, _)) if left.get_type().is_number() => {
139            match op {
140                ArithmeticComputation::Add => {
141                    // number + ratio = number * (1 + ratio)
142                    let result = *n * (Decimal::ONE + *r);
143                    OperationResult::Value(LiteralValue::number_with_type(
144                        result,
145                        standard_number().clone(),
146                    ))
147                }
148                ArithmeticComputation::Subtract => {
149                    // number - ratio = number * (1 - ratio)
150                    let result = *n * (Decimal::ONE - *r);
151                    OperationResult::Value(LiteralValue::number_with_type(
152                        result,
153                        standard_number().clone(),
154                    ))
155                }
156                ArithmeticComputation::Multiply => match number_arithmetic(*n, op, *r) {
157                    Ok(result) => OperationResult::Value(LiteralValue::number_with_type(
158                        result,
159                        standard_number().clone(),
160                    )),
161                    Err(msg) => OperationResult::Veto(Some(msg)),
162                },
163                ArithmeticComputation::Divide => {
164                    if *r == Decimal::ZERO {
165                        return OperationResult::Veto(Some("Division by zero".to_string()));
166                    }
167                    match number_arithmetic(*n, op, *r) {
168                        Ok(result) => OperationResult::Value(LiteralValue::number_with_type(
169                            result,
170                            standard_number().clone(),
171                        )),
172                        Err(msg) => OperationResult::Veto(Some(msg)),
173                    }
174                }
175                _ => OperationResult::Veto(Some(format!(
176                    "Operation {:?} not supported for number and ratio",
177                    op
178                ))),
179            }
180        }
181        // Ratio op Ratio → Ratio
182        (Value::Ratio(l, lu), Value::Ratio(r, ru)) => {
183            // Preserve unit from left operand, or right if left is None
184            let preserved_unit = lu.clone().or_else(|| ru.clone());
185            match number_arithmetic(*l, op, *r) {
186                Ok(result) => OperationResult::Value(LiteralValue::ratio_with_type(
187                    result,
188                    preserved_unit,
189                    left.lemma_type.clone(),
190                )),
191                Err(msg) => OperationResult::Veto(Some(msg)),
192            }
193        }
194        // Scale operations with Scale
195        (Value::Scale(l_val, l_unit), Value::Scale(r_val, r_unit)) => {
196            // Units must match for addition/subtraction
197            if l_unit != r_unit
198                && (matches!(
199                    op,
200                    ArithmeticComputation::Add | ArithmeticComputation::Subtract
201                ))
202            {
203                return OperationResult::Veto(Some(format!(
204                    "Cannot apply '{}' to values with different units: {:?} and {:?}",
205                    op, l_unit, r_unit
206                )));
207            }
208            // Preserve unit from left
209            let preserved_unit = l_unit.clone();
210            match number_arithmetic(*l_val, op, *r_val) {
211                Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
212                    result,
213                    preserved_unit,
214                    left.lemma_type.clone(),
215                )),
216                Err(msg) => OperationResult::Veto(Some(msg)),
217            }
218        }
219        // Ratio op Scale → Scale (inherits Scale type and unit)
220        (Value::Ratio(ratio_val, _), Value::Scale(scale_val, scale_unit)) => {
221            match op {
222                ArithmeticComputation::Multiply => {
223                    match number_arithmetic(*ratio_val, op, *scale_val) {
224                        Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
225                            result,
226                            scale_unit.clone(),
227                            right.lemma_type.clone(),
228                        )),
229                        Err(msg) => OperationResult::Veto(Some(msg)),
230                    }
231                }
232                ArithmeticComputation::Divide => {
233                    if *scale_val == Decimal::ZERO {
234                        return OperationResult::Veto(Some("Division by zero".to_string()));
235                    }
236                    match number_arithmetic(*ratio_val, op, *scale_val) {
237                        Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
238                            result,
239                            scale_unit.clone(),
240                            right.lemma_type.clone(),
241                        )),
242                        Err(msg) => OperationResult::Veto(Some(msg)),
243                    }
244                }
245                ArithmeticComputation::Add | ArithmeticComputation::Subtract => {
246                    // Scale +/- Ratio applies ratio semantics: scale +/- (scale * ratio) = scale * (1 +/- ratio)
247                    let ratio_amount = *scale_val * *ratio_val;
248                    let result = match op {
249                        ArithmeticComputation::Add => *scale_val + ratio_amount,
250                        ArithmeticComputation::Subtract => *scale_val - ratio_amount,
251                        _ => {
252                            return OperationResult::Veto(Some(format!(
253                                "Operation '{}' not supported for ratio and scale",
254                                op
255                            )))
256                        }
257                    };
258                    OperationResult::Value(LiteralValue::scale_with_type(
259                        result,
260                        scale_unit.clone(), // Preserve Scale unit
261                        right.lemma_type.clone(),
262                    ))
263                }
264                _ => OperationResult::Veto(Some(format!(
265                    "Operation {:?} not supported for ratio and scale",
266                    op
267                ))),
268            }
269        }
270        // Scale op Ratio → Scale (inherits Scale type and unit)
271        (Value::Scale(scale_val, scale_unit), Value::Ratio(ratio_val, _)) => {
272            match op {
273                ArithmeticComputation::Multiply => {
274                    match number_arithmetic(*scale_val, op, *ratio_val) {
275                        Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
276                            result,
277                            scale_unit.clone(),
278                            left.lemma_type.clone(),
279                        )),
280                        Err(msg) => OperationResult::Veto(Some(msg)),
281                    }
282                }
283                ArithmeticComputation::Divide => {
284                    if *ratio_val == Decimal::ZERO {
285                        return OperationResult::Veto(Some("Division by zero".to_string()));
286                    }
287                    match number_arithmetic(*scale_val, op, *ratio_val) {
288                        Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
289                            result,
290                            scale_unit.clone(),
291                            left.lemma_type.clone(), // Inherit Scale type
292                        )),
293                        Err(msg) => OperationResult::Veto(Some(msg)),
294                    }
295                }
296                ArithmeticComputation::Add | ArithmeticComputation::Subtract => {
297                    // Scale +/- Ratio applies ratio semantics: scale +/- (scale * ratio) = scale * (1 +/- ratio)
298                    let ratio_amount = *scale_val * *ratio_val;
299                    let result = match op {
300                        ArithmeticComputation::Add => *scale_val + ratio_amount,
301                        ArithmeticComputation::Subtract => *scale_val - ratio_amount,
302                        _ => {
303                            return OperationResult::Veto(Some(format!(
304                                "Operation '{}' not supported for scale and ratio",
305                                op
306                            )))
307                        }
308                    };
309                    OperationResult::Value(LiteralValue::scale_with_type(
310                        result,
311                        scale_unit.clone(), // Preserve Scale unit
312                        left.lemma_type.clone(),
313                    ))
314                }
315                _ => OperationResult::Veto(Some(format!(
316                    "Operation {:?} not supported for scale and ratio",
317                    op
318                ))),
319            }
320        }
321
322        // Scale op Number → Scale (preserves unit)
323        (Value::Scale(scale_val, scale_unit), Value::Number(n)) => {
324            match number_arithmetic(*scale_val, op, *n) {
325                Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
326                    result,
327                    scale_unit.clone(),
328                    left.lemma_type.clone(),
329                )),
330                Err(msg) => OperationResult::Veto(Some(msg)),
331            }
332        }
333        // Number op Scale → Scale (preserves unit)
334        (Value::Number(n), Value::Scale(scale_val, scale_unit)) => {
335            match number_arithmetic(*n, op, *scale_val) {
336                Ok(result) => OperationResult::Value(LiteralValue::scale_with_type(
337                    result,
338                    scale_unit.clone(),
339                    right.lemma_type.clone(),
340                )),
341                Err(msg) => OperationResult::Veto(Some(msg)),
342            }
343        }
344        // Scale op Duration - not supported
345        (Value::Scale(_scale_val, _scale_unit), Value::Duration(_d_val, _d_unit)) => match op {
346            ArithmeticComputation::Multiply => {
347                OperationResult::Veto(Some("Cannot multiply scale and duration".to_string()))
348            }
349            _ => OperationResult::Veto(Some(format!(
350                "Operation {:?} not supported for scale and duration",
351                op
352            ))),
353        },
354        // Duration op Scale - not supported
355        (Value::Duration(_d_val, _d_unit), Value::Scale(_scale_val, _scale_unit)) => match op {
356            ArithmeticComputation::Multiply => {
357                OperationResult::Veto(Some("Cannot multiply duration and scale".to_string()))
358            }
359            _ => OperationResult::Veto(Some(format!(
360                "Operation {:?} not supported for duration and scale",
361                op
362            ))),
363        },
364        _ => OperationResult::Veto(Some(format!(
365            "Arithmetic operation {:?} not supported for types {:?} and {:?}",
366            op,
367            type_name(left),
368            type_name(right)
369        ))),
370    }
371}
372
373fn number_arithmetic(
374    left: Decimal,
375    op: &ArithmeticComputation,
376    right: Decimal,
377) -> Result<Decimal, String> {
378    use rust_decimal::prelude::ToPrimitive;
379
380    match op {
381        ArithmeticComputation::Add => Ok(left + right),
382        ArithmeticComputation::Subtract => Ok(left - right),
383        ArithmeticComputation::Multiply => Ok(left * right),
384        ArithmeticComputation::Divide => {
385            if right == Decimal::ZERO {
386                return Err("Division by zero".to_string());
387            }
388            Ok(left / right)
389        }
390        ArithmeticComputation::Modulo => {
391            if right == Decimal::ZERO {
392                return Err("Division by zero (modulo)".to_string());
393            }
394            Ok(left % right)
395        }
396        ArithmeticComputation::Power => {
397            let base = left
398                .to_f64()
399                .ok_or_else(|| "Cannot convert base to float".to_string())?;
400            let exp = right
401                .to_f64()
402                .ok_or_else(|| "Cannot convert exponent to float".to_string())?;
403            let result = base.powf(exp);
404            Decimal::from_f64_retain(result)
405                .ok_or_else(|| "Power result cannot be represented".to_string())
406        }
407    }
408}
409
410fn type_name(value: &LiteralValue) -> String {
411    value.get_type().name().to_string()
412}