Skip to main content

lemma/computation/
algebra.rs

1//! Algebraic equation solving for expression trees.
2//!
3//! Solves single-variable linear equations by isolating the unknown
4//! through inverse operations.
5
6use crate::{ArithmeticComputation, Expression, ExpressionKind, FactReference, LemmaError};
7use std::sync::Arc;
8
9/// Solve for an unknown expression within a larger expression.
10///
11/// Given an expression tree and a target unknown, returns a new expression
12/// that computes the unknown's value from a "value" placeholder.
13///
14/// The unknown must appear exactly once in the expression (linear equations only).
15///
16/// # Errors
17///
18/// Returns an error if:
19/// - The unknown does not appear in the expression
20/// - The unknown appears more than once (non-linear)
21/// - The expression contains unsupported operations (Modulo, Power)
22pub fn solve_for(expression: &Expression, unknown: &Expression) -> Result<Expression, LemmaError> {
23    let occurrence_count = count_occurrences(expression, unknown);
24
25    if occurrence_count == 0 {
26        let loc = expression
27            .source_location
28            .as_ref()
29            .or(unknown.source_location.as_ref())
30            .unwrap_or_else(|| unreachable!("BUG: solve_for called with missing source_location"));
31        let source_text = std::sync::Arc::from("");
32        return Err(LemmaError::engine(
33            "Unknown not found in expression",
34            loc.span.clone(),
35            loc.attribute.clone(),
36            source_text,
37            loc.doc_name.clone(),
38            1,
39            None::<String>,
40        ));
41    }
42
43    if occurrence_count > 1 {
44        let loc = expression
45            .source_location
46            .as_ref()
47            .or(unknown.source_location.as_ref())
48            .unwrap_or_else(|| unreachable!("BUG: solve_for called with missing source_location"));
49        let source_text = std::sync::Arc::from("");
50        return Err(LemmaError::engine(
51            "Non-linear: unknown appears multiple times",
52            loc.span.clone(),
53            loc.attribute.clone(),
54            source_text,
55            loc.doc_name.clone(),
56            1,
57            None::<String>,
58        ));
59    }
60
61    let value_placeholder = Expression::new(
62        ExpressionKind::FactReference(FactReference::local("value".to_string())),
63        None,
64    );
65
66    isolate(expression, unknown, value_placeholder)
67}
68
69/// Replace all occurrences of `from` with `to` in the expression tree.
70pub fn substitute(expression: &Expression, from: &Expression, to: &Expression) -> Expression {
71    if expression == from {
72        return to.clone();
73    }
74
75    match &expression.kind {
76        ExpressionKind::Arithmetic(left, operation, right) => {
77            let substituted_left = substitute(left, from, to);
78            let substituted_right = substitute(right, from, to);
79            Expression::new(
80                ExpressionKind::Arithmetic(
81                    Arc::new(substituted_left),
82                    operation.clone(),
83                    Arc::new(substituted_right),
84                ),
85                None,
86            )
87        }
88        _ => expression.clone(),
89    }
90}
91
92/// Count how many times the unknown expression appears in the expression tree.
93fn count_occurrences(expression: &Expression, unknown: &Expression) -> usize {
94    if expression == unknown {
95        return 1;
96    }
97
98    match &expression.kind {
99        ExpressionKind::Arithmetic(left, _, right) => {
100            count_occurrences(left, unknown) + count_occurrences(right, unknown)
101        }
102        _ => 0,
103    }
104}
105
106/// Isolate the unknown by walking the tree and applying inverse operations.
107///
108/// At each arithmetic node, determines which side contains the unknown,
109/// applies the inverse operation to the accumulated result, and recurses.
110fn isolate(
111    expression: &Expression,
112    unknown: &Expression,
113    result: Expression,
114) -> Result<Expression, LemmaError> {
115    if expression == unknown {
116        return Ok(result);
117    }
118
119    match &expression.kind {
120        ExpressionKind::Arithmetic(left, operation, right) => {
121            let left_count = count_occurrences(left, unknown);
122
123            if left_count > 0 {
124                let new_result = inverse_left(operation.clone(), result, (**right).clone())?;
125                isolate(left, unknown, new_result)
126            } else {
127                let new_result = inverse_right(operation.clone(), result, (**left).clone())?;
128                isolate(right, unknown, new_result)
129            }
130        }
131        _ => {
132            let loc = expression
133                .source_location
134                .as_ref()
135                .or(unknown.source_location.as_ref())
136                .expect("Expression or unknown must have source_location");
137            let source_text = std::sync::Arc::from("");
138            Err(LemmaError::engine(
139                "Unknown not found on this path",
140                loc.span.clone(),
141                loc.attribute.clone(),
142                source_text,
143                loc.doc_name.clone(),
144                1,
145                None::<String>,
146            ))
147        }
148    }
149}
150
151/// Apply inverse operation when unknown is on the left side of the operator.
152///
153/// Given `left op right = result`, solves for `left`.
154fn inverse_left(
155    operation: ArithmeticComputation,
156    result: Expression,
157    right: Expression,
158) -> Result<Expression, LemmaError> {
159    let inverse_operation = match operation {
160        // left + right = result → left = result - right
161        ArithmeticComputation::Add => ArithmeticComputation::Subtract,
162        // left - right = result → left = result + right
163        ArithmeticComputation::Subtract => ArithmeticComputation::Add,
164        // left * right = result → left = result / right
165        ArithmeticComputation::Multiply => ArithmeticComputation::Divide,
166        // left / right = result → left = result * right
167        ArithmeticComputation::Divide => ArithmeticComputation::Multiply,
168        ArithmeticComputation::Modulo => {
169            let loc = result
170                .source_location
171                .as_ref()
172                .or(right.source_location.as_ref())
173                .expect("Result or right expression must have source_location");
174            let source_text = std::sync::Arc::from("");
175            return Err(LemmaError::engine(
176                "Modulo operation is not invertible",
177                loc.span.clone(),
178                loc.attribute.clone(),
179                source_text,
180                loc.doc_name.clone(),
181                1,
182                None::<String>,
183            ));
184        }
185        ArithmeticComputation::Power => {
186            let loc = result
187                .source_location
188                .as_ref()
189                .or(right.source_location.as_ref())
190                .expect("Result or right expression must have source_location");
191            let source_text = std::sync::Arc::from("");
192            return Err(LemmaError::engine(
193                "Power operation is not invertible",
194                loc.span.clone(),
195                loc.attribute.clone(),
196                source_text,
197                loc.doc_name.clone(),
198                1,
199                None::<String>,
200            ));
201        }
202    };
203
204    Ok(Expression::new(
205        ExpressionKind::Arithmetic(Arc::new(result), inverse_operation, Arc::new(right)),
206        None,
207    ))
208}
209
210/// Apply inverse operation when unknown is on the right side of the operator.
211///
212/// Given `left op right = result`, solves for `right`.
213///
214/// Note: For non-commutative operations (subtract, divide), the inverse
215/// is different than when unknown is on the left.
216fn inverse_right(
217    operation: ArithmeticComputation,
218    result: Expression,
219    left: Expression,
220) -> Result<Expression, LemmaError> {
221    match operation {
222        // left + right = result → right = result - left
223        ArithmeticComputation::Add => Ok(Expression::new(
224            ExpressionKind::Arithmetic(
225                Arc::new(result),
226                ArithmeticComputation::Subtract,
227                Arc::new(left),
228            ),
229            None,
230        )),
231        // left - right = result → right = left - result (different!)
232        ArithmeticComputation::Subtract => Ok(Expression::new(
233            ExpressionKind::Arithmetic(
234                Arc::new(left),
235                ArithmeticComputation::Subtract,
236                Arc::new(result),
237            ),
238            None,
239        )),
240        // left * right = result → right = result / left
241        ArithmeticComputation::Multiply => Ok(Expression::new(
242            ExpressionKind::Arithmetic(
243                Arc::new(result),
244                ArithmeticComputation::Divide,
245                Arc::new(left),
246            ),
247            None,
248        )),
249        // left / right = result → right = left / result (different!)
250        ArithmeticComputation::Divide => Ok(Expression::new(
251            ExpressionKind::Arithmetic(
252                Arc::new(left),
253                ArithmeticComputation::Divide,
254                Arc::new(result),
255            ),
256            None,
257        )),
258        ArithmeticComputation::Modulo => {
259            let loc = result
260                .source_location
261                .as_ref()
262                .or(left.source_location.as_ref())
263                .expect("Result or left expression must have source_location");
264            let source_text = std::sync::Arc::from("");
265            Err(LemmaError::engine(
266                "Modulo operation is not invertible",
267                loc.span.clone(),
268                loc.attribute.clone(),
269                source_text,
270                loc.doc_name.clone(),
271                1,
272                None::<String>,
273            ))
274        }
275        ArithmeticComputation::Power => {
276            let loc = result
277                .source_location
278                .as_ref()
279                .or(left.source_location.as_ref())
280                .expect("Result or left expression must have source_location");
281            let source_text = std::sync::Arc::from("");
282            Err(LemmaError::engine(
283                "Power operation is not invertible",
284                loc.span.clone(),
285                loc.attribute.clone(),
286                source_text,
287                loc.doc_name.clone(),
288                1,
289                None::<String>,
290            ))
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use rust_decimal::Decimal;
298
299    use super::*;
300    use crate::LiteralValue;
301
302    fn placeholder(name: &str) -> Expression {
303        use crate::parsing::ast::Span;
304        use crate::Source;
305        Expression::new(
306            ExpressionKind::FactReference(FactReference::local(name.to_string())),
307            Some(Source::new(
308                "<test>",
309                Span {
310                    start: 0,
311                    end: name.len(),
312                    line: 1,
313                    col: 0,
314                },
315                "test",
316            )),
317        )
318    }
319
320    fn number(value: rust_decimal::Decimal) -> Expression {
321        use crate::parsing::ast::Span;
322        use crate::Source;
323        Expression::new(
324            ExpressionKind::Literal(LiteralValue::number(value)),
325            Some(Source::new(
326                "<test>",
327                Span {
328                    start: 0,
329                    end: 0,
330                    line: 1,
331                    col: 0,
332                },
333                "test",
334            )),
335        )
336    }
337
338    fn arithmetic(
339        left: Expression,
340        operation: ArithmeticComputation,
341        right: Expression,
342    ) -> Expression {
343        use crate::parsing::ast::Span;
344        use crate::Source;
345        Expression::new(
346            ExpressionKind::Arithmetic(Arc::new(left), operation, Arc::new(right)),
347            Some(Source::new(
348                "<test>",
349                Span {
350                    start: 0,
351                    end: 0,
352                    line: 1,
353                    col: 0,
354                },
355                "test",
356            )),
357        )
358    }
359
360    #[test]
361    fn solve_multiply_left() {
362        // x * 3 = value → x = value / 3
363        let x = placeholder("x");
364        let expression = arithmetic(
365            x.clone(),
366            ArithmeticComputation::Multiply,
367            number(Decimal::from(3)),
368        );
369
370        let result = solve_for(&expression, &x).unwrap();
371
372        let expected = arithmetic(
373            placeholder("value"),
374            ArithmeticComputation::Divide,
375            number(Decimal::from(3)),
376        );
377        assert_eq!(result, expected);
378    }
379
380    #[test]
381    fn solve_multiply_right() {
382        // 3 * x = value → x = value / 3
383        let x = placeholder("x");
384        let expression = arithmetic(
385            number(Decimal::from(3)),
386            ArithmeticComputation::Multiply,
387            x.clone(),
388        );
389
390        let result = solve_for(&expression, &x).unwrap();
391
392        let expected = arithmetic(
393            placeholder("value"),
394            ArithmeticComputation::Divide,
395            number(Decimal::from(3)),
396        );
397        assert_eq!(result, expected);
398    }
399
400    #[test]
401    fn solve_divide_left() {
402        // x / 3 = value → x = value * 3
403        let x = placeholder("x");
404        let expression = arithmetic(
405            x.clone(),
406            ArithmeticComputation::Divide,
407            number(Decimal::from(3)),
408        );
409
410        let result = solve_for(&expression, &x).unwrap();
411
412        let expected = arithmetic(
413            placeholder("value"),
414            ArithmeticComputation::Multiply,
415            number(Decimal::from(3)),
416        );
417        assert_eq!(result, expected);
418    }
419
420    #[test]
421    fn solve_divide_right() {
422        // 3 / x = value → x = 3 / value
423        let x = placeholder("x");
424        let expression = arithmetic(
425            number(Decimal::from(3)),
426            ArithmeticComputation::Divide,
427            x.clone(),
428        );
429
430        let result = solve_for(&expression, &x).unwrap();
431
432        let expected = arithmetic(
433            number(Decimal::from(3)),
434            ArithmeticComputation::Divide,
435            placeholder("value"),
436        );
437        assert_eq!(result, expected);
438    }
439
440    #[test]
441    fn solve_add_left() {
442        // x + 3 = value → x = value - 3
443        let x = placeholder("x");
444        let expression = arithmetic(
445            x.clone(),
446            ArithmeticComputation::Add,
447            number(Decimal::from(3)),
448        );
449
450        let result = solve_for(&expression, &x).unwrap();
451
452        let expected = arithmetic(
453            placeholder("value"),
454            ArithmeticComputation::Subtract,
455            number(Decimal::from(3)),
456        );
457        assert_eq!(result, expected);
458    }
459
460    #[test]
461    fn solve_subtract_left() {
462        // x - 3 = value → x = value + 3
463        let x = placeholder("x");
464        let expression = arithmetic(
465            x.clone(),
466            ArithmeticComputation::Subtract,
467            number(Decimal::from(3)),
468        );
469
470        let result = solve_for(&expression, &x).unwrap();
471
472        let expected = arithmetic(
473            placeholder("value"),
474            ArithmeticComputation::Add,
475            number(Decimal::from(3)),
476        );
477        assert_eq!(result, expected);
478    }
479
480    #[test]
481    fn solve_subtract_right() {
482        // 3 - x = value → x = 3 - value
483        let x = placeholder("x");
484        let expression = arithmetic(
485            number(Decimal::from(3)),
486            ArithmeticComputation::Subtract,
487            x.clone(),
488        );
489
490        let result = solve_for(&expression, &x).unwrap();
491
492        let expected = arithmetic(
493            number(Decimal::from(3)),
494            ArithmeticComputation::Subtract,
495            placeholder("value"),
496        );
497        assert_eq!(result, expected);
498    }
499
500    #[test]
501    fn solve_compound_fahrenheit_to_celsius() {
502        // fahrenheit = celsius * 9/5 + 32
503        // Solve for celsius: celsius = (value - 32) * 5/9
504        let celsius = placeholder("celsius");
505        let nine_fifths = arithmetic(
506            number(Decimal::from(9)),
507            ArithmeticComputation::Divide,
508            number(Decimal::from(5)),
509        );
510        let expression = arithmetic(
511            arithmetic(
512                celsius.clone(),
513                ArithmeticComputation::Multiply,
514                nine_fifths,
515            ),
516            ArithmeticComputation::Add,
517            number(Decimal::from(32)),
518        );
519
520        let result = solve_for(&expression, &celsius).unwrap();
521
522        // Expected: (value - 32) / (9/5)
523        let expected_nine_fifths = arithmetic(
524            number(Decimal::from(9)),
525            ArithmeticComputation::Divide,
526            number(Decimal::from(5)),
527        );
528        let expected = arithmetic(
529            arithmetic(
530                placeholder("value"),
531                ArithmeticComputation::Subtract,
532                number(Decimal::from(32)),
533            ),
534            ArithmeticComputation::Divide,
535            expected_nine_fifths,
536        );
537        assert_eq!(result, expected);
538    }
539
540    #[test]
541    fn solve_with_fact_reference() {
542        // x * 9/5 + offset = value → x = (value - offset) / (9/5)
543        let x = placeholder("x");
544        let offset = placeholder("offset");
545        let nine_fifths = arithmetic(
546            number(Decimal::from(9)),
547            ArithmeticComputation::Divide,
548            number(Decimal::from(5)),
549        );
550        let expression = arithmetic(
551            arithmetic(x.clone(), ArithmeticComputation::Multiply, nine_fifths),
552            ArithmeticComputation::Add,
553            offset.clone(),
554        );
555
556        let result = solve_for(&expression, &x).unwrap();
557
558        // Expected: (value - offset) / (9/5)
559        let expected_nine_fifths = arithmetic(
560            number(Decimal::from(9)),
561            ArithmeticComputation::Divide,
562            number(Decimal::from(5)),
563        );
564        let expected = arithmetic(
565            arithmetic(
566                placeholder("value"),
567                ArithmeticComputation::Subtract,
568                offset,
569            ),
570            ArithmeticComputation::Divide,
571            expected_nine_fifths,
572        );
573        assert_eq!(result, expected);
574    }
575
576    #[test]
577    fn error_unknown_not_found() {
578        let x = placeholder("x");
579        let y = placeholder("y");
580        let expression = arithmetic(y, ArithmeticComputation::Multiply, number(Decimal::from(3)));
581
582        let result = solve_for(&expression, &x);
583
584        assert!(result.is_err());
585        assert!(result
586            .unwrap_err()
587            .to_string()
588            .contains("Unknown not found"));
589    }
590
591    #[test]
592    fn error_non_linear() {
593        // x * x is non-linear
594        let x = placeholder("x");
595        let expression = arithmetic(x.clone(), ArithmeticComputation::Multiply, x.clone());
596
597        let result = solve_for(&expression, &x);
598
599        assert!(result.is_err());
600        let error_msg = result.unwrap_err().to_string();
601        assert!(
602            error_msg.contains("Non-linear")
603                || error_msg.contains("non-linear")
604                || error_msg.contains("multiple times")
605        );
606    }
607
608    #[test]
609    fn error_modulo_not_invertible() {
610        let x = placeholder("x");
611        let expression = arithmetic(
612            x.clone(),
613            ArithmeticComputation::Modulo,
614            number(Decimal::from(3)),
615        );
616
617        let result = solve_for(&expression, &x);
618
619        assert!(result.is_err());
620        let error_msg = result.unwrap_err().to_string();
621        assert!(
622            error_msg.contains("Modulo operation is not invertible")
623                || error_msg.contains("not invertible")
624        );
625    }
626
627    #[test]
628    fn error_power_not_invertible() {
629        let x = placeholder("x");
630        let expression = arithmetic(
631            x.clone(),
632            ArithmeticComputation::Power,
633            number(Decimal::from(2)),
634        );
635
636        let result = solve_for(&expression, &x);
637
638        assert!(result.is_err());
639        let error_msg = result.unwrap_err().to_string();
640        assert!(
641            error_msg.contains("Power operation is not invertible")
642                || error_msg.contains("not invertible")
643        );
644    }
645
646    #[test]
647    fn substitute_simple() {
648        let x = placeholder("x");
649        let replacement = number(Decimal::from(5));
650
651        let expression = arithmetic(
652            x.clone(),
653            ArithmeticComputation::Multiply,
654            number(Decimal::from(3)),
655        );
656
657        let result = substitute(&expression, &x, &replacement);
658
659        let expected = arithmetic(
660            number(Decimal::from(5)),
661            ArithmeticComputation::Multiply,
662            number(Decimal::from(3)),
663        );
664        assert_eq!(result, expected);
665    }
666
667    #[test]
668    fn substitute_nested() {
669        // (x + 2) * 3 with x replaced by 5 → (5 + 2) * 3
670        let x = placeholder("x");
671        let replacement = number(Decimal::from(5));
672
673        let inner = arithmetic(
674            x.clone(),
675            ArithmeticComputation::Add,
676            number(Decimal::from(2)),
677        );
678        let expression = arithmetic(
679            inner,
680            ArithmeticComputation::Multiply,
681            number(Decimal::from(3)),
682        );
683
684        let result = substitute(&expression, &x, &replacement);
685
686        let expected_inner = arithmetic(
687            number(Decimal::from(5)),
688            ArithmeticComputation::Add,
689            number(Decimal::from(2)),
690        );
691        let expected = arithmetic(
692            expected_inner,
693            ArithmeticComputation::Multiply,
694            number(Decimal::from(3)),
695        );
696        assert_eq!(result, expected);
697    }
698
699    #[test]
700    fn substitute_chained_units() {
701        // milligram = kilogram / 1000000
702        // kilogram = 1000 * gram
703        // Substitute kilogram → (1000 * gram) / 1000000
704        let kilogram = placeholder("kilogram");
705        let gram = placeholder("gram");
706
707        let kilogram_definition = arithmetic(
708            number(Decimal::from(1000)),
709            ArithmeticComputation::Multiply,
710            gram.clone(),
711        );
712        let milligram_expression = arithmetic(
713            kilogram.clone(),
714            ArithmeticComputation::Divide,
715            number(Decimal::from(1_000_000)),
716        );
717
718        let result = substitute(&milligram_expression, &kilogram, &kilogram_definition);
719
720        let expected = arithmetic(
721            arithmetic(
722                number(Decimal::from(1000)),
723                ArithmeticComputation::Multiply,
724                gram,
725            ),
726            ArithmeticComputation::Divide,
727            number(Decimal::from(1_000_000)),
728        );
729        assert_eq!(result, expected);
730    }
731}