use std::collections::HashMap;
use thales::ast::{BinaryOp, Equation, Expression, Variable};
use thales::solver::{solve_for, LinearSolver, Solver, SolverError};
fn var(name: &str) -> Expression {
Expression::Variable(Variable::new(name))
}
fn int(n: i64) -> Expression {
Expression::Integer(n)
}
fn float(x: f64) -> Expression {
Expression::Float(x)
}
fn binary(op: BinaryOp, left: Expression, right: Expression) -> Expression {
Expression::Binary(op, Box::new(left), Box::new(right))
}
fn mul(left: Expression, right: Expression) -> Expression {
binary(BinaryOp::Mul, left, right)
}
fn add(left: Expression, right: Expression) -> Expression {
binary(BinaryOp::Add, left, right)
}
fn div(left: Expression, right: Expression) -> Expression {
binary(BinaryOp::Div, left, right)
}
fn pow(base: Expression, exp: Expression) -> Expression {
Expression::Power(Box::new(base), Box::new(exp))
}
#[test]
fn test_linear_solver_simple_equality() {
let equation = Equation::new("test", var("x"), int(5));
let solver = LinearSolver::new();
let target = Variable::new("x");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, int(5));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_linear_solver_multiplication() {
let left = mul(int(2), var("x"));
let right = int(10);
let equation = Equation::new("test", left, right);
let solver = LinearSolver::new();
let target = Variable::new("x");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
let expected = int(5);
assert_eq!(expr, expected);
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_linear_solver_addition() {
let left = add(var("x"), int(3));
let right = int(7);
let equation = Equation::new("test", left, right);
let solver = LinearSolver::new();
let target = Variable::new("x");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
let expected = int(4);
assert_eq!(expr, expected);
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_linear_solver_ax_plus_b() {
let left = add(mul(int(2), var("x")), int(3));
let right = int(7);
let equation = Equation::new("test", left, right);
let solver = LinearSolver::new();
let target = Variable::new("x");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
let expected = int(2);
assert_eq!(expr, expected);
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_linear_solver_variable_not_found() {
let left = add(int(2), int(3));
let right = int(5);
let equation = Equation::new("test", left, right);
let solver = LinearSolver::new();
let target = Variable::new("x");
let result = solver.solve(&equation, &target);
assert!(result.is_err());
match result.unwrap_err() {
SolverError::CannotSolve(msg) => {
assert!(msg.contains("not found"));
}
_ => panic!("Expected CannotSolve error"),
}
}
#[test]
fn test_force_equation_solve_for_f() {
let left = var("F");
let right = mul(var("m"), var("a"));
let equation = Equation::new("force", left, right);
let solver = LinearSolver::new();
let target = Variable::new("F");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, mul(var("m"), var("a")));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_force_equation_solve_for_m() {
let left = var("F");
let right = mul(var("m"), var("a"));
let equation = Equation::new("force", left, right);
let solver = LinearSolver::new();
let target = Variable::new("m");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, div(var("F"), var("a")));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_force_equation_solve_for_a() {
let left = var("F");
let right = mul(var("m"), var("a"));
let equation = Equation::new("force", left, right);
let solver = LinearSolver::new();
let target = Variable::new("a");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, div(var("F"), var("m")));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_velocity_equation_solve_for_v() {
let left = var("v");
let right = div(var("d"), var("t"));
let equation = Equation::new("velocity", left, right);
let solver = LinearSolver::new();
let target = Variable::new("v");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, div(var("d"), var("t")));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_velocity_equation_solve_for_d() {
let left = var("v");
let right = div(var("d"), var("t"));
let equation = Equation::new("velocity", left, right);
let solver = LinearSolver::new();
let target = Variable::new("d");
let result = solver.solve(&equation, &target);
if result.is_ok() {
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
println!("Got solution: {:?}", expr);
}
_ => panic!("Expected unique solution"),
}
}
}
#[test]
fn test_energy_equation_solve_for_e() {
let left = var("E");
let right = mul(var("m"), pow(var("c"), int(2)));
let equation = Equation::new("energy", left, right);
let solver = LinearSolver::new();
let target = Variable::new("E");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, mul(var("m"), pow(var("c"), int(2))));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_energy_equation_solve_for_m() {
let left = var("E");
let right = mul(var("m"), pow(var("c"), int(2)));
let equation = Equation::new("energy", left, right);
let solver = LinearSolver::new();
let target = Variable::new("m");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, div(var("E"), pow(var("c"), int(2))));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_linear_equation_solve_for_y() {
let left = var("y");
let right = add(mul(var("m"), var("x")), var("b"));
let equation = Equation::new("line", left, right);
let solver = LinearSolver::new();
let target = Variable::new("y");
let result = solver.solve(&equation, &target);
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
thales::solver::Solution::Unique(expr) => {
assert_eq!(expr, add(mul(var("m"), var("x")), var("b")));
}
_ => panic!("Expected unique solution"),
}
}
#[test]
fn test_solve_for_with_values() {
let left = var("F");
let right = mul(var("m"), var("a"));
let equation = Equation::new("force", left, right);
let mut known_values = HashMap::new();
known_values.insert("m".to_string(), 2.0);
known_values.insert("a".to_string(), 3.0);
let result = solve_for(&equation, "F", &known_values);
if let Err(ref e) = result {
eprintln!("Error solving equation: {:?}", e);
}
assert!(result.is_ok());
let path = result.unwrap();
if let Expression::Float(val) = &path.result {
assert!((val - 6.0).abs() < 1e-10);
} else if let Expression::Integer(val) = &path.result {
assert_eq!(*val, 6);
} else {
panic!("Expected numeric result, got: {:?}", path.result);
}
}
#[test]
fn test_solve_for_partial_values() {
let left = var("F");
let right = mul(var("m"), var("a"));
let equation = Equation::new("force", left, right);
let mut known_values = HashMap::new();
known_values.insert("m".to_string(), 2.0);
let result = solve_for(&equation, "F", &known_values);
assert!(result.is_ok());
let path = result.unwrap();
println!("Result: {:?}", path.result);
assert!(path.result.contains_variable("a"));
}
#[test]
fn test_solve_for_no_values() {
let left = var("F");
let right = mul(var("m"), var("a"));
let equation = Equation::new("force", left, right);
let known_values = HashMap::new();
let result = solve_for(&equation, "F", &known_values);
assert!(result.is_ok());
let path = result.unwrap();
assert_eq!(path.result, mul(var("m"), var("a")));
}
#[test]
fn test_solve_for_simple_arithmetic() {
let left = add(mul(int(2), var("x")), int(3));
let right = int(7);
let equation = Equation::new("test", left, right);
let known_values = HashMap::new();
let result = solve_for(&equation, "x", &known_values);
assert!(result.is_ok());
let path = result.unwrap();
assert_eq!(path.result, int(2));
}
#[test]
fn test_solve_for_variable_not_in_equation() {
let left = add(int(2), int(3));
let right = int(5);
let equation = Equation::new("test", left, right);
let known_values = HashMap::new();
let result = solve_for(&equation, "x", &known_values);
assert!(result.is_err());
}
#[test]
fn test_can_solve_linear() {
let left = add(mul(int(2), var("x")), int(3));
let right = int(7);
let equation = Equation::new("test", left, right);
let solver = LinearSolver::new();
assert!(solver.can_solve(&equation));
}
#[test]
fn test_cannot_solve_quadratic() {
let left = add(add(pow(var("x"), int(2)), mul(int(2), var("x"))), int(1));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = LinearSolver::new();
assert!(!solver.can_solve(&equation));
}
use thales::solver::{PolynomialSolver, QuadraticSolver, Solution};
#[test]
fn test_quadratic_solver_two_real_roots() {
let left = add(add(pow(var("x"), int(2)), mul(int(-5), var("x"))), int(6));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = QuadraticSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 2);
let vals: Vec<f64> = roots
.iter()
.filter_map(|r| r.evaluate(&HashMap::new()))
.collect();
assert!(vals.iter().any(|v| (v - 2.0).abs() < 1e-10));
assert!(vals.iter().any(|v| (v - 3.0).abs() < 1e-10));
}
_ => panic!("Expected multiple solutions"),
}
}
#[test]
fn test_quadratic_solver_complex_roots() {
let left = add(pow(var("x"), int(2)), int(1));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = QuadraticSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 2);
for root in &roots {
if let Expression::Complex(c) = root {
assert!(c.re.abs() < 1e-10);
assert!((c.im.abs() - 1.0).abs() < 1e-10);
} else {
panic!("Expected complex roots");
}
}
}
_ => panic!("Expected multiple solutions"),
}
}
#[test]
fn test_cubic_solver_x3_minus_1() {
let left = add(pow(var("x"), int(3)), int(-1));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = PolynomialSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 3);
let real_roots: Vec<f64> = roots
.iter()
.filter_map(|r| r.evaluate(&HashMap::new()))
.collect();
assert!(real_roots.iter().any(|v| (v - 1.0).abs() < 1e-10));
}
_ => panic!("Expected multiple solutions"),
}
}
#[test]
fn test_cubic_solver_depressed_cubic() {
let left = add(add(pow(var("x"), int(3)), mul(int(-6), var("x"))), int(-9));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = PolynomialSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 3);
let real_roots: Vec<f64> = roots
.iter()
.filter_map(|r| r.evaluate(&HashMap::new()))
.collect();
assert!(real_roots.iter().any(|v| (v - 3.0).abs() < 1e-10));
}
_ => panic!("Expected multiple solutions"),
}
}
#[test]
fn test_quartic_solver_x4_minus_1() {
let left = add(pow(var("x"), int(4)), int(-1));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = PolynomialSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 4);
let mut real_roots = Vec::new();
let mut complex_roots = Vec::new();
for root in &roots {
match root {
Expression::Integer(n) => real_roots.push(*n as f64),
Expression::Float(f) => real_roots.push(*f),
Expression::Complex(c) if c.im.abs() < 1e-10 => real_roots.push(c.re),
Expression::Complex(_) => complex_roots.push(root.clone()),
_ => {}
}
}
assert!(real_roots.iter().any(|v| (v - 1.0).abs() < 1e-10));
assert!(real_roots.iter().any(|v| (v + 1.0).abs() < 1e-10));
assert_eq!(complex_roots.len(), 2);
}
_ => panic!("Expected multiple solutions"),
}
}
#[test]
fn test_quartic_solver_biquadratic() {
let x4 = pow(var("x"), int(4));
let x2 = pow(var("x"), int(2));
let left = add(add(x4, mul(int(-5), x2)), int(4));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = PolynomialSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 4);
let vals: Vec<f64> = roots
.iter()
.filter_map(|r| r.evaluate(&HashMap::new()))
.collect();
assert!(vals.iter().any(|v| (v - 1.0).abs() < 1e-10));
assert!(vals.iter().any(|v| (v + 1.0).abs() < 1e-10));
assert!(vals.iter().any(|v| (v - 2.0).abs() < 1e-10));
assert!(vals.iter().any(|v| (v + 2.0).abs() < 1e-10));
}
_ => panic!("Expected multiple solutions"),
}
}
#[test]
fn test_polynomial_solver_quintic_numerical() {
let x5 = pow(var("x"), int(5));
let left = add(add(x5, mul(int(-1), var("x"))), int(-1));
let right = int(0);
let equation = Equation::new("test", left, right);
let solver = PolynomialSolver::new();
let result = solver.solve(&equation, &Variable::new("x"));
assert!(result.is_ok());
let (solution, _path) = result.unwrap();
match solution {
Solution::Multiple(roots) => {
assert_eq!(roots.len(), 5);
let real_roots: Vec<f64> = roots
.iter()
.filter_map(|r| r.evaluate(&HashMap::new()))
.filter(|v| v.is_finite())
.collect();
assert!(!real_roots.is_empty());
assert!(real_roots.iter().any(|v| (v - 1.1673).abs() < 0.01));
}
_ => panic!("Expected multiple solutions"),
}
}