1use crate::{ArithmeticComputation, Expression, ExpressionKind, FactReference, LemmaError};
7use std::sync::Arc;
8
9pub fn solve_for(expression: &Expression, unknown: &Expression) -> Result<Expression, LemmaError> {
23 let occurrence_count = count_occurrences(expression, unknown);
24
25 if occurrence_count == 0 {
26 let loc = expression
27 .source_location
28 .as_ref()
29 .or(unknown.source_location.as_ref())
30 .unwrap_or_else(|| unreachable!("BUG: solve_for called with missing source_location"));
31 let source_text = std::sync::Arc::from("");
32 return Err(LemmaError::engine(
33 "Unknown not found in expression",
34 loc.span.clone(),
35 loc.attribute.clone(),
36 source_text,
37 loc.doc_name.clone(),
38 1,
39 None::<String>,
40 ));
41 }
42
43 if occurrence_count > 1 {
44 let loc = expression
45 .source_location
46 .as_ref()
47 .or(unknown.source_location.as_ref())
48 .unwrap_or_else(|| unreachable!("BUG: solve_for called with missing source_location"));
49 let source_text = std::sync::Arc::from("");
50 return Err(LemmaError::engine(
51 "Non-linear: unknown appears multiple times",
52 loc.span.clone(),
53 loc.attribute.clone(),
54 source_text,
55 loc.doc_name.clone(),
56 1,
57 None::<String>,
58 ));
59 }
60
61 let value_placeholder = Expression::new(
62 ExpressionKind::FactReference(FactReference::local("value".to_string())),
63 None,
64 );
65
66 isolate(expression, unknown, value_placeholder)
67}
68
69pub fn substitute(expression: &Expression, from: &Expression, to: &Expression) -> Expression {
71 if expression == from {
72 return to.clone();
73 }
74
75 match &expression.kind {
76 ExpressionKind::Arithmetic(left, operation, right) => {
77 let substituted_left = substitute(left, from, to);
78 let substituted_right = substitute(right, from, to);
79 Expression::new(
80 ExpressionKind::Arithmetic(
81 Arc::new(substituted_left),
82 operation.clone(),
83 Arc::new(substituted_right),
84 ),
85 None,
86 )
87 }
88 _ => expression.clone(),
89 }
90}
91
92fn count_occurrences(expression: &Expression, unknown: &Expression) -> usize {
94 if expression == unknown {
95 return 1;
96 }
97
98 match &expression.kind {
99 ExpressionKind::Arithmetic(left, _, right) => {
100 count_occurrences(left, unknown) + count_occurrences(right, unknown)
101 }
102 _ => 0,
103 }
104}
105
106fn isolate(
111 expression: &Expression,
112 unknown: &Expression,
113 result: Expression,
114) -> Result<Expression, LemmaError> {
115 if expression == unknown {
116 return Ok(result);
117 }
118
119 match &expression.kind {
120 ExpressionKind::Arithmetic(left, operation, right) => {
121 let left_count = count_occurrences(left, unknown);
122
123 if left_count > 0 {
124 let new_result = inverse_left(operation.clone(), result, (**right).clone())?;
125 isolate(left, unknown, new_result)
126 } else {
127 let new_result = inverse_right(operation.clone(), result, (**left).clone())?;
128 isolate(right, unknown, new_result)
129 }
130 }
131 _ => {
132 let loc = expression
133 .source_location
134 .as_ref()
135 .or(unknown.source_location.as_ref())
136 .expect("Expression or unknown must have source_location");
137 let source_text = std::sync::Arc::from("");
138 Err(LemmaError::engine(
139 "Unknown not found on this path",
140 loc.span.clone(),
141 loc.attribute.clone(),
142 source_text,
143 loc.doc_name.clone(),
144 1,
145 None::<String>,
146 ))
147 }
148 }
149}
150
151fn inverse_left(
155 operation: ArithmeticComputation,
156 result: Expression,
157 right: Expression,
158) -> Result<Expression, LemmaError> {
159 let inverse_operation = match operation {
160 ArithmeticComputation::Add => ArithmeticComputation::Subtract,
162 ArithmeticComputation::Subtract => ArithmeticComputation::Add,
164 ArithmeticComputation::Multiply => ArithmeticComputation::Divide,
166 ArithmeticComputation::Divide => ArithmeticComputation::Multiply,
168 ArithmeticComputation::Modulo => {
169 let loc = result
170 .source_location
171 .as_ref()
172 .or(right.source_location.as_ref())
173 .expect("Result or right expression must have source_location");
174 let source_text = std::sync::Arc::from("");
175 return Err(LemmaError::engine(
176 "Modulo operation is not invertible",
177 loc.span.clone(),
178 loc.attribute.clone(),
179 source_text,
180 loc.doc_name.clone(),
181 1,
182 None::<String>,
183 ));
184 }
185 ArithmeticComputation::Power => {
186 let loc = result
187 .source_location
188 .as_ref()
189 .or(right.source_location.as_ref())
190 .expect("Result or right expression must have source_location");
191 let source_text = std::sync::Arc::from("");
192 return Err(LemmaError::engine(
193 "Power operation is not invertible",
194 loc.span.clone(),
195 loc.attribute.clone(),
196 source_text,
197 loc.doc_name.clone(),
198 1,
199 None::<String>,
200 ));
201 }
202 };
203
204 Ok(Expression::new(
205 ExpressionKind::Arithmetic(Arc::new(result), inverse_operation, Arc::new(right)),
206 None,
207 ))
208}
209
210fn inverse_right(
217 operation: ArithmeticComputation,
218 result: Expression,
219 left: Expression,
220) -> Result<Expression, LemmaError> {
221 match operation {
222 ArithmeticComputation::Add => Ok(Expression::new(
224 ExpressionKind::Arithmetic(
225 Arc::new(result),
226 ArithmeticComputation::Subtract,
227 Arc::new(left),
228 ),
229 None,
230 )),
231 ArithmeticComputation::Subtract => Ok(Expression::new(
233 ExpressionKind::Arithmetic(
234 Arc::new(left),
235 ArithmeticComputation::Subtract,
236 Arc::new(result),
237 ),
238 None,
239 )),
240 ArithmeticComputation::Multiply => Ok(Expression::new(
242 ExpressionKind::Arithmetic(
243 Arc::new(result),
244 ArithmeticComputation::Divide,
245 Arc::new(left),
246 ),
247 None,
248 )),
249 ArithmeticComputation::Divide => Ok(Expression::new(
251 ExpressionKind::Arithmetic(
252 Arc::new(left),
253 ArithmeticComputation::Divide,
254 Arc::new(result),
255 ),
256 None,
257 )),
258 ArithmeticComputation::Modulo => {
259 let loc = result
260 .source_location
261 .as_ref()
262 .or(left.source_location.as_ref())
263 .expect("Result or left expression must have source_location");
264 let source_text = std::sync::Arc::from("");
265 Err(LemmaError::engine(
266 "Modulo operation is not invertible",
267 loc.span.clone(),
268 loc.attribute.clone(),
269 source_text,
270 loc.doc_name.clone(),
271 1,
272 None::<String>,
273 ))
274 }
275 ArithmeticComputation::Power => {
276 let loc = result
277 .source_location
278 .as_ref()
279 .or(left.source_location.as_ref())
280 .expect("Result or left expression must have source_location");
281 let source_text = std::sync::Arc::from("");
282 Err(LemmaError::engine(
283 "Power operation is not invertible",
284 loc.span.clone(),
285 loc.attribute.clone(),
286 source_text,
287 loc.doc_name.clone(),
288 1,
289 None::<String>,
290 ))
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use rust_decimal::Decimal;
298
299 use super::*;
300 use crate::LiteralValue;
301
302 fn placeholder(name: &str) -> Expression {
303 use crate::parsing::ast::Span;
304 use crate::Source;
305 Expression::new(
306 ExpressionKind::FactReference(FactReference::local(name.to_string())),
307 Some(Source::new(
308 "<test>",
309 Span {
310 start: 0,
311 end: name.len(),
312 line: 1,
313 col: 0,
314 },
315 "test",
316 )),
317 )
318 }
319
320 fn number(value: rust_decimal::Decimal) -> Expression {
321 use crate::parsing::ast::Span;
322 use crate::Source;
323 Expression::new(
324 ExpressionKind::Literal(LiteralValue::number(value)),
325 Some(Source::new(
326 "<test>",
327 Span {
328 start: 0,
329 end: 0,
330 line: 1,
331 col: 0,
332 },
333 "test",
334 )),
335 )
336 }
337
338 fn arithmetic(
339 left: Expression,
340 operation: ArithmeticComputation,
341 right: Expression,
342 ) -> Expression {
343 use crate::parsing::ast::Span;
344 use crate::Source;
345 Expression::new(
346 ExpressionKind::Arithmetic(Arc::new(left), operation, Arc::new(right)),
347 Some(Source::new(
348 "<test>",
349 Span {
350 start: 0,
351 end: 0,
352 line: 1,
353 col: 0,
354 },
355 "test",
356 )),
357 )
358 }
359
360 #[test]
361 fn solve_multiply_left() {
362 let x = placeholder("x");
364 let expression = arithmetic(
365 x.clone(),
366 ArithmeticComputation::Multiply,
367 number(Decimal::from(3)),
368 );
369
370 let result = solve_for(&expression, &x).unwrap();
371
372 let expected = arithmetic(
373 placeholder("value"),
374 ArithmeticComputation::Divide,
375 number(Decimal::from(3)),
376 );
377 assert_eq!(result, expected);
378 }
379
380 #[test]
381 fn solve_multiply_right() {
382 let x = placeholder("x");
384 let expression = arithmetic(
385 number(Decimal::from(3)),
386 ArithmeticComputation::Multiply,
387 x.clone(),
388 );
389
390 let result = solve_for(&expression, &x).unwrap();
391
392 let expected = arithmetic(
393 placeholder("value"),
394 ArithmeticComputation::Divide,
395 number(Decimal::from(3)),
396 );
397 assert_eq!(result, expected);
398 }
399
400 #[test]
401 fn solve_divide_left() {
402 let x = placeholder("x");
404 let expression = arithmetic(
405 x.clone(),
406 ArithmeticComputation::Divide,
407 number(Decimal::from(3)),
408 );
409
410 let result = solve_for(&expression, &x).unwrap();
411
412 let expected = arithmetic(
413 placeholder("value"),
414 ArithmeticComputation::Multiply,
415 number(Decimal::from(3)),
416 );
417 assert_eq!(result, expected);
418 }
419
420 #[test]
421 fn solve_divide_right() {
422 let x = placeholder("x");
424 let expression = arithmetic(
425 number(Decimal::from(3)),
426 ArithmeticComputation::Divide,
427 x.clone(),
428 );
429
430 let result = solve_for(&expression, &x).unwrap();
431
432 let expected = arithmetic(
433 number(Decimal::from(3)),
434 ArithmeticComputation::Divide,
435 placeholder("value"),
436 );
437 assert_eq!(result, expected);
438 }
439
440 #[test]
441 fn solve_add_left() {
442 let x = placeholder("x");
444 let expression = arithmetic(
445 x.clone(),
446 ArithmeticComputation::Add,
447 number(Decimal::from(3)),
448 );
449
450 let result = solve_for(&expression, &x).unwrap();
451
452 let expected = arithmetic(
453 placeholder("value"),
454 ArithmeticComputation::Subtract,
455 number(Decimal::from(3)),
456 );
457 assert_eq!(result, expected);
458 }
459
460 #[test]
461 fn solve_subtract_left() {
462 let x = placeholder("x");
464 let expression = arithmetic(
465 x.clone(),
466 ArithmeticComputation::Subtract,
467 number(Decimal::from(3)),
468 );
469
470 let result = solve_for(&expression, &x).unwrap();
471
472 let expected = arithmetic(
473 placeholder("value"),
474 ArithmeticComputation::Add,
475 number(Decimal::from(3)),
476 );
477 assert_eq!(result, expected);
478 }
479
480 #[test]
481 fn solve_subtract_right() {
482 let x = placeholder("x");
484 let expression = arithmetic(
485 number(Decimal::from(3)),
486 ArithmeticComputation::Subtract,
487 x.clone(),
488 );
489
490 let result = solve_for(&expression, &x).unwrap();
491
492 let expected = arithmetic(
493 number(Decimal::from(3)),
494 ArithmeticComputation::Subtract,
495 placeholder("value"),
496 );
497 assert_eq!(result, expected);
498 }
499
500 #[test]
501 fn solve_compound_fahrenheit_to_celsius() {
502 let celsius = placeholder("celsius");
505 let nine_fifths = arithmetic(
506 number(Decimal::from(9)),
507 ArithmeticComputation::Divide,
508 number(Decimal::from(5)),
509 );
510 let expression = arithmetic(
511 arithmetic(
512 celsius.clone(),
513 ArithmeticComputation::Multiply,
514 nine_fifths,
515 ),
516 ArithmeticComputation::Add,
517 number(Decimal::from(32)),
518 );
519
520 let result = solve_for(&expression, &celsius).unwrap();
521
522 let expected_nine_fifths = arithmetic(
524 number(Decimal::from(9)),
525 ArithmeticComputation::Divide,
526 number(Decimal::from(5)),
527 );
528 let expected = arithmetic(
529 arithmetic(
530 placeholder("value"),
531 ArithmeticComputation::Subtract,
532 number(Decimal::from(32)),
533 ),
534 ArithmeticComputation::Divide,
535 expected_nine_fifths,
536 );
537 assert_eq!(result, expected);
538 }
539
540 #[test]
541 fn solve_with_fact_reference() {
542 let x = placeholder("x");
544 let offset = placeholder("offset");
545 let nine_fifths = arithmetic(
546 number(Decimal::from(9)),
547 ArithmeticComputation::Divide,
548 number(Decimal::from(5)),
549 );
550 let expression = arithmetic(
551 arithmetic(x.clone(), ArithmeticComputation::Multiply, nine_fifths),
552 ArithmeticComputation::Add,
553 offset.clone(),
554 );
555
556 let result = solve_for(&expression, &x).unwrap();
557
558 let expected_nine_fifths = arithmetic(
560 number(Decimal::from(9)),
561 ArithmeticComputation::Divide,
562 number(Decimal::from(5)),
563 );
564 let expected = arithmetic(
565 arithmetic(
566 placeholder("value"),
567 ArithmeticComputation::Subtract,
568 offset,
569 ),
570 ArithmeticComputation::Divide,
571 expected_nine_fifths,
572 );
573 assert_eq!(result, expected);
574 }
575
576 #[test]
577 fn error_unknown_not_found() {
578 let x = placeholder("x");
579 let y = placeholder("y");
580 let expression = arithmetic(y, ArithmeticComputation::Multiply, number(Decimal::from(3)));
581
582 let result = solve_for(&expression, &x);
583
584 assert!(result.is_err());
585 assert!(result
586 .unwrap_err()
587 .to_string()
588 .contains("Unknown not found"));
589 }
590
591 #[test]
592 fn error_non_linear() {
593 let x = placeholder("x");
595 let expression = arithmetic(x.clone(), ArithmeticComputation::Multiply, x.clone());
596
597 let result = solve_for(&expression, &x);
598
599 assert!(result.is_err());
600 let error_msg = result.unwrap_err().to_string();
601 assert!(
602 error_msg.contains("Non-linear")
603 || error_msg.contains("non-linear")
604 || error_msg.contains("multiple times")
605 );
606 }
607
608 #[test]
609 fn error_modulo_not_invertible() {
610 let x = placeholder("x");
611 let expression = arithmetic(
612 x.clone(),
613 ArithmeticComputation::Modulo,
614 number(Decimal::from(3)),
615 );
616
617 let result = solve_for(&expression, &x);
618
619 assert!(result.is_err());
620 let error_msg = result.unwrap_err().to_string();
621 assert!(
622 error_msg.contains("Modulo operation is not invertible")
623 || error_msg.contains("not invertible")
624 );
625 }
626
627 #[test]
628 fn error_power_not_invertible() {
629 let x = placeholder("x");
630 let expression = arithmetic(
631 x.clone(),
632 ArithmeticComputation::Power,
633 number(Decimal::from(2)),
634 );
635
636 let result = solve_for(&expression, &x);
637
638 assert!(result.is_err());
639 let error_msg = result.unwrap_err().to_string();
640 assert!(
641 error_msg.contains("Power operation is not invertible")
642 || error_msg.contains("not invertible")
643 );
644 }
645
646 #[test]
647 fn substitute_simple() {
648 let x = placeholder("x");
649 let replacement = number(Decimal::from(5));
650
651 let expression = arithmetic(
652 x.clone(),
653 ArithmeticComputation::Multiply,
654 number(Decimal::from(3)),
655 );
656
657 let result = substitute(&expression, &x, &replacement);
658
659 let expected = arithmetic(
660 number(Decimal::from(5)),
661 ArithmeticComputation::Multiply,
662 number(Decimal::from(3)),
663 );
664 assert_eq!(result, expected);
665 }
666
667 #[test]
668 fn substitute_nested() {
669 let x = placeholder("x");
671 let replacement = number(Decimal::from(5));
672
673 let inner = arithmetic(
674 x.clone(),
675 ArithmeticComputation::Add,
676 number(Decimal::from(2)),
677 );
678 let expression = arithmetic(
679 inner,
680 ArithmeticComputation::Multiply,
681 number(Decimal::from(3)),
682 );
683
684 let result = substitute(&expression, &x, &replacement);
685
686 let expected_inner = arithmetic(
687 number(Decimal::from(5)),
688 ArithmeticComputation::Add,
689 number(Decimal::from(2)),
690 );
691 let expected = arithmetic(
692 expected_inner,
693 ArithmeticComputation::Multiply,
694 number(Decimal::from(3)),
695 );
696 assert_eq!(result, expected);
697 }
698
699 #[test]
700 fn substitute_chained_units() {
701 let kilogram = placeholder("kilogram");
705 let gram = placeholder("gram");
706
707 let kilogram_definition = arithmetic(
708 number(Decimal::from(1000)),
709 ArithmeticComputation::Multiply,
710 gram.clone(),
711 );
712 let milligram_expression = arithmetic(
713 kilogram.clone(),
714 ArithmeticComputation::Divide,
715 number(Decimal::from(1_000_000)),
716 );
717
718 let result = substitute(&milligram_expression, &kilogram, &kilogram_definition);
719
720 let expected = arithmetic(
721 arithmetic(
722 number(Decimal::from(1000)),
723 ArithmeticComputation::Multiply,
724 gram,
725 ),
726 ArithmeticComputation::Divide,
727 number(Decimal::from(1_000_000)),
728 );
729 assert_eq!(result, expected);
730 }
731}