1use crate::evaluation::OperationResult;
4use crate::planning::semantics::{
5 primitive_number, ArithmeticComputation, LiteralValue, ValueKind,
6};
7use rust_decimal::Decimal;
8
9pub fn arithmetic_operation(
11 left: &LiteralValue,
12 op: &ArithmeticComputation,
13 right: &LiteralValue,
14) -> OperationResult {
15 match (&left.value, &right.value) {
16 (ValueKind::Number(l), ValueKind::Number(r)) => match number_arithmetic(*l, op, *r) {
17 Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
18 result,
19 left.lemma_type.clone(),
20 ))),
21 Err(msg) => OperationResult::Veto(Some(msg)),
22 },
23
24 (ValueKind::Date(_), _) | (_, ValueKind::Date(_)) => {
25 super::datetime::datetime_arithmetic(left, op, right)
26 }
27
28 (ValueKind::Time(_), _) | (_, ValueKind::Time(_)) => {
29 super::datetime::time_arithmetic(left, op, right)
30 }
31
32 (ValueKind::Duration(l, lu), ValueKind::Duration(r, ru)) => {
34 let left_seconds = super::units::duration_to_seconds(*l, lu);
35 let right_seconds = super::units::duration_to_seconds(*r, ru);
36 match op {
37 ArithmeticComputation::Add => {
38 let result_seconds = left_seconds + right_seconds;
39 let result_value = super::units::seconds_to_duration(result_seconds, lu);
40 OperationResult::Value(Box::new(LiteralValue::duration_with_type(
41 result_value,
42 lu.clone(),
43 left.lemma_type.clone(),
44 )))
45 }
46 ArithmeticComputation::Subtract => {
47 let result_seconds = left_seconds - right_seconds;
48 let result_value = super::units::seconds_to_duration(result_seconds, lu);
49 OperationResult::Value(Box::new(LiteralValue::duration_with_type(
50 result_value,
51 lu.clone(),
52 left.lemma_type.clone(),
53 )))
54 }
55 _ => OperationResult::Veto(Some(format!(
56 "Operation {:?} not supported for durations",
57 op
58 ))),
59 }
60 }
61
62 (ValueKind::Duration(value, unit), ValueKind::Number(n)) => {
64 match number_arithmetic(*value, op, *n) {
65 Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
66 result,
67 unit.clone(),
68 left.lemma_type.clone(),
69 ))),
70 Err(msg) => OperationResult::Veto(Some(msg)),
71 }
72 }
73
74 (ValueKind::Number(n), ValueKind::Duration(value, unit)) => {
76 match number_arithmetic(*n, op, *value) {
77 Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
78 result,
79 unit.clone(),
80 right.lemma_type.clone(),
81 ))),
82 Err(msg) => OperationResult::Veto(Some(msg)),
83 }
84 }
85
86 (ValueKind::Ratio(r, _), ValueKind::Number(n)) => match op {
88 ArithmeticComputation::Add => {
89 let result = *n * (Decimal::ONE + *r);
90 OperationResult::Value(Box::new(LiteralValue::number_with_type(
91 result,
92 primitive_number().clone(),
93 )))
94 }
95 ArithmeticComputation::Subtract => {
96 let result = *n * (Decimal::ONE - *r);
97 OperationResult::Value(Box::new(LiteralValue::number_with_type(
98 result,
99 primitive_number().clone(),
100 )))
101 }
102 _ => match number_arithmetic(*r, op, *n) {
103 Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
104 result,
105 primitive_number().clone(),
106 ))),
107 Err(msg) => OperationResult::Veto(Some(msg)),
108 },
109 },
110
111 (ValueKind::Number(n), ValueKind::Ratio(r, _)) => match op {
113 ArithmeticComputation::Add => {
114 let result = *n * (Decimal::ONE + *r);
115 OperationResult::Value(Box::new(LiteralValue::number_with_type(
116 result,
117 primitive_number().clone(),
118 )))
119 }
120 ArithmeticComputation::Subtract => {
121 let result = *n * (Decimal::ONE - *r);
122 OperationResult::Value(Box::new(LiteralValue::number_with_type(
123 result,
124 primitive_number().clone(),
125 )))
126 }
127 _ => match number_arithmetic(*n, op, *r) {
128 Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
129 result,
130 primitive_number().clone(),
131 ))),
132 Err(msg) => OperationResult::Veto(Some(msg)),
133 },
134 },
135
136 (ValueKind::Duration(value, unit), ValueKind::Ratio(r, _)) => match op {
138 ArithmeticComputation::Add => {
139 let result = *value * (Decimal::ONE + *r);
140 OperationResult::Value(Box::new(LiteralValue::duration_with_type(
141 result,
142 unit.clone(),
143 left.lemma_type.clone(),
144 )))
145 }
146 ArithmeticComputation::Subtract => {
147 let result = *value * (Decimal::ONE - *r);
148 OperationResult::Value(Box::new(LiteralValue::duration_with_type(
149 result,
150 unit.clone(),
151 left.lemma_type.clone(),
152 )))
153 }
154 _ => match number_arithmetic(*value, op, *r) {
155 Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
156 result,
157 unit.clone(),
158 left.lemma_type.clone(),
159 ))),
160 Err(msg) => OperationResult::Veto(Some(msg)),
161 },
162 },
163
164 (ValueKind::Ratio(r, _), ValueKind::Duration(value, unit)) => match op {
166 ArithmeticComputation::Add => {
167 let result = *value * (Decimal::ONE + *r);
168 OperationResult::Value(Box::new(LiteralValue::duration_with_type(
169 result,
170 unit.clone(),
171 right.lemma_type.clone(),
172 )))
173 }
174 ArithmeticComputation::Subtract => {
175 let result = *value * (Decimal::ONE - *r);
176 OperationResult::Value(Box::new(LiteralValue::duration_with_type(
177 result,
178 unit.clone(),
179 right.lemma_type.clone(),
180 )))
181 }
182 _ => match number_arithmetic(*r, op, *value) {
183 Ok(result) => OperationResult::Value(Box::new(LiteralValue::duration_with_type(
184 result,
185 unit.clone(),
186 right.lemma_type.clone(),
187 ))),
188 Err(msg) => OperationResult::Veto(Some(msg)),
189 },
190 },
191 (ValueKind::Ratio(l, lu), ValueKind::Ratio(r, _ru)) => {
193 match number_arithmetic(*l, op, *r) {
195 Ok(result) => OperationResult::Value(Box::new(LiteralValue::ratio_with_type(
196 result,
197 lu.clone(),
198 left.lemma_type.clone(),
199 ))),
200 Err(msg) => OperationResult::Veto(Some(msg)),
201 }
202 }
203 (ValueKind::Scale(l_val, l_unit), ValueKind::Scale(r_val, r_unit)) => {
205 if l_unit != r_unit
207 && (matches!(
208 op,
209 ArithmeticComputation::Add | ArithmeticComputation::Subtract
210 ))
211 {
212 return OperationResult::Veto(Some(format!(
213 "Cannot apply '{}' to values with different units: {:?} and {:?}",
214 op, l_unit, r_unit
215 )));
216 }
217 let preserved_unit = l_unit.clone();
219 match number_arithmetic(*l_val, op, *r_val) {
220 Ok(result) => OperationResult::Value(Box::new(LiteralValue::scale_with_type(
221 result,
222 preserved_unit,
223 left.lemma_type.clone(),
224 ))),
225 Err(msg) => OperationResult::Veto(Some(msg)),
226 }
227 }
228 (ValueKind::Ratio(ratio_val, _), ValueKind::Scale(scale_val, scale_unit)) => {
230 match op {
231 ArithmeticComputation::Multiply => {
232 match number_arithmetic(*ratio_val, op, *scale_val) {
233 Ok(result) => {
234 OperationResult::Value(Box::new(LiteralValue::scale_with_type(
235 result,
236 scale_unit.clone(),
237 right.lemma_type.clone(),
238 )))
239 }
240 Err(msg) => OperationResult::Veto(Some(msg)),
241 }
242 }
243 ArithmeticComputation::Divide => {
244 if *scale_val == Decimal::ZERO {
245 return OperationResult::Veto(Some("Division by zero".to_string()));
246 }
247 match number_arithmetic(*ratio_val, op, *scale_val) {
248 Ok(result) => {
249 OperationResult::Value(Box::new(LiteralValue::scale_with_type(
250 result,
251 scale_unit.clone(),
252 right.lemma_type.clone(),
253 )))
254 }
255 Err(msg) => OperationResult::Veto(Some(msg)),
256 }
257 }
258 ArithmeticComputation::Add | ArithmeticComputation::Subtract => {
259 let ratio_amount = *scale_val * *ratio_val;
261 let result = match op {
262 ArithmeticComputation::Add => *scale_val + ratio_amount,
263 ArithmeticComputation::Subtract => *scale_val - ratio_amount,
264 _ => {
265 return OperationResult::Veto(Some(format!(
266 "Operation '{}' not supported for ratio and scale",
267 op
268 )))
269 }
270 };
271 OperationResult::Value(Box::new(LiteralValue::scale_with_type(
272 result,
273 scale_unit.clone(), right.lemma_type.clone(),
275 )))
276 }
277 _ => OperationResult::Veto(Some(format!(
278 "Operation {:?} not supported for ratio and scale",
279 op
280 ))),
281 }
282 }
283 (ValueKind::Scale(scale_val, scale_unit), ValueKind::Ratio(ratio_val, _)) => {
285 match op {
286 ArithmeticComputation::Multiply => {
287 match number_arithmetic(*scale_val, op, *ratio_val) {
288 Ok(result) => {
289 OperationResult::Value(Box::new(LiteralValue::scale_with_type(
290 result,
291 scale_unit.clone(),
292 left.lemma_type.clone(),
293 )))
294 }
295 Err(msg) => OperationResult::Veto(Some(msg)),
296 }
297 }
298 ArithmeticComputation::Divide => {
299 if *ratio_val == Decimal::ZERO {
300 return OperationResult::Veto(Some("Division by zero".to_string()));
301 }
302 match number_arithmetic(*scale_val, op, *ratio_val) {
303 Ok(result) => {
304 OperationResult::Value(Box::new(LiteralValue::scale_with_type(
305 result,
306 scale_unit.clone(),
307 left.lemma_type.clone(), )))
309 }
310 Err(msg) => OperationResult::Veto(Some(msg)),
311 }
312 }
313 ArithmeticComputation::Add | ArithmeticComputation::Subtract => {
314 let ratio_amount = *scale_val * *ratio_val;
316 let result = match op {
317 ArithmeticComputation::Add => *scale_val + ratio_amount,
318 ArithmeticComputation::Subtract => *scale_val - ratio_amount,
319 _ => {
320 return OperationResult::Veto(Some(format!(
321 "Operation '{}' not supported for scale and ratio",
322 op
323 )))
324 }
325 };
326 OperationResult::Value(Box::new(LiteralValue::scale_with_type(
327 result,
328 scale_unit.clone(), left.lemma_type.clone(),
330 )))
331 }
332 _ => OperationResult::Veto(Some(format!(
333 "Operation {:?} not supported for scale and ratio",
334 op
335 ))),
336 }
337 }
338
339 (ValueKind::Scale(scale_val, scale_unit), ValueKind::Number(n)) => {
341 match number_arithmetic(*scale_val, op, *n) {
342 Ok(result) => OperationResult::Value(Box::new(LiteralValue::scale_with_type(
343 result,
344 scale_unit.clone(),
345 left.lemma_type.clone(),
346 ))),
347 Err(msg) => OperationResult::Veto(Some(msg)),
348 }
349 }
350 (ValueKind::Number(n), ValueKind::Scale(scale_val, scale_unit)) => {
352 match number_arithmetic(*n, op, *scale_val) {
353 Ok(result) => OperationResult::Value(Box::new(LiteralValue::scale_with_type(
354 result,
355 scale_unit.clone(),
356 right.lemma_type.clone(),
357 ))),
358 Err(msg) => OperationResult::Veto(Some(msg)),
359 }
360 }
361 (ValueKind::Scale(scale_val, _), ValueKind::Duration(dur_val, _)) => {
363 match number_arithmetic(*scale_val, op, *dur_val) {
364 Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
365 result,
366 primitive_number().clone(),
367 ))),
368 Err(msg) => OperationResult::Veto(Some(msg)),
369 }
370 }
371
372 (ValueKind::Duration(dur_val, _), ValueKind::Scale(scale_val, _)) => {
374 match number_arithmetic(*dur_val, op, *scale_val) {
375 Ok(result) => OperationResult::Value(Box::new(LiteralValue::number_with_type(
376 result,
377 primitive_number().clone(),
378 ))),
379 Err(msg) => OperationResult::Veto(Some(msg)),
380 }
381 }
382 _ => OperationResult::Veto(Some(format!(
383 "Arithmetic operation {:?} not supported for types {:?} and {:?}",
384 op,
385 type_name(left),
386 type_name(right)
387 ))),
388 }
389}
390
391fn number_arithmetic(
392 left: Decimal,
393 op: &ArithmeticComputation,
394 right: Decimal,
395) -> Result<Decimal, String> {
396 use rust_decimal::prelude::ToPrimitive;
397
398 match op {
399 ArithmeticComputation::Add => Ok(left + right),
400 ArithmeticComputation::Subtract => Ok(left - right),
401 ArithmeticComputation::Multiply => Ok(left * right),
402 ArithmeticComputation::Divide => {
403 if right == Decimal::ZERO {
404 return Err("Division by zero".to_string());
405 }
406 Ok(left / right)
407 }
408 ArithmeticComputation::Modulo => {
409 if right == Decimal::ZERO {
410 return Err("Division by zero (modulo)".to_string());
411 }
412 Ok(left % right)
413 }
414 ArithmeticComputation::Power => {
415 let base = left
416 .to_f64()
417 .ok_or_else(|| "Cannot convert base to float".to_string())?;
418 let exp = right
419 .to_f64()
420 .ok_or_else(|| "Cannot convert exponent to float".to_string())?;
421 let result = base.powf(exp);
422 Decimal::from_f64_retain(result)
423 .ok_or_else(|| "Power result cannot be represented".to_string())
424 }
425 }
426}
427
428fn type_name(value: &LiteralValue) -> String {
429 value.get_type().name().to_string()
430}