1use crate::expr::Expression;
7use crate::symbol::{Seft, Symbol};
8
9#[derive(Clone)]
11enum SolveNodeKind {
12 Atom,
13 Unary(Symbol, Box<SolveNode>),
14 Binary(Symbol, Box<SolveNode>, Box<SolveNode>),
15}
16
17#[derive(Clone)]
19struct SolveNode {
20 expr: Expression,
21 x_count: u32,
22 kind: SolveNodeKind,
23}
24
25pub fn solve_for_x_rhs_expression(lhs: &Expression, rhs: &Expression) -> Option<Expression> {
31 if lhs.count_symbol(Symbol::X) != 1 {
32 return None;
33 }
34
35 let mut node = build_solve_ast(lhs)?;
36 if node.x_count != 1 {
37 return None;
38 }
39 let mut rhs_expr = rhs.clone();
40
41 loop {
42 match node.kind {
43 SolveNodeKind::Atom => {
44 return is_x_atom(&node.expr).then_some(rhs_expr);
46 }
47 SolveNodeKind::Unary(op, child) => {
48 if child.x_count != 1 {
49 return None;
50 }
51 rhs_expr = unary_inverse_expression(op, &rhs_expr)?;
52 node = *child;
53 }
54 SolveNodeKind::Binary(op, left, right) => {
55 let lx = left.x_count;
56 let rx = right.x_count;
57 if lx + rx != 1 {
58 return None;
59 }
60
61 if lx == 1 {
62 rhs_expr = invert_binary_left(op, &rhs_expr, &right.expr)?;
63 node = *left;
64 } else {
65 rhs_expr = invert_binary_right(op, &left.expr, &rhs_expr)?;
66 node = *right;
67 }
68 }
69 }
70 }
71}
72
73pub fn canonical_expression_key(expression: &Expression) -> Option<String> {
78 let node = build_solve_ast(expression)?;
79 Some(canonical_node_key(&node))
80}
81
82fn build_solve_ast(expression: &Expression) -> Option<SolveNode> {
83 let mut stack: Vec<SolveNode> = Vec::with_capacity(expression.len());
84
85 for &sym in expression.symbols() {
86 match sym.seft() {
87 Seft::A => {
88 let mut e = Expression::new();
89 e.push(sym);
90 stack.push(SolveNode {
91 expr: e,
92 x_count: u32::from(sym == Symbol::X),
93 kind: SolveNodeKind::Atom,
94 });
95 }
96 Seft::B => {
97 let arg = stack.pop()?;
98 let mut e = arg.expr.clone();
99 e.push(sym);
100 stack.push(SolveNode {
101 expr: e,
102 x_count: arg.x_count,
103 kind: SolveNodeKind::Unary(sym, Box::new(arg)),
104 });
105 }
106 Seft::C => {
107 let rhs = stack.pop()?;
108 let lhs = stack.pop()?;
109 let mut e = Expression::new();
110 for &s in lhs.expr.symbols() {
111 e.push(s);
112 }
113 for &s in rhs.expr.symbols() {
114 e.push(s);
115 }
116 e.push(sym);
117 stack.push(SolveNode {
118 expr: e,
119 x_count: lhs.x_count.saturating_add(rhs.x_count),
120 kind: SolveNodeKind::Binary(sym, Box::new(lhs), Box::new(rhs)),
121 });
122 }
123 }
124 }
125
126 if stack.len() == 1 {
127 stack.pop()
128 } else {
129 None
130 }
131}
132
133fn unary_inverse_expression(op: Symbol, rhs_value: &Expression) -> Option<Expression> {
141 Some(match op {
142 Symbol::Neg => append_unary_expression(rhs_value, Symbol::Neg),
143 Symbol::Recip => append_unary_expression(rhs_value, Symbol::Recip),
144 Symbol::Square => append_unary_expression(rhs_value, Symbol::Sqrt), Symbol::Sqrt => append_unary_expression(rhs_value, Symbol::Square), Symbol::Ln => append_unary_expression(rhs_value, Symbol::Exp),
147 Symbol::Exp => append_unary_expression(rhs_value, Symbol::Ln),
148 Symbol::TanPi => {
149 let one = constant_expression(Symbol::One);
152 let atan = combine_binary_expressions(rhs_value, &one, Symbol::Atan2);
153 let pi = constant_expression(Symbol::Pi);
154 combine_binary_expressions(&atan, &pi, Symbol::Div)
155 }
156 Symbol::SinPi => {
157 let one = constant_expression(Symbol::One);
160 let rhs_sq = append_unary_expression(rhs_value, Symbol::Square);
161 let inner = combine_binary_expressions(&one, &rhs_sq, Symbol::Sub);
162 let denom = append_unary_expression(&inner, Symbol::Sqrt);
163 let atan = combine_binary_expressions(rhs_value, &denom, Symbol::Atan2);
164 let pi = constant_expression(Symbol::Pi);
165 combine_binary_expressions(&atan, &pi, Symbol::Div)
166 }
167 Symbol::CosPi => {
168 let one = constant_expression(Symbol::One);
171 let rhs_sq = append_unary_expression(rhs_value, Symbol::Square);
172 let inner = combine_binary_expressions(&one, &rhs_sq, Symbol::Sub);
173 let numer = append_unary_expression(&inner, Symbol::Sqrt);
174 let atan = combine_binary_expressions(&numer, rhs_value, Symbol::Atan2);
175 let pi = constant_expression(Symbol::Pi);
176 combine_binary_expressions(&atan, &pi, Symbol::Div)
177 }
178 Symbol::LambertW => {
179 let exp_rhs = append_unary_expression(rhs_value, Symbol::Exp);
182 combine_binary_expressions(rhs_value, &exp_rhs, Symbol::Mul)
183 }
184 _ => return None,
185 })
186}
187
188fn invert_binary_left(
189 op: Symbol,
190 rhs_value: &Expression,
191 known_right: &Expression,
192) -> Option<Expression> {
193 Some(match op {
194 Symbol::Add => combine_binary_expressions(rhs_value, known_right, Symbol::Sub),
195 Symbol::Sub => combine_binary_expressions(rhs_value, known_right, Symbol::Add),
196 Symbol::Mul => combine_binary_expressions(rhs_value, known_right, Symbol::Div),
197 Symbol::Div => combine_binary_expressions(rhs_value, known_right, Symbol::Mul),
198 Symbol::Pow => combine_binary_expressions(known_right, rhs_value, Symbol::Root),
199 Symbol::Root => combine_binary_expressions(rhs_value, known_right, Symbol::Log),
200 Symbol::Log => combine_binary_expressions(rhs_value, known_right, Symbol::Root),
201 _ => return None,
202 })
203}
204
205fn invert_binary_right(
206 op: Symbol,
207 known_left: &Expression,
208 rhs_value: &Expression,
209) -> Option<Expression> {
210 Some(match op {
211 Symbol::Add => combine_binary_expressions(rhs_value, known_left, Symbol::Sub),
212 Symbol::Sub => combine_binary_expressions(known_left, rhs_value, Symbol::Sub),
213 Symbol::Mul => combine_binary_expressions(rhs_value, known_left, Symbol::Div),
214 Symbol::Div => combine_binary_expressions(known_left, rhs_value, Symbol::Div),
215 Symbol::Pow => combine_binary_expressions(known_left, rhs_value, Symbol::Log),
216 Symbol::Root => combine_binary_expressions(rhs_value, known_left, Symbol::Pow),
217 Symbol::Log => combine_binary_expressions(known_left, rhs_value, Symbol::Pow),
218 _ => return None,
219 })
220}
221
222fn append_unary_expression(base: &Expression, op: Symbol) -> Expression {
223 let mut out = base.clone();
224 out.push(op);
225 out
226}
227
228fn combine_binary_expressions(lhs: &Expression, rhs: &Expression, op: Symbol) -> Expression {
229 let mut out = Expression::new();
230 for &sym in lhs.symbols() {
231 out.push(sym);
232 }
233 for &sym in rhs.symbols() {
234 out.push(sym);
235 }
236 out.push(op);
237 out
238}
239
240fn constant_expression(sym: Symbol) -> Expression {
241 let mut out = Expression::new();
242 out.push(sym);
243 out
244}
245
246fn is_x_atom(expression: &Expression) -> bool {
247 expression.len() == 1
248 && expression
249 .symbols()
250 .first()
251 .is_some_and(|sym| *sym == Symbol::X)
252}
253
254fn canonical_node_key(node: &SolveNode) -> String {
255 match &node.kind {
256 SolveNodeKind::Atom => node.expr.to_postfix(),
257 SolveNodeKind::Unary(op, child) => {
258 format!("{}({})", symbol_key(*op), canonical_node_key(child))
259 }
260 SolveNodeKind::Binary(op, left, right) => {
261 let mut lk = canonical_node_key(left);
262 let mut rk = canonical_node_key(right);
263 if matches!(op, Symbol::Add | Symbol::Mul) && lk > rk {
265 std::mem::swap(&mut lk, &mut rk);
266 }
267 format!("({}{}{})", lk, symbol_key(*op), rk)
268 }
269 }
270}
271
272fn symbol_key(sym: Symbol) -> String {
273 let byte = sym as u8;
274 if byte.is_ascii_graphic() {
275 (byte as char).to_string()
276 } else {
277 format!("#{}", byte)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 fn expr(s: &str) -> Expression {
287 Expression::parse(s).expect("valid expression")
288 }
289
290 #[test]
293 fn test_solve_simple_addition() {
294 let lhs = expr("x1+");
296 let rhs = expr("2");
297 let result = solve_for_x_rhs_expression(&lhs, &rhs);
298 assert!(result.is_some());
299 let solved = result.unwrap();
300 assert_eq!(solved.to_postfix(), "21-");
302 }
303
304 #[test]
305 fn test_solve_simple_subtraction() {
306 let lhs = expr("x1-");
308 let rhs = expr("2");
309 let result = solve_for_x_rhs_expression(&lhs, &rhs);
310 assert!(result.is_some());
311 let solved = result.unwrap();
312 assert_eq!(solved.to_postfix(), "21+");
313 }
314
315 #[test]
316 fn test_solve_simple_multiplication() {
317 let lhs = expr("2x*");
319 let rhs = expr("6");
320 let result = solve_for_x_rhs_expression(&lhs, &rhs);
321 assert!(result.is_some());
322 let solved = result.unwrap();
323 assert_eq!(solved.to_postfix(), "62/");
324 }
325
326 #[test]
327 fn test_solve_simple_division() {
328 let lhs = expr("x2/");
330 let rhs = expr("3");
331 let result = solve_for_x_rhs_expression(&lhs, &rhs);
332 assert!(result.is_some());
333 let solved = result.unwrap();
334 assert_eq!(solved.to_postfix(), "32*");
335 }
336
337 #[test]
338 fn test_solve_square() {
339 let lhs = expr("xs");
341 let rhs = expr("4");
342 let result = solve_for_x_rhs_expression(&lhs, &rhs);
343 assert!(result.is_some());
344 let solved = result.unwrap();
345 assert_eq!(solved.to_postfix(), "4q");
346 }
347
348 #[test]
349 fn test_solve_sqrt() {
350 let lhs = expr("xq");
352 let rhs = expr("4");
353 let result = solve_for_x_rhs_expression(&lhs, &rhs);
354 assert!(result.is_some());
355 let solved = result.unwrap();
356 assert_eq!(solved.to_postfix(), "4s");
357 }
358
359 #[test]
360 fn test_solve_negation() {
361 let lhs = expr("xn");
363 let rhs = expr("3");
364 let result = solve_for_x_rhs_expression(&lhs, &rhs);
365 assert!(result.is_some());
366 let solved = result.unwrap();
367 assert_eq!(solved.to_postfix(), "3n");
368 }
369
370 #[test]
371 fn test_solve_reciprocal() {
372 let lhs = expr("xr");
374 let rhs = expr("2");
375 let result = solve_for_x_rhs_expression(&lhs, &rhs);
376 assert!(result.is_some());
377 let solved = result.unwrap();
378 assert_eq!(solved.to_postfix(), "2r");
379 }
380
381 #[test]
382 fn test_solve_ln() {
383 let lhs = expr("xl"); let rhs = expr("2");
386 let result = solve_for_x_rhs_expression(&lhs, &rhs);
387 assert!(result.is_some());
388 let solved = result.unwrap();
389 assert!(solved.to_postfix().contains('2'));
391 assert!(solved.to_postfix().contains('E'));
392 }
393
394 #[test]
395 fn test_solve_exp() {
396 let lhs = expr("xE"); let rhs = expr("2");
399 let result = solve_for_x_rhs_expression(&lhs, &rhs);
400 assert!(result.is_some());
401 let solved = result.unwrap();
402 assert!(solved.to_postfix().starts_with('2'));
403 assert!(solved.to_postfix().contains('l')); }
405
406 #[test]
407 fn test_solve_nested_expression() {
408 let lhs = expr("x1+2*");
410 let rhs = expr("6");
411 let result = solve_for_x_rhs_expression(&lhs, &rhs);
412 assert!(result.is_some());
413 let solved = result.unwrap();
415 assert_eq!(solved.to_postfix(), "62/1-");
416 }
417
418 #[test]
419 fn test_solve_x_on_right_side() {
420 let lhs = expr("1x+");
422 let rhs = expr("3");
423 let result = solve_for_x_rhs_expression(&lhs, &rhs);
424 assert!(result.is_some());
425 let solved = result.unwrap();
426 assert_eq!(solved.to_postfix(), "31-");
427 }
428
429 #[test]
430 fn test_solve_fails_multiple_x() {
431 let lhs = expr("xx*");
433 let rhs = expr("4");
434 let result = solve_for_x_rhs_expression(&lhs, &rhs);
435 assert!(
436 result.is_none(),
437 "Expected None for expression with multiple x"
438 );
439 }
440
441 #[test]
442 fn test_solve_fails_no_x() {
443 let lhs = expr("23+");
445 let rhs = expr("5");
446 let result = solve_for_x_rhs_expression(&lhs, &rhs);
447 assert!(result.is_none(), "Expected None for expression with no x");
448 }
449
450 #[test]
451 fn test_solve_trig_functions() {
452 let lhs = expr("xs");
454 let rhs = expr("5"); let result = solve_for_x_rhs_expression(&lhs, &rhs);
456 assert!(result.is_some());
459 }
460
461 #[test]
462 fn test_solve_power() {
463 let lhs = expr("x2^");
465 let rhs = expr("4");
466 let result = solve_for_x_rhs_expression(&lhs, &rhs);
467 assert!(result.is_some());
468 }
469
470 #[test]
471 fn test_solve_right_operand_x() {
472 let lhs = expr("2x^");
474 let rhs = expr("8");
475 let result = solve_for_x_rhs_expression(&lhs, &rhs);
476 assert!(result.is_some());
477 }
478
479 #[test]
482 fn test_canonical_key_atom() {
483 let expr1 = expr("x");
484 let key = canonical_expression_key(&expr1);
485 assert!(key.is_some());
486 assert_eq!(key.unwrap(), "x");
487 }
488
489 #[test]
490 fn test_canonical_key_commutativity_addition() {
491 let expr1 = expr("x1+");
493 let expr2 = expr("1x+");
494 let key1 = canonical_expression_key(&expr1);
495 let key2 = canonical_expression_key(&expr2);
496 assert_eq!(key1, key2, "x+1 and 1+x should have same canonical key");
497 }
498
499 #[test]
500 fn test_canonical_key_commutativity_multiplication() {
501 let expr1 = expr("x2*");
503 let expr2 = expr("2x*");
504 let key1 = canonical_expression_key(&expr1);
505 let key2 = canonical_expression_key(&expr2);
506 assert_eq!(key1, key2, "x*2 and 2*x should have same canonical key");
507 }
508
509 #[test]
510 fn test_canonical_key_non_commutative() {
511 let expr1 = expr("x1-");
513 let expr2 = expr("1x-");
514 let key1 = canonical_expression_key(&expr1);
515 let key2 = canonical_expression_key(&expr2);
516 assert_ne!(
517 key1, key2,
518 "x-1 and 1-x should have different canonical keys"
519 );
520 }
521
522 #[test]
523 fn test_canonical_key_nested() {
524 let expr1 = expr("x1+2*");
526 let expr2 = expr("1x+2*");
527 let key1 = canonical_expression_key(&expr1);
528 let key2 = canonical_expression_key(&expr2);
529 assert_eq!(key1, key2, "nested commutative expressions should match");
530 }
531
532 #[test]
535 fn test_unary_inverse_negation() {
536 let rhs = expr("3");
538 let result = unary_inverse_expression(Symbol::Neg, &rhs);
539 assert!(result.is_some());
540 assert_eq!(result.unwrap().to_postfix(), "3n");
541 }
542
543 #[test]
544 fn test_unary_inverse_reciprocal() {
545 let rhs = expr("3");
547 let result = unary_inverse_expression(Symbol::Recip, &rhs);
548 assert!(result.is_some());
549 assert_eq!(result.unwrap().to_postfix(), "3r");
550 }
551
552 #[test]
553 fn test_unary_inverse_square_sqrt() {
554 let rhs = expr("4");
556 let result = unary_inverse_expression(Symbol::Square, &rhs);
557 assert!(result.is_some());
558 assert_eq!(result.unwrap().to_postfix(), "4q");
559
560 let result = unary_inverse_expression(Symbol::Sqrt, &rhs);
562 assert!(result.is_some());
563 assert_eq!(result.unwrap().to_postfix(), "4s");
564 }
565
566 #[test]
567 fn test_unary_inverse_ln_exp() {
568 let rhs = expr("2");
570 let result = unary_inverse_expression(Symbol::Ln, &rhs);
571 assert!(result.is_some());
572
573 let result = unary_inverse_expression(Symbol::Exp, &rhs);
575 assert!(result.is_some());
576 }
577
578 #[test]
581 fn test_binary_inverse_add_left() {
582 let rhs = expr("5");
584 let known = expr("2");
585 let result = invert_binary_left(Symbol::Add, &rhs, &known);
586 assert!(result.is_some());
587 assert_eq!(result.unwrap().to_postfix(), "52-");
588 }
589
590 #[test]
591 fn test_binary_inverse_sub_left() {
592 let rhs = expr("3");
594 let known = expr("2");
595 let result = invert_binary_left(Symbol::Sub, &rhs, &known);
596 assert!(result.is_some());
597 assert_eq!(result.unwrap().to_postfix(), "32+");
598 }
599
600 #[test]
601 fn test_binary_inverse_mul_left() {
602 let rhs = expr("6");
604 let known = expr("2");
605 let result = invert_binary_left(Symbol::Mul, &rhs, &known);
606 assert!(result.is_some());
607 assert_eq!(result.unwrap().to_postfix(), "62/");
608 }
609
610 #[test]
611 fn test_binary_inverse_div_left() {
612 let rhs = expr("3");
614 let known = expr("2");
615 let result = invert_binary_left(Symbol::Div, &rhs, &known);
616 assert!(result.is_some());
617 assert_eq!(result.unwrap().to_postfix(), "32*");
618 }
619
620 #[test]
621 fn test_binary_inverse_sub_right() {
622 let known = expr("5");
624 let rhs = expr("2");
625 let result = invert_binary_right(Symbol::Sub, &known, &rhs);
626 assert!(result.is_some());
627 assert_eq!(result.unwrap().to_postfix(), "52-");
628 }
629
630 #[test]
633 fn test_solve_empty_expression() {
634 let lhs = Expression::new();
635 let rhs = expr("1");
636 let result = solve_for_x_rhs_expression(&lhs, &rhs);
637 assert!(result.is_none());
638 }
639
640 #[test]
641 fn test_canonical_empty_expression() {
642 let empty = Expression::new();
643 let result = canonical_expression_key(&empty);
644 assert!(result.is_none());
645 }
646
647 #[test]
648 fn test_build_solve_ast_malformed() {
649 let malformed = Expression::from_symbols(&[Symbol::Add]);
651 let result = build_solve_ast(&malformed);
652 assert!(result.is_none(), "Malformed expression should return None");
653 }
654
655 #[test]
656 fn test_build_solve_ast_incomplete() {
657 let incomplete = Expression::from_symbols(&[Symbol::One, Symbol::Two, Symbol::Three]);
659 let result = build_solve_ast(&incomplete);
660 assert!(result.is_none(), "Incomplete expression should return None");
661 }
662
663 #[test]
664 fn test_is_x_atom() {
665 let x_expr = expr("x");
666 assert!(is_x_atom(&x_expr));
667
668 let not_x = expr("1");
669 assert!(!is_x_atom(¬_x));
670
671 let complex = expr("x1+");
672 assert!(!is_x_atom(&complex));
673 }
674}