use crate::ast::{BinaryOp, Equation, Expression, UnaryOp, Variable};
use crate::resolution_path::{Operation, ResolutionPath, ResolutionPathBuilder, ResolutionStep};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum SolverError {
NoSolution,
InfiniteSolutions,
CannotSolve(String),
UnsupportedEquationType,
DivisionByZero,
Other(String),
}
impl std::fmt::Display for SolverError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SolverError::NoSolution => write!(f, "Equation has no solution"),
SolverError::InfiniteSolutions => write!(f, "Equation has infinite solutions"),
SolverError::CannotSolve(msg) => write!(f, "Cannot solve: {}", msg),
SolverError::UnsupportedEquationType => write!(f, "Equation type is not supported"),
SolverError::DivisionByZero => write!(f, "Division by zero encountered"),
SolverError::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for SolverError {}
pub type SolverResult<T> = Result<T, SolverError>;
#[derive(Debug, Clone, PartialEq)]
pub enum Solution {
Unique(Expression),
Multiple(Vec<Expression>),
Parametric {
expression: Expression,
constraints: Vec<Constraint>,
},
None,
Infinite,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Constraint {
pub variable: Variable,
pub condition: Expression,
}
pub trait Solver {
fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> SolverResult<(Solution, ResolutionPath)>;
fn can_solve(&self, equation: &Equation) -> bool;
}
fn contains_variable(expr: &Expression, var: &str) -> bool {
expr.contains_variable(var)
}
fn extract_coefficient(expr: &Expression, var: &str) -> Option<Expression> {
match expr {
Expression::Variable(v) if v.name == var => Some(Expression::Integer(1)),
Expression::Binary(BinaryOp::Mul, left, right) => {
if let Expression::Variable(v) = left.as_ref() {
if v.name == var && !contains_variable(right, var) {
return Some(right.as_ref().clone());
}
}
if let Expression::Variable(v) = right.as_ref() {
if v.name == var && !contains_variable(left, var) {
return Some(left.as_ref().clone());
}
}
None
}
_ => None,
}
}
fn evaluate_constants(expr: &Expression) -> Expression {
let simplified = expr.simplify();
if !has_any_variable(&simplified) {
if let Some(value) = simplified.evaluate(&HashMap::new()) {
if value.fract().abs() < 1e-10 {
return Expression::Integer(value.round() as i64);
} else {
return Expression::Float(value);
}
}
}
simplified
}
fn isolate_variable(
equation: &Equation,
var: &str,
_path: &mut ResolutionPathBuilder,
) -> Result<Expression, SolverError> {
let left = &equation.left;
let right = &equation.right;
if !contains_variable(left, var) && !contains_variable(right, var) {
return Err(SolverError::CannotSolve(format!(
"Variable '{}' not found in equation",
var
)));
}
if let Expression::Variable(v) = left {
if v.name == var && !contains_variable(right, var) {
return Ok(right.clone());
}
}
if let Expression::Variable(v) = right {
if v.name == var && !contains_variable(left, var) {
return Ok(left.clone());
}
}
if let Some(coeff) = extract_coefficient(left, var) {
if !contains_variable(right, var) {
let result = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(coeff.clone()),
)
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
if let Some(coeff) = extract_coefficient(right, var) {
if !contains_variable(left, var) {
let result = Expression::Binary(
BinaryOp::Div,
Box::new(left.clone()),
Box::new(coeff.clone()),
)
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
if let Expression::Binary(BinaryOp::Add, l, r) = left {
if let Expression::Variable(v) = l.as_ref() {
if v.name == var && !contains_variable(r, var) && !contains_variable(right, var) {
let result = Expression::Binary(
BinaryOp::Sub,
Box::new(right.clone()),
Box::new(r.as_ref().clone()),
)
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
if let Expression::Variable(v) = r.as_ref() {
if v.name == var && !contains_variable(l, var) && !contains_variable(right, var) {
let result = Expression::Binary(
BinaryOp::Sub,
Box::new(right.clone()),
Box::new(l.as_ref().clone()),
)
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
}
if let Expression::Binary(BinaryOp::Add, l, r) = right {
if let Expression::Variable(v) = l.as_ref() {
if v.name == var && !contains_variable(r, var) && !contains_variable(left, var) {
let result = Expression::Binary(
BinaryOp::Sub,
Box::new(left.clone()),
Box::new(r.as_ref().clone()),
)
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
if let Expression::Variable(v) = r.as_ref() {
if v.name == var && !contains_variable(l, var) && !contains_variable(left, var) {
let result = Expression::Binary(
BinaryOp::Sub,
Box::new(left.clone()),
Box::new(l.as_ref().clone()),
)
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
}
if let Expression::Binary(BinaryOp::Add, l, r) = left {
if let Some(coeff) = extract_coefficient(l, var) {
if !contains_variable(r, var) && !contains_variable(right, var) {
let numerator = Expression::Binary(
BinaryOp::Sub,
Box::new(right.clone()),
Box::new(r.as_ref().clone()),
);
let result =
Expression::Binary(BinaryOp::Div, Box::new(numerator), Box::new(coeff))
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
if let Some(coeff) = extract_coefficient(r, var) {
if !contains_variable(l, var) && !contains_variable(right, var) {
let numerator = Expression::Binary(
BinaryOp::Sub,
Box::new(right.clone()),
Box::new(l.as_ref().clone()),
);
let result =
Expression::Binary(BinaryOp::Div, Box::new(numerator), Box::new(coeff))
.simplify();
let evaluated = evaluate_constants(&result);
return Ok(evaluated);
}
}
}
Err(SolverError::CannotSolve(
"Equation pattern not yet supported for Phase 1".to_string(),
))
}
#[derive(Debug, Default)]
pub struct LinearSolver;
impl LinearSolver {
pub fn new() -> Self {
Self
}
}
impl Solver for LinearSolver {
fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> SolverResult<(Solution, ResolutionPath)> {
let var_name = &variable.name;
let initial_expr = Expression::Binary(
BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let mut path = ResolutionPathBuilder::new(initial_expr.clone());
let left_has_var = contains_variable(&equation.left, var_name);
let right_has_var = contains_variable(&equation.right, var_name);
if !left_has_var && !right_has_var {
return Err(SolverError::CannotSolve(format!(
"Variable '{}' not found in equation",
var_name
)));
}
if !is_linear_in_variable(&equation.left, var_name)
|| !is_linear_in_variable(&equation.right, var_name)
{
return Err(SolverError::UnsupportedEquationType);
}
let result_expr = isolate_variable(equation, var_name, &mut path)?;
path = path.step(
Operation::Isolate(variable.clone()),
format!("Isolate {} on one side", variable),
result_expr.clone(),
);
let resolution_path = path.finish(result_expr.clone());
Ok((Solution::Unique(result_expr), resolution_path))
}
fn can_solve(&self, equation: &Equation) -> bool {
!has_obvious_nonlinearity(&equation.left) && !has_obvious_nonlinearity(&equation.right)
}
}
fn has_obvious_nonlinearity(expr: &Expression) -> bool {
match expr {
Expression::Power(base, exp) => {
if has_any_variable(base) {
if let Some(exp_val) = exp.evaluate(&HashMap::new()) {
if exp_val > 1.0 {
return true;
}
}
}
has_obvious_nonlinearity(base) || has_obvious_nonlinearity(exp)
}
Expression::Unary(_, inner) => has_obvious_nonlinearity(inner),
Expression::Binary(_, left, right) => {
has_obvious_nonlinearity(left) || has_obvious_nonlinearity(right)
}
Expression::Function(_, args) => args.iter().any(|arg| has_obvious_nonlinearity(arg)),
_ => false,
}
}
fn has_any_variable(expr: &Expression) -> bool {
match expr {
Expression::Variable(_) => true,
Expression::Unary(_, inner) => has_any_variable(inner),
Expression::Binary(_, left, right) => has_any_variable(left) || has_any_variable(right),
Expression::Function(_, args) => args.iter().any(has_any_variable),
Expression::Power(base, exp) => has_any_variable(base) || has_any_variable(exp),
_ => false,
}
}
fn extract_quadratic_coefficients(expr: &Expression, var: &str) -> (f64, f64, f64) {
let mut a = 0.0;
let mut b = 0.0;
let mut c = 0.0;
extract_poly_coefficients_recursive(expr, var, 1.0, &mut a, &mut b, &mut c);
(a, b, c)
}
fn extract_poly_coefficients_recursive(
expr: &Expression,
var: &str,
multiplier: f64,
a: &mut f64,
b: &mut f64,
c: &mut f64,
) {
match expr {
Expression::Integer(n) => *c += (*n as f64) * multiplier,
Expression::Float(f) => *c += f * multiplier,
Expression::Rational(r) => *c += (*r.numer() as f64 / *r.denom() as f64) * multiplier,
Expression::Variable(v) if v.name == var => *b += multiplier,
Expression::Variable(_) | Expression::Constant(_) => {
if let Some(val) = expr.evaluate(&std::collections::HashMap::new()) {
*c += val * multiplier;
}
}
Expression::Unary(UnaryOp::Neg, inner) => {
extract_poly_coefficients_recursive(inner, var, -multiplier, a, b, c);
}
Expression::Binary(BinaryOp::Add, left, right) => {
extract_poly_coefficients_recursive(left, var, multiplier, a, b, c);
extract_poly_coefficients_recursive(right, var, multiplier, a, b, c);
}
Expression::Binary(BinaryOp::Sub, left, right) => {
extract_poly_coefficients_recursive(left, var, multiplier, a, b, c);
extract_poly_coefficients_recursive(right, var, -multiplier, a, b, c);
}
Expression::Binary(BinaryOp::Mul, left, right) => {
let left_val = left.evaluate(&std::collections::HashMap::new());
let right_val = right.evaluate(&std::collections::HashMap::new());
match (left_val, right_val) {
(Some(lv), None) => {
extract_poly_coefficients_recursive(right, var, multiplier * lv, a, b, c);
}
(None, Some(rv)) => {
extract_poly_coefficients_recursive(left, var, multiplier * rv, a, b, c);
}
(Some(lv), Some(rv)) => {
*c += lv * rv * multiplier;
}
(None, None) => {
if matches!(&**left, Expression::Variable(v) if v.name == var)
&& matches!(&**right, Expression::Variable(v) if v.name == var)
{
*a += multiplier;
} else if matches!(&**left, Expression::Variable(v) if v.name == var) {
if let Some(rv) = right.evaluate(&std::collections::HashMap::new()) {
*b += multiplier * rv;
}
} else if matches!(&**right, Expression::Variable(v) if v.name == var) {
if let Some(lv) = left.evaluate(&std::collections::HashMap::new()) {
*b += multiplier * lv;
}
}
}
}
}
Expression::Power(base, exp) => {
if matches!(&**base, Expression::Variable(v) if v.name == var) {
if let Some(exp_val) = exp.evaluate(&std::collections::HashMap::new()) {
if (exp_val - 2.0).abs() < 1e-10 {
*a += multiplier;
} else if (exp_val - 1.0).abs() < 1e-10 {
*b += multiplier;
} else if exp_val.abs() < 1e-10 {
*c += multiplier;
}
}
}
}
_ => {
if let Some(val) = expr.evaluate(&std::collections::HashMap::new()) {
*c += val * multiplier;
}
}
}
}
fn simplify_numeric_expression(val: f64) -> Expression {
let rounded = val.round();
if (val - rounded).abs() < 1e-10 && rounded.abs() < i64::MAX as f64 {
Expression::Integer(rounded as i64)
} else {
Expression::Float(val)
}
}
fn extract_polynomial_coefficients(expr: &Expression, var: &str, max_degree: usize) -> Vec<f64> {
let mut coeffs = vec![0.0; max_degree + 1];
extract_general_poly_coefficients(expr, var, 1.0, &mut coeffs);
coeffs
}
fn extract_general_poly_coefficients(
expr: &Expression,
var: &str,
multiplier: f64,
coeffs: &mut [f64],
) {
match expr {
Expression::Integer(n) => coeffs[0] += (*n as f64) * multiplier,
Expression::Float(f) => coeffs[0] += f * multiplier,
Expression::Rational(r) => {
coeffs[0] += (*r.numer() as f64 / *r.denom() as f64) * multiplier
}
Expression::Variable(v) if v.name == var => {
if coeffs.len() > 1 {
coeffs[1] += multiplier;
}
}
Expression::Variable(_) | Expression::Constant(_) => {
if let Some(val) = expr.evaluate(&std::collections::HashMap::new()) {
coeffs[0] += val * multiplier;
}
}
Expression::Unary(UnaryOp::Neg, inner) => {
extract_general_poly_coefficients(inner, var, -multiplier, coeffs);
}
Expression::Binary(BinaryOp::Add, left, right) => {
extract_general_poly_coefficients(left, var, multiplier, coeffs);
extract_general_poly_coefficients(right, var, multiplier, coeffs);
}
Expression::Binary(BinaryOp::Sub, left, right) => {
extract_general_poly_coefficients(left, var, multiplier, coeffs);
extract_general_poly_coefficients(right, var, -multiplier, coeffs);
}
Expression::Binary(BinaryOp::Mul, left, right) => {
let left_val = left.evaluate(&std::collections::HashMap::new());
let right_val = right.evaluate(&std::collections::HashMap::new());
match (left_val, right_val) {
(Some(lv), None) => {
extract_general_poly_coefficients(right, var, multiplier * lv, coeffs);
}
(None, Some(rv)) => {
extract_general_poly_coefficients(left, var, multiplier * rv, coeffs);
}
(Some(lv), Some(rv)) => {
coeffs[0] += lv * rv * multiplier;
}
(None, None) => {
if matches!(&**left, Expression::Variable(v) if v.name == var)
&& matches!(&**right, Expression::Variable(v) if v.name == var)
{
if coeffs.len() > 2 {
coeffs[2] += multiplier;
}
}
}
}
}
Expression::Power(base, exp) => {
if matches!(&**base, Expression::Variable(v) if v.name == var) {
if let Some(exp_val) = exp.evaluate(&std::collections::HashMap::new()) {
let degree = exp_val.round() as usize;
if degree < coeffs.len() {
coeffs[degree] += multiplier;
}
}
}
}
_ => {
if let Some(val) = expr.evaluate(&std::collections::HashMap::new()) {
coeffs[0] += val * multiplier;
}
}
}
}
fn get_polynomial_degree(expr: &Expression, var: &str) -> usize {
match expr {
Expression::Integer(_)
| Expression::Rational(_)
| Expression::Float(_)
| Expression::Complex(_)
| Expression::Constant(_) => 0,
Expression::Variable(v) if v.name == var => 1,
Expression::Variable(_) => 0,
Expression::Unary(UnaryOp::Neg, inner) => get_polynomial_degree(inner, var),
Expression::Binary(BinaryOp::Add | BinaryOp::Sub, left, right) => {
get_polynomial_degree(left, var).max(get_polynomial_degree(right, var))
}
Expression::Binary(BinaryOp::Mul, left, right) => {
get_polynomial_degree(left, var) + get_polynomial_degree(right, var)
}
Expression::Binary(BinaryOp::Div, left, right) => {
if contains_variable(right, var) {
0 } else {
get_polynomial_degree(left, var)
}
}
Expression::Power(base, exp) => {
if let Expression::Variable(v) = base.as_ref() {
if v.name == var {
if let Some(exp_val) = exp.evaluate(&HashMap::new()) {
if exp_val >= 0.0 && (exp_val - exp_val.round()).abs() < 1e-10 {
return exp_val.round() as usize;
}
}
}
}
let base_deg = get_polynomial_degree(base, var);
if base_deg == 0 {
0
} else if let Some(exp_val) = exp.evaluate(&HashMap::new()) {
if exp_val >= 0.0 && (exp_val - exp_val.round()).abs() < 1e-10 {
base_deg * (exp_val.round() as usize)
} else {
0
}
} else {
0
}
}
_ => 0,
}
}
fn is_polynomial_expression(expr: &Expression) -> bool {
match expr {
Expression::Integer(_)
| Expression::Rational(_)
| Expression::Float(_)
| Expression::Complex(_)
| Expression::Constant(_)
| Expression::Variable(_) => true,
Expression::Unary(_, inner) => is_polynomial_expression(inner),
Expression::Binary(_, left, right) => {
is_polynomial_expression(left) && is_polynomial_expression(right)
}
Expression::Power(base, exp) => {
if !is_polynomial_expression(base) {
return false;
}
if let Some(exp_val) = exp.evaluate(&HashMap::new()) {
exp_val >= 0.0 && (exp_val - exp_val.round()).abs() < 1e-10
} else {
is_polynomial_expression(exp)
}
}
Expression::Function(_, _) => false, }
}
fn solve_cubic(
coeffs: &[f64],
_var: &str,
mut path: ResolutionPathBuilder,
) -> SolverResult<(Solution, ResolutionPath)> {
if coeffs.len() < 4 {
return Err(SolverError::CannotSolve(
"Not a cubic polynomial".to_string(),
));
}
let d = coeffs[0];
let c = coeffs[1];
let b = coeffs[2];
let a = coeffs[3];
if a.abs() < 1e-15 {
return Err(SolverError::CannotSolve(
"Leading coefficient is zero".to_string(),
));
}
let p = b / a;
let q = c / a;
let r = d / a;
path = path.step(
Operation::Simplify,
format!("Normalized cubic: x³ + {}x² + {}x + {} = 0", p, q, r),
Expression::Integer(0),
);
let dep_p = q - p * p / 3.0;
let dep_q = r - p * q / 3.0 + 2.0 * p * p * p / 27.0;
path = path.step(
Operation::Simplify,
format!("Depressed cubic: t³ + {}t + {} = 0", dep_p, dep_q),
Expression::Integer(0),
);
let discriminant = -4.0 * dep_p * dep_p * dep_p - 27.0 * dep_q * dep_q;
path = path.step(
Operation::Simplify,
format!("Discriminant: Δ = {}", discriminant),
Expression::Integer(0),
);
let shift = -p / 3.0;
let roots: Vec<Expression>;
if discriminant.abs() < 1e-10 {
if dep_p.abs() < 1e-10 && dep_q.abs() < 1e-10 {
let root = simplify_numeric_expression(shift);
roots = vec![root.clone(), root.clone(), root];
} else {
let t1 = 3.0 * dep_q / dep_p;
let t2 = -3.0 * dep_q / (2.0 * dep_p);
roots = vec![
simplify_numeric_expression(t1 + shift),
simplify_numeric_expression(t2 + shift),
simplify_numeric_expression(t2 + shift),
];
}
} else if discriminant > 0.0 {
let m = 2.0 * (-dep_p / 3.0).sqrt();
let theta = (3.0 * dep_q / (dep_p * m)).acos() / 3.0;
let t1 = m * theta.cos();
let t2 = m * (theta - 2.0 * std::f64::consts::PI / 3.0).cos();
let t3 = m * (theta - 4.0 * std::f64::consts::PI / 3.0).cos();
roots = vec![
simplify_numeric_expression(t1 + shift),
simplify_numeric_expression(t2 + shift),
simplify_numeric_expression(t3 + shift),
];
} else {
let sqrt_term = (dep_q * dep_q / 4.0 + dep_p * dep_p * dep_p / 27.0).sqrt();
let u = (-dep_q / 2.0 + sqrt_term).cbrt();
let v = (-dep_q / 2.0 - sqrt_term).cbrt();
let t_real = u + v;
let real_part = -0.5 * (u + v) + shift;
let imag_part = (3.0_f64).sqrt() / 2.0 * (u - v);
roots = vec![
simplify_numeric_expression(t_real + shift),
Expression::Complex(num_complex::Complex64::new(real_part, imag_part)),
Expression::Complex(num_complex::Complex64::new(real_part, -imag_part)),
];
}
path = path.step(
Operation::Simplify,
"Applied Cardano's formula".to_string(),
roots[0].clone(),
);
let resolution_path = path.finish(roots[0].clone());
Ok((Solution::Multiple(roots), resolution_path))
}
fn solve_quartic(
coeffs: &[f64],
_var: &str,
mut path: ResolutionPathBuilder,
) -> SolverResult<(Solution, ResolutionPath)> {
if coeffs.len() < 5 {
return Err(SolverError::CannotSolve(
"Not a quartic polynomial".to_string(),
));
}
let e = coeffs[0];
let d = coeffs[1];
let c = coeffs[2];
let b = coeffs[3];
let a = coeffs[4];
if a.abs() < 1e-15 {
return Err(SolverError::CannotSolve(
"Leading coefficient is zero".to_string(),
));
}
let p = b / a;
let q = c / a;
let r = d / a;
let s = e / a;
path = path.step(
Operation::Simplify,
format!(
"Normalized quartic: x⁴ + {}x³ + {}x² + {}x + {} = 0",
p, q, r, s
),
Expression::Integer(0),
);
let alpha = q - 3.0 * p * p / 8.0;
let beta = r - p * q / 2.0 + p * p * p / 8.0;
let gamma = s - p * r / 4.0 + p * p * q / 16.0 - 3.0 * p * p * p * p / 256.0;
path = path.step(
Operation::Simplify,
format!(
"Depressed quartic: y⁴ + {}y² + {}y + {} = 0",
alpha, beta, gamma
),
Expression::Integer(0),
);
let shift = -p / 4.0;
if beta.abs() < 1e-15 {
let disc = alpha * alpha - 4.0 * gamma;
if disc < -1e-15 {
let u1_real = -alpha / 2.0;
let u1_imag = (-disc).sqrt() / 2.0;
let mut roots = Vec::new();
for sign1 in [-1.0, 1.0] {
let u_real = u1_real;
let u_imag = sign1 * u1_imag;
let r = (u_real * u_real + u_imag * u_imag).sqrt();
let sqrt_real = ((r + u_real) / 2.0).sqrt();
let sqrt_imag = u_imag.signum() * ((r - u_real) / 2.0).sqrt();
roots.push(Expression::Complex(num_complex::Complex64::new(
sqrt_real + shift,
sqrt_imag,
)));
roots.push(Expression::Complex(num_complex::Complex64::new(
-sqrt_real + shift,
-sqrt_imag,
)));
}
let resolution_path = path.finish(roots[0].clone());
return Ok((Solution::Multiple(roots), resolution_path));
} else {
let u1 = (-alpha + disc.sqrt()) / 2.0;
let u2 = (-alpha - disc.sqrt()) / 2.0;
let mut roots = Vec::new();
for u in [u1, u2] {
if u >= 0.0 {
roots.push(simplify_numeric_expression(u.sqrt() + shift));
roots.push(simplify_numeric_expression(-u.sqrt() + shift));
} else {
let imag = (-u).sqrt();
roots.push(Expression::Complex(num_complex::Complex64::new(
shift, imag,
)));
roots.push(Expression::Complex(num_complex::Complex64::new(
shift, -imag,
)));
}
}
let resolution_path = path.finish(roots[0].clone());
return Ok((Solution::Multiple(roots), resolution_path));
}
}
let resolvent_coeffs = vec![
-beta * beta / 64.0,
(alpha * alpha - 4.0 * gamma) / 16.0,
alpha / 2.0,
1.0,
];
let dep_p = resolvent_coeffs[1] - resolvent_coeffs[2] * resolvent_coeffs[2] / 3.0;
let dep_q = resolvent_coeffs[0] - resolvent_coeffs[2] * resolvent_coeffs[1] / 3.0
+ 2.0 * resolvent_coeffs[2] * resolvent_coeffs[2] * resolvent_coeffs[2] / 27.0;
let disc_cubic = -4.0 * dep_p * dep_p * dep_p - 27.0 * dep_q * dep_q;
let m: f64;
if disc_cubic > 1e-10 {
let sqrt_term = 2.0 * (-dep_p / 3.0).sqrt();
let theta = (3.0 * dep_q / (dep_p * sqrt_term)).acos() / 3.0;
m = sqrt_term * theta.cos() - resolvent_coeffs[2] / 3.0;
} else {
let sqrt_term = (dep_q * dep_q / 4.0 + dep_p * dep_p * dep_p / 27.0)
.abs()
.sqrt();
let sign = if dep_q < 0.0 { 1.0 } else { -1.0 };
let u = (sign * sqrt_term - dep_q / 2.0).abs().cbrt()
* (sign * sqrt_term - dep_q / 2.0).signum();
let v = if u.abs() > 1e-10 {
-dep_p / (3.0 * u)
} else {
0.0
};
m = u + v - resolvent_coeffs[2] / 3.0;
}
path = path.step(
Operation::Simplify,
format!("Resolvent cubic root: m = {}", m),
Expression::Integer(0),
);
let sqrt_2m_alpha = (2.0 * m + alpha).max(0.0).sqrt();
let term = if sqrt_2m_alpha.abs() > 1e-10 {
beta / (2.0 * sqrt_2m_alpha)
} else {
0.0
};
let mut roots = Vec::new();
let a1 = 1.0;
let b1 = sqrt_2m_alpha;
let c1 = m + term;
let disc1 = b1 * b1 - 4.0 * a1 * c1;
if disc1 >= 0.0 {
roots.push(simplify_numeric_expression(
(-b1 + disc1.sqrt()) / 2.0 + shift,
));
roots.push(simplify_numeric_expression(
(-b1 - disc1.sqrt()) / 2.0 + shift,
));
} else {
let real = -b1 / 2.0 + shift;
let imag = (-disc1).sqrt() / 2.0;
roots.push(Expression::Complex(num_complex::Complex64::new(real, imag)));
roots.push(Expression::Complex(num_complex::Complex64::new(
real, -imag,
)));
}
let b2 = -sqrt_2m_alpha;
let c2 = m - term;
let disc2 = b2 * b2 - 4.0 * a1 * c2;
if disc2 >= 0.0 {
roots.push(simplify_numeric_expression(
(-b2 + disc2.sqrt()) / 2.0 + shift,
));
roots.push(simplify_numeric_expression(
(-b2 - disc2.sqrt()) / 2.0 + shift,
));
} else {
let real = -b2 / 2.0 + shift;
let imag = (-disc2).sqrt() / 2.0;
roots.push(Expression::Complex(num_complex::Complex64::new(real, imag)));
roots.push(Expression::Complex(num_complex::Complex64::new(
real, -imag,
)));
}
path = path.step(
Operation::Simplify,
"Applied Ferrari's method".to_string(),
roots[0].clone(),
);
let resolution_path = path.finish(roots[0].clone());
Ok((Solution::Multiple(roots), resolution_path))
}
fn solve_polynomial_numerically(
coeffs: &[f64],
_var: &str,
mut path: ResolutionPathBuilder,
) -> SolverResult<(Solution, ResolutionPath)> {
let degree = coeffs.len() - 1;
if degree < 1 {
return Err(SolverError::CannotSolve("Invalid polynomial".to_string()));
}
let leading = coeffs[degree];
if leading.abs() < 1e-15 {
return Err(SolverError::CannotSolve(
"Leading coefficient is zero".to_string(),
));
}
path = path.step(
Operation::Simplify,
format!(
"Solving degree {} polynomial numerically (Durand-Kerner method)",
degree
),
Expression::Integer(0),
);
let radius = 1.0
+ coeffs
.iter()
.take(degree)
.map(|c| (c / leading).abs())
.fold(0.0, f64::max);
let mut roots: Vec<num_complex::Complex64> = (0..degree)
.map(|k| {
let angle = 2.0 * std::f64::consts::PI * (k as f64) / (degree as f64) + 0.4;
num_complex::Complex64::new(radius * angle.cos(), radius * angle.sin())
})
.collect();
let max_iter = 100;
let tolerance = 1e-12;
for _ in 0..max_iter {
let mut max_change: f64 = 0.0;
for i in 0..degree {
let mut p_val = num_complex::Complex64::new(0.0, 0.0);
let mut power = num_complex::Complex64::new(1.0, 0.0);
for &coeff in coeffs.iter() {
p_val += num_complex::Complex64::new(coeff, 0.0) * power;
power *= roots[i];
}
let mut denom = num_complex::Complex64::new(1.0, 0.0);
for j in 0..degree {
if i != j {
denom *= roots[i] - roots[j];
}
}
if denom.norm() > 1e-15 {
let delta = p_val / denom;
roots[i] -= delta;
max_change = max_change.max(delta.norm());
}
}
if max_change < tolerance {
break;
}
}
let root_exprs: Vec<Expression> = roots
.iter()
.map(|r| {
if r.im.abs() < 1e-10 {
simplify_numeric_expression(r.re)
} else {
Expression::Complex(*r)
}
})
.collect();
path = path.step(
Operation::Simplify,
format!("Found {} roots numerically", degree),
root_exprs[0].clone(),
);
let resolution_path = path.finish(root_exprs[0].clone());
Ok((Solution::Multiple(root_exprs), resolution_path))
}
fn is_linear_in_variable(expr: &Expression, var: &str) -> bool {
match expr {
Expression::Integer(_)
| Expression::Rational(_)
| Expression::Float(_)
| Expression::Complex(_)
| Expression::Constant(_) => true,
Expression::Variable(_v) => {
true
}
Expression::Unary(_, inner) => is_linear_in_variable(inner, var),
Expression::Binary(op, left, right) => {
let left_has_var = contains_variable(left, var);
let right_has_var = contains_variable(right, var);
match op {
BinaryOp::Add | BinaryOp::Sub => {
is_linear_in_variable(left, var) && is_linear_in_variable(right, var)
}
BinaryOp::Mul => {
if left_has_var && right_has_var {
false
} else {
is_linear_in_variable(left, var) && is_linear_in_variable(right, var)
}
}
BinaryOp::Div => {
if right_has_var {
false } else {
is_linear_in_variable(left, var)
}
}
_ => false,
}
}
Expression::Power(base, exp) => {
!contains_variable(base, var) && is_linear_in_variable(exp, var)
}
Expression::Function(_, _) => {
false
}
}
}
#[derive(Debug, Default)]
pub struct QuadraticSolver;
impl QuadraticSolver {
pub fn new() -> Self {
Self
}
}
impl Solver for QuadraticSolver {
fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> SolverResult<(Solution, ResolutionPath)> {
let var_name = &variable.name;
let initial_expr = Expression::Binary(
BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let mut path = ResolutionPathBuilder::new(initial_expr.clone());
if !contains_variable(&equation.left, var_name)
&& !contains_variable(&equation.right, var_name)
{
return Err(SolverError::CannotSolve(format!(
"Variable '{}' not found in equation",
var_name
)));
}
let combined = Expression::Binary(
BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
)
.simplify();
let (a, b, c) = extract_quadratic_coefficients(&combined, var_name);
path = path.step(
Operation::Simplify,
format!("Identified coefficients: a={}, b={}, c={}", a, b, c),
combined.clone(),
);
if a.abs() < 1e-15 {
if b.abs() < 1e-15 {
if c.abs() < 1e-15 {
let resolution_path = path.finish(Expression::Integer(0));
return Ok((Solution::Infinite, resolution_path));
} else {
return Err(SolverError::NoSolution);
}
}
let solution = Expression::Float(-c / b);
let resolution_path = path.finish(solution.clone());
return Ok((Solution::Unique(solution), resolution_path));
}
let discriminant = b * b - 4.0 * a * c;
path = path.step(
Operation::Simplify,
format!("Computed discriminant: Δ = b² - 4ac = {}", discriminant),
combined.clone(),
);
let epsilon = 1e-15;
if discriminant > epsilon {
let sqrt_disc = discriminant.sqrt();
let x1 = (-b + sqrt_disc) / (2.0 * a);
let x2 = (-b - sqrt_disc) / (2.0 * a);
let root1 = simplify_numeric_expression(x1);
let root2 = simplify_numeric_expression(x2);
path = path.step(
Operation::Simplify,
format!("Quadratic formula: x = (-b ± √Δ)/(2a) = {} or {}", x1, x2),
root1.clone(),
);
let resolution_path = path.finish(root1.clone());
Ok((Solution::Multiple(vec![root1, root2]), resolution_path))
} else if discriminant.abs() <= epsilon {
let x = -b / (2.0 * a);
let root = simplify_numeric_expression(x);
path = path.step(
Operation::Simplify,
format!("Quadratic formula (Δ = 0): x = -b/(2a) = {}", x),
root.clone(),
);
let resolution_path = path.finish(root.clone());
Ok((Solution::Unique(root), resolution_path))
} else {
let real_part = -b / (2.0 * a);
let imag_part = (-discriminant).sqrt() / (2.0 * a);
let root1 = Expression::Complex(num_complex::Complex64::new(real_part, imag_part));
let root2 = Expression::Complex(num_complex::Complex64::new(real_part, -imag_part));
path = path.step(
Operation::Simplify,
format!("Complex roots: x = {} ± {}i", real_part, imag_part),
root1.clone(),
);
let resolution_path = path.finish(root1.clone());
Ok((Solution::Multiple(vec![root1, root2]), resolution_path))
}
}
fn can_solve(&self, equation: &Equation) -> bool {
has_obvious_nonlinearity(&equation.left) || has_obvious_nonlinearity(&equation.right)
}
}
#[derive(Debug, Default)]
pub struct PolynomialSolver;
impl PolynomialSolver {
pub fn new() -> Self {
Self
}
}
impl Solver for PolynomialSolver {
fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> SolverResult<(Solution, ResolutionPath)> {
let var_name = &variable.name;
let initial_expr = Expression::Binary(
BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let mut path = ResolutionPathBuilder::new(initial_expr.clone());
if !contains_variable(&equation.left, var_name)
&& !contains_variable(&equation.right, var_name)
{
return Err(SolverError::CannotSolve(format!(
"Variable '{}' not found in equation",
var_name
)));
}
let combined = Expression::Binary(
BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
)
.simplify();
let degree = get_polynomial_degree(&combined, var_name);
path = path.step(
Operation::Simplify,
format!("Identified polynomial of degree {}", degree),
combined.clone(),
);
match degree {
0 => {
if let Some(val) = combined.evaluate(&HashMap::new()) {
if val.abs() < 1e-15 {
let resolution_path = path.finish(Expression::Integer(0));
return Ok((Solution::Infinite, resolution_path));
} else {
return Err(SolverError::NoSolution);
}
}
Err(SolverError::CannotSolve(
"Cannot evaluate constant expression".to_string(),
))
}
1 => {
LinearSolver::new().solve(equation, variable)
}
2 => {
QuadraticSolver::new().solve(equation, variable)
}
3 => {
let coeffs = extract_polynomial_coefficients(&combined, var_name, 3);
solve_cubic(&coeffs, var_name, path)
}
4 => {
let coeffs = extract_polynomial_coefficients(&combined, var_name, 4);
solve_quartic(&coeffs, var_name, path)
}
_ => {
let coeffs = extract_polynomial_coefficients(&combined, var_name, degree);
solve_polynomial_numerically(&coeffs, var_name, path)
}
}
}
fn can_solve(&self, equation: &Equation) -> bool {
is_polynomial_expression(&equation.left) && is_polynomial_expression(&equation.right)
}
}
#[derive(Debug, Default)]
pub struct TranscendentalSolver;
impl TranscendentalSolver {
pub fn new() -> Self {
Self
}
fn solve_trig_equation(
&self,
equation: &Equation,
variable: &Variable,
path: &mut ResolutionPath,
) -> Option<Expression> {
let var_name = &variable.name;
if let Some((result, func, value)) = self.match_trig_pattern_with_validation(
&equation.left,
&equation.right,
var_name,
crate::ast::Function::Sin,
crate::ast::Function::Asin,
) {
if let Err(_e) = Self::validate_trig_domain(value, &func) {
return None; }
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("asin".to_string()),
format!("Apply arcsine to solve sin({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some((result, func, value)) = self.match_trig_pattern_with_validation(
&equation.right,
&equation.left,
var_name,
crate::ast::Function::Sin,
crate::ast::Function::Asin,
) {
if let Err(_e) = Self::validate_trig_domain(value, &func) {
return None;
}
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("asin".to_string()),
format!("Apply arcsine to solve sin({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some((result, func, value)) = self.match_trig_pattern_with_validation(
&equation.left,
&equation.right,
var_name,
crate::ast::Function::Cos,
crate::ast::Function::Acos,
) {
if let Err(_e) = Self::validate_trig_domain(value, &func) {
return None;
}
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("acos".to_string()),
format!("Apply arccosine to solve cos({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some((result, func, value)) = self.match_trig_pattern_with_validation(
&equation.right,
&equation.left,
var_name,
crate::ast::Function::Cos,
crate::ast::Function::Acos,
) {
if let Err(_e) = Self::validate_trig_domain(value, &func) {
return None;
}
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("acos".to_string()),
format!("Apply arccosine to solve cos({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_trig_pattern(
&equation.left,
&equation.right,
var_name,
crate::ast::Function::Tan,
crate::ast::Function::Atan,
) {
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("atan".to_string()),
format!("Apply arctangent to solve tan({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_trig_pattern(
&equation.right,
&equation.left,
var_name,
crate::ast::Function::Tan,
crate::ast::Function::Atan,
) {
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("atan".to_string()),
format!("Apply arctangent to solve tan({}) = value", variable),
result.clone(),
));
return Some(result);
}
None
}
fn match_trig_pattern_with_validation(
&self,
left: &Expression,
right: &Expression,
var: &str,
func: crate::ast::Function,
inverse_func: crate::ast::Function,
) -> Option<(Expression, crate::ast::Function, f64)> {
if contains_variable(right, var) {
return None;
}
let value = match right.evaluate(&HashMap::new()) {
Some(v) => v,
None => return None, };
if let Expression::Function(f, args) = left {
if *f == func && args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let result =
Expression::Function(inverse_func.clone(), vec![right.clone()]);
return Some((result.simplify(), inverse_func, value));
}
}
if let Some(coeff) = extract_coefficient(&args[0], var) {
let inverse_applied =
Expression::Function(inverse_func.clone(), vec![right.clone()]);
let result = Expression::Binary(
BinaryOp::Div,
Box::new(inverse_applied),
Box::new(coeff),
);
return Some((result.simplify(), inverse_func, value));
}
}
}
if let Expression::Binary(BinaryOp::Mul, mul_left, mul_right) = left {
if let Expression::Function(f, args) = mul_left.as_ref() {
if *f == func && args.len() == 1 && !contains_variable(mul_right, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_right.as_ref().clone()),
);
let divided_value = divided.evaluate(&HashMap::new()).unwrap_or(value);
let result = Expression::Function(inverse_func.clone(), vec![divided]);
return Some((result.simplify(), inverse_func, divided_value));
}
}
}
}
if let Expression::Function(f, args) = mul_right.as_ref() {
if *f == func && args.len() == 1 && !contains_variable(mul_left, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_left.as_ref().clone()),
);
let divided_value = divided.evaluate(&HashMap::new()).unwrap_or(value);
let result = Expression::Function(inverse_func.clone(), vec![divided]);
return Some((result.simplify(), inverse_func, divided_value));
}
}
}
}
}
None
}
fn match_trig_pattern(
&self,
left: &Expression,
right: &Expression,
var: &str,
func: crate::ast::Function,
inverse_func: crate::ast::Function,
) -> Option<Expression> {
if contains_variable(right, var) {
return None;
}
if let Expression::Function(f, args) = left {
if *f == func && args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let result = Expression::Function(inverse_func, vec![right.clone()]);
return Some(result.simplify());
}
}
if let Some(coeff) = extract_coefficient(&args[0], var) {
let inverse_applied = Expression::Function(inverse_func, vec![right.clone()]);
let result = Expression::Binary(
BinaryOp::Div,
Box::new(inverse_applied),
Box::new(coeff),
);
return Some(result.simplify());
}
}
}
if let Expression::Binary(BinaryOp::Mul, mul_left, mul_right) = left {
if let Expression::Function(f, args) = mul_left.as_ref() {
if *f == func && args.len() == 1 && !contains_variable(mul_right, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_right.as_ref().clone()),
);
let result = Expression::Function(inverse_func, vec![divided]);
return Some(result.simplify());
}
}
}
}
if let Expression::Function(f, args) = mul_right.as_ref() {
if *f == func && args.len() == 1 && !contains_variable(mul_left, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_left.as_ref().clone()),
);
let result = Expression::Function(inverse_func, vec![divided]);
return Some(result.simplify());
}
}
}
}
}
None
}
fn solve_log_equation(
&self,
equation: &Equation,
variable: &Variable,
path: &mut ResolutionPath,
) -> Option<Expression> {
let var_name = &variable.name;
if let Some(result) = self.match_log_pattern(&equation.left, &equation.right, var_name) {
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("exp".to_string()),
format!("Apply exponential to solve ln({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_log_pattern(&equation.right, &equation.left, var_name) {
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("exp".to_string()),
format!("Apply exponential to solve ln({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_log10_pattern(&equation.left, &equation.right, var_name) {
path.add_step(ResolutionStep::new(
Operation::PowerBothSides(Expression::Integer(10)),
format!("Apply 10^x to solve log10({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_log10_pattern(&equation.right, &equation.left, var_name) {
path.add_step(ResolutionStep::new(
Operation::PowerBothSides(Expression::Integer(10)),
format!("Apply 10^x to solve log10({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_log_base_pattern(&equation.left, &equation.right, var_name)
{
path.add_step(ResolutionStep::new(
Operation::ApplyLogProperty("exponential form".to_string()),
format!(
"Convert logarithm to exponential form to solve for {}",
variable
),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_log_base_pattern(&equation.right, &equation.left, var_name)
{
path.add_step(ResolutionStep::new(
Operation::ApplyLogProperty("exponential form".to_string()),
format!(
"Convert logarithm to exponential form to solve for {}",
variable
),
result.clone(),
));
return Some(result);
}
None
}
fn match_log_pattern(
&self,
left: &Expression,
right: &Expression,
var: &str,
) -> Option<Expression> {
if contains_variable(right, var) {
return None;
}
if let Expression::Function(crate::ast::Function::Ln, args) = left {
if args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let result =
Expression::Function(crate::ast::Function::Exp, vec![right.clone()]);
return Some(result.simplify());
}
}
}
}
if let Expression::Binary(BinaryOp::Mul, mul_left, mul_right) = left {
if let Expression::Function(crate::ast::Function::Ln, args) = mul_left.as_ref() {
if args.len() == 1 && !contains_variable(mul_right, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_right.as_ref().clone()),
);
let result =
Expression::Function(crate::ast::Function::Exp, vec![divided]);
return Some(result.simplify());
}
}
}
}
if let Expression::Function(crate::ast::Function::Ln, args) = mul_right.as_ref() {
if args.len() == 1 && !contains_variable(mul_left, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_left.as_ref().clone()),
);
let result =
Expression::Function(crate::ast::Function::Exp, vec![divided]);
return Some(result.simplify());
}
}
}
}
}
None
}
fn match_log10_pattern(
&self,
left: &Expression,
right: &Expression,
var: &str,
) -> Option<Expression> {
if contains_variable(right, var) {
return None;
}
if let Expression::Function(crate::ast::Function::Log10, args) = left {
if args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let result = Expression::Power(
Box::new(Expression::Integer(10)),
Box::new(right.clone()),
);
return Some(result.simplify());
}
}
}
}
None
}
fn match_log_base_pattern(
&self,
left: &Expression,
right: &Expression,
var: &str,
) -> Option<Expression> {
if contains_variable(right, var) {
return None;
}
if let Expression::Function(crate::ast::Function::Log, args) = left {
if args.len() == 2 {
if let Expression::Variable(v) = &args[0] {
if v.name == var && !contains_variable(&args[1], var) {
let result =
Expression::Power(Box::new(args[1].clone()), Box::new(right.clone()));
return Some(result.simplify());
}
}
}
}
None
}
fn solve_exp_equation(
&self,
equation: &Equation,
variable: &Variable,
path: &mut ResolutionPath,
) -> Option<Expression> {
let var_name = &variable.name;
if let Some(result) = self.match_exp_pattern(&equation.left, &equation.right, var_name) {
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("ln".to_string()),
format!("Apply natural logarithm to solve exp({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_exp_pattern(&equation.right, &equation.left, var_name) {
path.add_step(ResolutionStep::new(
Operation::ApplyFunction("ln".to_string()),
format!("Apply natural logarithm to solve exp({}) = value", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_power_pattern(&equation.left, &equation.right, var_name) {
path.add_step(ResolutionStep::new(
Operation::ApplyLogProperty("change of base".to_string()),
format!("Apply logarithm to solve for {} in exponent", variable),
result.clone(),
));
return Some(result);
}
if let Some(result) = self.match_power_pattern(&equation.right, &equation.left, var_name) {
path.add_step(ResolutionStep::new(
Operation::ApplyLogProperty("change of base".to_string()),
format!("Apply logarithm to solve for {} in exponent", variable),
result.clone(),
));
return Some(result);
}
None
}
fn match_exp_pattern(
&self,
left: &Expression,
right: &Expression,
var: &str,
) -> Option<Expression> {
if contains_variable(right, var) {
return None;
}
if let Expression::Function(crate::ast::Function::Exp, args) = left {
if args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let result =
Expression::Function(crate::ast::Function::Ln, vec![right.clone()]);
return Some(result.simplify());
}
}
if let Some(coeff) = extract_coefficient(&args[0], var) {
let ln_applied =
Expression::Function(crate::ast::Function::Ln, vec![right.clone()]);
let result =
Expression::Binary(BinaryOp::Div, Box::new(ln_applied), Box::new(coeff));
return Some(result.simplify());
}
}
}
if let Expression::Binary(BinaryOp::Mul, mul_left, mul_right) = left {
if let Expression::Function(crate::ast::Function::Exp, args) = mul_left.as_ref() {
if args.len() == 1 && !contains_variable(mul_right, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_right.as_ref().clone()),
);
let result =
Expression::Function(crate::ast::Function::Ln, vec![divided]);
return Some(result.simplify());
}
}
}
}
if let Expression::Function(crate::ast::Function::Exp, args) = mul_right.as_ref() {
if args.len() == 1 && !contains_variable(mul_left, var) {
if let Expression::Variable(v) = &args[0] {
if v.name == var {
let divided = Expression::Binary(
BinaryOp::Div,
Box::new(right.clone()),
Box::new(mul_left.as_ref().clone()),
);
let result =
Expression::Function(crate::ast::Function::Ln, vec![divided]);
return Some(result.simplify());
}
}
}
}
}
None
}
fn match_power_pattern(
&self,
left: &Expression,
right: &Expression,
var: &str,
) -> Option<Expression> {
if contains_variable(right, var) {
return None;
}
if let Expression::Power(base, exp) = left {
if !contains_variable(base, var) && contains_variable(exp, var) {
if let Expression::Variable(v) = exp.as_ref() {
if v.name == var {
let ln_right =
Expression::Function(crate::ast::Function::Ln, vec![right.clone()]);
let ln_base = Expression::Function(
crate::ast::Function::Ln,
vec![base.as_ref().clone()],
);
let result = Expression::Binary(
BinaryOp::Div,
Box::new(ln_right),
Box::new(ln_base),
);
return Some(result.simplify());
}
}
if let Some(coeff) = extract_coefficient(exp, var) {
let ln_right =
Expression::Function(crate::ast::Function::Ln, vec![right.clone()]);
let ln_base =
Expression::Function(crate::ast::Function::Ln, vec![base.as_ref().clone()]);
let divided =
Expression::Binary(BinaryOp::Div, Box::new(ln_right), Box::new(ln_base));
let result =
Expression::Binary(BinaryOp::Div, Box::new(divided), Box::new(coeff));
return Some(result.simplify());
}
}
}
None
}
fn has_transcendental_function(expr: &Expression) -> bool {
match expr {
Expression::Function(func, _) => {
matches!(
func,
crate::ast::Function::Sin
| crate::ast::Function::Cos
| crate::ast::Function::Tan
| crate::ast::Function::Asin
| crate::ast::Function::Acos
| crate::ast::Function::Atan
| crate::ast::Function::Sinh
| crate::ast::Function::Cosh
| crate::ast::Function::Tanh
| crate::ast::Function::Exp
| crate::ast::Function::Ln
| crate::ast::Function::Log
| crate::ast::Function::Log2
| crate::ast::Function::Log10
)
}
Expression::Unary(_, inner) => Self::has_transcendental_function(inner),
Expression::Binary(_, left, right) => {
Self::has_transcendental_function(left) || Self::has_transcendental_function(right)
}
Expression::Power(base, exp) => {
has_any_variable(exp)
|| Self::has_transcendental_function(base)
|| Self::has_transcendental_function(exp)
}
_ => false,
}
}
fn validate_trig_domain(value: f64, func: &crate::ast::Function) -> Result<(), SolverError> {
match func {
crate::ast::Function::Asin | crate::ast::Function::Acos => {
if value.abs() > 1.0 {
return Err(SolverError::Other(format!(
"Domain error: {:?} requires |value| ≤ 1, got {}",
func, value
)));
}
}
_ => {}
}
Ok(())
}
}
impl Solver for TranscendentalSolver {
fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> SolverResult<(Solution, ResolutionPath)> {
let var_name = &variable.name;
let left_has_var = contains_variable(&equation.left, var_name);
let right_has_var = contains_variable(&equation.right, var_name);
if !left_has_var && !right_has_var {
return Err(SolverError::CannotSolve(format!(
"Variable '{}' not found in equation",
var_name
)));
}
let initial_expr = Expression::Binary(
BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let mut path = ResolutionPath::new(initial_expr);
if let Some(result) = self.solve_trig_equation(equation, variable, &mut path) {
if let Expression::Function(func, args) = &result {
if args.len() == 1 {
if let Some(val) = args[0].evaluate(&HashMap::new()) {
Self::validate_trig_domain(val, func)?;
}
}
}
let evaluated = evaluate_constants(&result);
path.set_result(evaluated.clone());
return Ok((Solution::Unique(evaluated), path));
}
if let Some(result) = self.solve_log_equation(equation, variable, &mut path) {
let evaluated = evaluate_constants(&result);
path.set_result(evaluated.clone());
return Ok((Solution::Unique(evaluated), path));
}
if let Some(result) = self.solve_exp_equation(equation, variable, &mut path) {
let evaluated = evaluate_constants(&result);
path.set_result(evaluated.clone());
return Ok((Solution::Unique(evaluated), path));
}
Err(SolverError::CannotSolve(
"Transcendental equation pattern not recognized or too complex".to_string(),
))
}
fn can_solve(&self, equation: &Equation) -> bool {
Self::has_transcendental_function(&equation.left)
|| Self::has_transcendental_function(&equation.right)
}
}
#[derive(Debug, Clone)]
pub enum SystemSolution {
Unique(HashMap<Variable, Expression>),
Infinite {
bound: HashMap<Variable, Expression>,
free: Vec<Variable>,
},
NoSolution,
}
#[derive(Debug, Clone)]
pub struct LinearSystem {
coefficients: Vec<Vec<f64>>,
constants: Vec<f64>,
variables: Vec<Variable>,
}
impl LinearSystem {
pub fn from_equations(equations: &[Equation], variables: &[Variable]) -> SolverResult<Self> {
let n_eqs = equations.len();
let n_vars = variables.len();
if n_eqs == 0 || n_vars == 0 {
return Err(SolverError::Other("Empty system".to_string()));
}
let mut coefficients = Vec::with_capacity(n_eqs);
let mut constants = Vec::with_capacity(n_eqs);
for eq in equations {
let combined = Expression::Binary(
BinaryOp::Sub,
Box::new(eq.left.clone()),
Box::new(eq.right.clone()),
)
.simplify();
let (row, constant) = Self::extract_linear_coefficients(&combined, variables)?;
coefficients.push(row);
constants.push(-constant); }
Ok(Self {
coefficients,
constants,
variables: variables.to_vec(),
})
}
fn extract_linear_coefficients(
expr: &Expression,
variables: &[Variable],
) -> SolverResult<(Vec<f64>, f64)> {
let mut coeffs = vec![0.0; variables.len()];
let mut constant = 0.0;
let terms = Self::collect_additive_terms(expr);
for term in terms {
let mut found_var = false;
for (i, var) in variables.iter().enumerate() {
if term.contains_variable(&var.name) {
let coeff = Self::extract_coefficient(&term, var)?;
coeffs[i] += coeff;
found_var = true;
break;
}
}
if !found_var {
let empty_vars: HashMap<String, f64> = HashMap::new();
match term.evaluate(&empty_vars) {
Some(val) => constant += val,
None => {
return Err(SolverError::Other(format!(
"Cannot evaluate constant term: {}",
term
)));
}
}
}
}
Ok((coeffs, constant))
}
fn collect_additive_terms(expr: &Expression) -> Vec<Expression> {
match expr {
Expression::Binary(BinaryOp::Add, left, right) => {
let mut terms = Self::collect_additive_terms(left);
terms.extend(Self::collect_additive_terms(right));
terms
}
Expression::Binary(BinaryOp::Sub, left, right) => {
let mut terms = Self::collect_additive_terms(left);
for term in Self::collect_additive_terms(right) {
terms.push(Expression::Unary(UnaryOp::Neg, Box::new(term)));
}
terms
}
_ => vec![expr.clone()],
}
}
fn extract_coefficient(term: &Expression, var: &Variable) -> SolverResult<f64> {
match term {
Expression::Variable(v) if v.name == var.name => Ok(1.0),
Expression::Unary(UnaryOp::Neg, inner) => {
let inner_coeff = Self::extract_coefficient(inner, var)?;
Ok(-inner_coeff)
}
Expression::Binary(BinaryOp::Mul, left, right) => {
let left_has_var = left.contains_variable(&var.name);
let right_has_var = right.contains_variable(&var.name);
if left_has_var && right_has_var {
return Err(SolverError::Other(format!(
"Non-linear term: {} * {} both contain {}",
left, right, var.name
)));
}
if left_has_var {
let empty: HashMap<String, f64> = HashMap::new();
let coeff = right.evaluate(&empty).ok_or_else(|| {
SolverError::Other(format!("Cannot evaluate coefficient: {}", right))
})?;
let var_coeff = Self::extract_coefficient(left, var)?;
Ok(coeff * var_coeff)
} else {
let empty: HashMap<String, f64> = HashMap::new();
let coeff = left.evaluate(&empty).ok_or_else(|| {
SolverError::Other(format!("Cannot evaluate coefficient: {}", left))
})?;
let var_coeff = Self::extract_coefficient(right, var)?;
Ok(coeff * var_coeff)
}
}
Expression::Binary(BinaryOp::Div, left, right) => {
if right.contains_variable(&var.name) {
return Err(SolverError::Other(format!(
"Non-linear: variable {} in denominator",
var.name
)));
}
let empty: HashMap<String, f64> = HashMap::new();
let divisor = right.evaluate(&empty).ok_or_else(|| {
SolverError::Other(format!("Cannot evaluate divisor: {}", right))
})?;
if divisor.abs() < 1e-15 {
return Err(SolverError::DivisionByZero);
}
let var_coeff = Self::extract_coefficient(left, var)?;
Ok(var_coeff / divisor)
}
_ => {
if term.contains_variable(&var.name) {
Err(SolverError::Other(format!(
"Cannot extract coefficient from: {}",
term
)))
} else {
Ok(0.0)
}
}
}
}
pub fn solve(&self) -> SolverResult<SystemSolution> {
let n_eqs = self.coefficients.len();
let n_vars = self.variables.len();
let mut augmented: Vec<Vec<f64>> = self
.coefficients
.iter()
.zip(self.constants.iter())
.map(|(row, &c)| {
let mut new_row = row.clone();
new_row.push(c);
new_row
})
.collect();
let mut pivot_row = 0;
let mut pivot_cols = Vec::new();
for col in 0..n_vars {
if pivot_row >= n_eqs {
break;
}
let mut max_row = pivot_row;
let mut max_val = augmented[pivot_row][col].abs();
for row in (pivot_row + 1)..n_eqs {
if augmented[row][col].abs() > max_val {
max_val = augmented[row][col].abs();
max_row = row;
}
}
if max_val < 1e-15 {
continue;
}
if max_row != pivot_row {
augmented.swap(pivot_row, max_row);
}
pivot_cols.push(col);
let pivot_val = augmented[pivot_row][col];
for row in (pivot_row + 1)..n_eqs {
let factor = augmented[row][col] / pivot_val;
augmented[row][col] = 0.0;
for c in (col + 1)..=n_vars {
augmented[row][c] -= factor * augmented[pivot_row][c];
}
}
pivot_row += 1;
}
let rank = pivot_cols.len();
for row in rank..n_eqs {
let rhs = augmented[row][n_vars];
let all_zeros = augmented[row][0..n_vars].iter().all(|&x| x.abs() < 1e-15);
if all_zeros && rhs.abs() > 1e-15 {
return Ok(SystemSolution::NoSolution);
}
}
if rank == n_vars {
let mut solution_values = vec![0.0; n_vars];
for i in (0..rank).rev() {
let col = pivot_cols[i];
let mut sum = augmented[i][n_vars];
for j in (col + 1)..n_vars {
sum -= augmented[i][j] * solution_values[j];
}
solution_values[col] = sum / augmented[i][col];
}
let mut result = HashMap::new();
for (i, var) in self.variables.iter().enumerate() {
let val = solution_values[i];
let expr = if (val - val.round()).abs() < 1e-10 {
Expression::Integer(val.round() as i64)
} else {
Expression::Float(val)
};
result.insert(var.clone(), expr);
}
Ok(SystemSolution::Unique(result))
} else {
let pivot_set: std::collections::HashSet<_> = pivot_cols.iter().cloned().collect();
let free_cols: Vec<_> = (0..n_vars).filter(|c| !pivot_set.contains(c)).collect();
let free_vars: Vec<_> = free_cols
.iter()
.map(|&c| self.variables[c].clone())
.collect();
let mut bound = HashMap::new();
for i in (0..rank).rev() {
let col = pivot_cols[i];
let rhs = augmented[i][n_vars];
let mut terms: Vec<Expression> = vec![];
if rhs.abs() > 1e-15 {
terms.push(if (rhs - rhs.round()).abs() < 1e-10 {
Expression::Integer(rhs.round() as i64)
} else {
Expression::Float(rhs)
});
}
for &free_col in &free_cols {
let coeff = -augmented[i][free_col] / augmented[i][col];
if coeff.abs() > 1e-15 {
let free_var = Expression::Variable(self.variables[free_col].clone());
let term = if (coeff - coeff.round()).abs() < 1e-10 {
let int_coeff = coeff.round() as i64;
if int_coeff == 1 {
free_var
} else if int_coeff == -1 {
Expression::Unary(UnaryOp::Neg, Box::new(free_var))
} else {
Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(int_coeff)),
Box::new(free_var),
)
}
} else {
Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Float(coeff)),
Box::new(free_var),
)
};
terms.push(term);
}
}
let expr = if terms.is_empty() {
Expression::Integer(0)
} else if terms.len() == 1 {
terms.remove(0)
} else {
let mut result = terms.remove(0);
for term in terms {
result =
Expression::Binary(BinaryOp::Add, Box::new(result), Box::new(term));
}
result
};
let pivot_coeff = augmented[i][col];
let final_expr = if (pivot_coeff - 1.0).abs() < 1e-15 {
expr
} else {
Expression::Binary(
BinaryOp::Div,
Box::new(expr),
Box::new(if (pivot_coeff - pivot_coeff.round()).abs() < 1e-10 {
Expression::Integer(pivot_coeff.round() as i64)
} else {
Expression::Float(pivot_coeff)
}),
)
};
bound.insert(self.variables[col].clone(), final_expr);
}
Ok(SystemSolution::Infinite {
bound,
free: free_vars,
})
}
}
pub fn solve_cramers(&self) -> SolverResult<SystemSolution> {
let n = self.variables.len();
if self.coefficients.len() != n {
return Err(SolverError::Other(
"Cramer's rule requires square system".to_string(),
));
}
if n != 2 && n != 3 {
return Err(SolverError::Other(
"Cramer's rule only implemented for 2x2 and 3x3 systems".to_string(),
));
}
let det_a = if n == 2 {
Self::det_2x2(&self.coefficients)
} else {
Self::det_3x3(&self.coefficients)
};
if det_a.abs() < 1e-15 {
return self.solve();
}
let mut result = HashMap::new();
for i in 0..n {
let mut modified: Vec<Vec<f64>> = self.coefficients.clone();
for (row, &c) in self.constants.iter().enumerate() {
modified[row][i] = c;
}
let det_i = if n == 2 {
Self::det_2x2(&modified)
} else {
Self::det_3x3(&modified)
};
let val = det_i / det_a;
let expr = if (val - val.round()).abs() < 1e-10 {
Expression::Integer(val.round() as i64)
} else {
Expression::Float(val)
};
result.insert(self.variables[i].clone(), expr);
}
Ok(SystemSolution::Unique(result))
}
fn det_2x2(m: &[Vec<f64>]) -> f64 {
m[0][0] * m[1][1] - m[0][1] * m[1][0]
}
fn det_3x3(m: &[Vec<f64>]) -> f64 {
let a = m[0][0];
let b = m[0][1];
let c = m[0][2];
let minor1 = m[1][1] * m[2][2] - m[1][2] * m[2][1];
let minor2 = m[1][0] * m[2][2] - m[1][2] * m[2][0];
let minor3 = m[1][0] * m[2][1] - m[1][1] * m[2][0];
a * minor1 - b * minor2 + c * minor3
}
}
#[derive(Debug, Default)]
pub struct SystemSolver;
impl SystemSolver {
pub fn new() -> Self {
Self
}
pub fn solve_linear_system(
&self,
equations: &[Equation],
variables: &[Variable],
) -> SolverResult<SystemSolution> {
let system = LinearSystem::from_equations(equations, variables)?;
system.solve()
}
pub fn solve_cramers(
&self,
equations: &[Equation],
variables: &[Variable],
) -> SolverResult<SystemSolution> {
let system = LinearSystem::from_equations(equations, variables)?;
system.solve_cramers()
}
pub fn solve_system(
&self,
equations: &[Equation],
variables: &[Variable],
) -> SolverResult<HashMap<Variable, Solution>> {
let result = self.solve_linear_system(equations, variables)?;
match result {
SystemSolution::Unique(sol) => {
let mut out = HashMap::new();
for (var, expr) in sol {
out.insert(var, Solution::Unique(expr));
}
Ok(out)
}
SystemSolution::Infinite { bound, free: _ } => {
let mut out = HashMap::new();
for (var, expr) in bound {
out.insert(
var,
Solution::Parametric {
expression: expr,
constraints: vec![],
},
);
}
Ok(out)
}
SystemSolution::NoSolution => Err(SolverError::NoSolution),
}
}
}
#[derive(Debug)]
pub struct SmartSolver {
linear: LinearSolver,
quadratic: QuadraticSolver,
polynomial: PolynomialSolver,
transcendental: TranscendentalSolver,
}
impl SmartSolver {
pub fn new() -> Self {
Self {
linear: LinearSolver::new(),
quadratic: QuadraticSolver::new(),
polynomial: PolynomialSolver::new(),
transcendental: TranscendentalSolver::new(),
}
}
}
impl Default for SmartSolver {
fn default() -> Self {
Self::new()
}
}
impl Solver for SmartSolver {
fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> SolverResult<(Solution, ResolutionPath)> {
if self.linear.can_solve(equation) {
self.linear.solve(equation, variable)
} else if self.quadratic.can_solve(equation) {
self.quadratic.solve(equation, variable)
} else if self.polynomial.can_solve(equation) {
self.polynomial.solve(equation, variable)
} else if self.transcendental.can_solve(equation) {
self.transcendental.solve(equation, variable)
} else {
Err(SolverError::UnsupportedEquationType)
}
}
fn can_solve(&self, equation: &Equation) -> bool {
self.linear.can_solve(equation)
|| self.quadratic.can_solve(equation)
|| self.polynomial.can_solve(equation)
|| self.transcendental.can_solve(equation)
}
}
pub fn solve_for(
equation: &Equation,
target: &str,
known_values: &HashMap<String, f64>,
) -> Result<ResolutionPath, SolverError> {
let target_var = Variable::new(target);
let solver = SmartSolver::new();
let (solution, mut path) = solver.solve(equation, &target_var)?;
let solution_expr = match solution {
Solution::Unique(expr) => expr,
Solution::Multiple(_) => {
return Err(SolverError::Other(
"Multiple solutions not yet supported in solve_for".to_string(),
))
}
Solution::None => return Err(SolverError::NoSolution),
Solution::Infinite => return Err(SolverError::InfiniteSolutions),
Solution::Parametric { .. } => {
return Err(SolverError::Other(
"Parametric solutions not yet supported in solve_for".to_string(),
))
}
};
if !known_values.is_empty() {
let substituted = substitute_values(&solution_expr, known_values);
let simplified = substituted.simplify();
let evaluated = evaluate_constants(&simplified);
path.add_step(ResolutionStep::new(
Operation::Substitute {
variable: Variable::new("known_values"),
value: Expression::Integer(0), },
"Substitute known values and evaluate".to_string(),
evaluated.clone(),
));
path.set_result(evaluated);
} else {
path.set_result(solution_expr);
}
Ok(path)
}
fn substitute_values(expr: &Expression, values: &HashMap<String, f64>) -> Expression {
match expr {
Expression::Variable(v) => {
if let Some(&value) = values.get(&v.name) {
Expression::Float(value)
} else {
expr.clone()
}
}
Expression::Unary(op, inner) => {
Expression::Unary(*op, Box::new(substitute_values(inner, values)))
}
Expression::Binary(op, left, right) => Expression::Binary(
*op,
Box::new(substitute_values(left, values)),
Box::new(substitute_values(right, values)),
),
Expression::Function(func, args) => Expression::Function(
func.clone(),
args.iter()
.map(|arg| substitute_values(arg, values))
.collect(),
),
Expression::Power(base, exp) => Expression::Power(
Box::new(substitute_values(base, values)),
Box::new(substitute_values(exp, values)),
),
_ => expr.clone(),
}
}
pub fn compute_partial_derivative(
equation: &Equation,
output_var: &str,
input_var: &str,
values: &HashMap<String, f64>,
) -> Result<f64, SolverError> {
let output_expr = if let Expression::Variable(v) = &equation.left {
if v.name == output_var {
&equation.right
} else if let Expression::Variable(v2) = &equation.right {
if v2.name == output_var {
&equation.left
} else {
return Err(SolverError::CannotSolve(format!(
"Output variable '{}' not found in equation",
output_var
)));
}
} else {
return Err(SolverError::CannotSolve(format!(
"Output variable '{}' not found in equation",
output_var
)));
}
} else if let Expression::Variable(v) = &equation.right {
if v.name == output_var {
&equation.left
} else {
return Err(SolverError::CannotSolve(format!(
"Output variable '{}' not found in equation",
output_var
)));
}
} else {
return Err(SolverError::CannotSolve(
"Equation does not have output variable isolated".to_string(),
));
};
let derivative_expr = output_expr.differentiate(input_var);
let simplified = derivative_expr.simplify();
simplified.evaluate(values).ok_or_else(|| {
SolverError::Other(format!(
"Failed to evaluate derivative - missing or invalid values"
))
})
}
pub fn compute_all_partial_derivatives(
equation: &Equation,
output_var: &str,
input_vars: &[String],
values: &HashMap<String, f64>,
) -> Result<HashMap<String, f64>, SolverError> {
let mut derivatives = HashMap::new();
for input_var in input_vars {
let derivative = compute_partial_derivative(equation, output_var, input_var, values)?;
derivatives.insert(input_var.clone(), derivative);
}
Ok(derivatives)
}
#[cfg(test)]
mod system_solver_tests {
use super::*;
use crate::ast::{BinaryOp, Equation, Expression, Variable};
fn var(name: &str) -> Expression {
Expression::Variable(Variable::new(name))
}
fn int(n: i64) -> Expression {
Expression::Integer(n)
}
fn add(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Add, Box::new(left), Box::new(right))
}
fn sub(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Sub, Box::new(left), Box::new(right))
}
fn mul(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Mul, Box::new(left), Box::new(right))
}
#[test]
fn test_2x2_unique_solution() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new("eq1", add(var("x"), var("y")), int(5));
let eq2 = Equation::new("eq2", sub(var("x"), var("y")), int(1));
let solver = SystemSolver::new();
let result = solver
.solve_linear_system(&[eq1, eq2], &[x.clone(), y.clone()])
.unwrap();
match result {
SystemSolution::Unique(sol) => {
let x_val = sol.get(&x).unwrap();
let y_val = sol.get(&y).unwrap();
let empty: HashMap<String, f64> = HashMap::new();
assert_eq!(x_val.evaluate(&empty), Some(3.0));
assert_eq!(y_val.evaluate(&empty), Some(2.0));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_2x2_with_coefficients() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new(
"eq1",
add(mul(int(2), var("x")), mul(int(3), var("y"))),
int(8),
);
let eq2 = Equation::new("eq2", sub(mul(int(4), var("x")), var("y")), int(2));
let solver = SystemSolver::new();
let result = solver
.solve_linear_system(&[eq1, eq2], &[x.clone(), y.clone()])
.unwrap();
match result {
SystemSolution::Unique(sol) => {
let x_val = sol.get(&x).unwrap();
let y_val = sol.get(&y).unwrap();
let empty: HashMap<String, f64> = HashMap::new();
assert_eq!(x_val.evaluate(&empty), Some(1.0));
assert_eq!(y_val.evaluate(&empty), Some(2.0));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_3x3_unique_solution() {
let x = Variable::new("x");
let y = Variable::new("y");
let z = Variable::new("z");
let eq1 = Equation::new("eq1", add(add(var("x"), var("y")), var("z")), int(6));
let eq2 = Equation::new(
"eq2",
sub(add(mul(int(2), var("x")), var("y")), var("z")),
int(1),
);
let eq3 = Equation::new(
"eq3",
add(sub(var("x"), var("y")), mul(int(2), var("z"))),
int(5),
);
let solver = SystemSolver::new();
let result = solver
.solve_linear_system(&[eq1, eq2, eq3], &[x.clone(), y.clone(), z.clone()])
.unwrap();
match result {
SystemSolution::Unique(sol) => {
let empty: HashMap<String, f64> = HashMap::new();
assert_eq!(sol.get(&x).unwrap().evaluate(&empty), Some(1.0));
assert_eq!(sol.get(&y).unwrap().evaluate(&empty), Some(2.0));
assert_eq!(sol.get(&z).unwrap().evaluate(&empty), Some(3.0));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_underdetermined_system() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new("eq1", add(var("x"), var("y")), int(5));
let solver = SystemSolver::new();
let result = solver
.solve_linear_system(&[eq1], &[x.clone(), y.clone()])
.unwrap();
match result {
SystemSolution::Infinite { bound, free } => {
assert!(!free.is_empty());
assert!(!bound.is_empty());
}
_ => panic!("Expected infinite solutions"),
}
}
#[test]
fn test_inconsistent_system() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new("eq1", add(var("x"), var("y")), int(5));
let eq2 = Equation::new("eq2", add(var("x"), var("y")), int(6));
let solver = SystemSolver::new();
let result = solver
.solve_linear_system(&[eq1, eq2], &[x.clone(), y.clone()])
.unwrap();
assert!(matches!(result, SystemSolution::NoSolution));
}
#[test]
fn test_cramers_rule_2x2() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new("eq1", add(var("x"), var("y")), int(5));
let eq2 = Equation::new("eq2", sub(var("x"), var("y")), int(1));
let solver = SystemSolver::new();
let result = solver
.solve_cramers(&[eq1, eq2], &[x.clone(), y.clone()])
.unwrap();
match result {
SystemSolution::Unique(sol) => {
let empty: HashMap<String, f64> = HashMap::new();
assert_eq!(sol.get(&x).unwrap().evaluate(&empty), Some(3.0));
assert_eq!(sol.get(&y).unwrap().evaluate(&empty), Some(2.0));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_linear_system_struct() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new("eq1", add(var("x"), var("y")), int(5));
let eq2 = Equation::new("eq2", sub(var("x"), var("y")), int(1));
let system = LinearSystem::from_equations(&[eq1, eq2], &[x.clone(), y.clone()]).unwrap();
assert_eq!(system.coefficients.len(), 2);
assert_eq!(system.constants.len(), 2);
}
#[test]
fn test_overdetermined_consistent() {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Equation::new("eq1", add(var("x"), var("y")), int(5));
let eq2 = Equation::new("eq2", sub(var("x"), var("y")), int(1));
let eq3 = Equation::new("eq3", mul(int(2), var("x")), int(6));
let solver = SystemSolver::new();
let result = solver
.solve_linear_system(&[eq1, eq2, eq3], &[x.clone(), y.clone()])
.unwrap();
match result {
SystemSolution::Unique(sol) => {
let empty: HashMap<String, f64> = HashMap::new();
assert_eq!(sol.get(&x).unwrap().evaluate(&empty), Some(3.0));
assert_eq!(sol.get(&y).unwrap().evaluate(&empty), Some(2.0));
}
_ => panic!("Expected unique solution"),
}
}
}