use proptest::prelude::*;
use std::collections::HashMap;
use thales::ast::{BinaryOp, Expression, UnaryOp, Variable};
use thales::parser::{parse_equation, parse_expression};
use thales::transforms::{Cartesian2D, Cartesian3D};
proptest! {
#[test]
fn derivative_is_linear(a in -100i32..100i32, b in -100i32..100i32) {
let expr = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Variable(Variable::new("x"))),
)),
Box::new(Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(b as i64)),
Box::new(Expression::Variable(Variable::new("x"))),
)),
);
let derivative = expr.differentiate("x").simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 5.0);
let result = derivative.evaluate(&vars);
if let Some(value) = result {
prop_assert!((value - (a + b) as f64).abs() < 1e-10);
}
}
#[test]
fn product_rule_holds(coef in 1i32..10i32) {
let f = Expression::Power(
Box::new(Expression::Variable(Variable::new("x"))),
Box::new(Expression::Integer(2)),
);
let g = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(coef as i64)),
Box::new(Expression::Variable(Variable::new("x"))),
);
let product = Expression::Binary(BinaryOp::Mul, Box::new(f), Box::new(g));
let derivative = product.differentiate("x").simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 2.0);
if let Some(value) = derivative.evaluate(&vars) {
let expected = 3.0 * coef as f64 * 4.0; prop_assert!((value - expected).abs() < 1e-8);
}
}
#[test]
fn chain_rule_holds(coef in 1i32..10i32) {
let inner = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(coef as i64)),
Box::new(Expression::Variable(Variable::new("x"))),
);
let expr = Expression::Power(
Box::new(inner),
Box::new(Expression::Integer(2)),
);
let derivative = expr.differentiate("x").simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
if let Some(value) = derivative.evaluate(&vars) {
let expected = 2.0 * (coef * coef) as f64 * 3.0;
prop_assert!((value - expected).abs() < 1e-8);
}
}
#[test]
fn constant_multiple_rule(c in -100i32..100i32, exp in 1i32..5i32) {
if c == 0 {
return Ok(()); }
let expr = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(c as i64)),
Box::new(Expression::Power(
Box::new(Expression::Variable(Variable::new("x"))),
Box::new(Expression::Integer(exp as i64)),
)),
);
let derivative = expr.differentiate("x").simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 2.0);
if let Some(value) = derivative.evaluate(&vars) {
let expected = c as f64 * exp as f64 * 2.0_f64.powi(exp - 1);
prop_assert!((value - expected).abs() < 1e-8);
}
}
}
proptest! {
#[test]
fn parse_roundtrip_linear(a in -100i32..100i32, b in -100i32..100i32) {
if a == 0 {
return Ok(()); }
let expr_str = format!("{}*x + {}", a, b);
let parsed = parse_expression(&expr_str);
if let Ok(expr) = parsed {
let rendered = format!("{}", expr);
let reparsed = parse_expression(&rendered);
if let Ok(expr2) = reparsed {
let mut vars = HashMap::new();
vars.insert("x".to_string(), 5.0);
let val1 = expr.evaluate(&vars);
let val2 = expr2.evaluate(&vars);
if let (Some(v1), Some(v2)) = (val1, val2) {
prop_assert!((v1 - v2).abs() < 1e-10);
}
}
}
}
#[test]
fn parse_roundtrip_equation(a in 1i32..100i32, b in -100i32..100i32, c in -100i32..100i32) {
let eq_str = format!("{}*x + {} = {}", a, b, c);
let parsed = parse_equation(&eq_str);
if let Ok(eq) = parsed {
let rendered = format!("{} = {}", eq.left, eq.right);
let reparsed = parse_equation(&rendered);
if let Ok(eq2) = reparsed {
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
let lhs1 = eq.left.evaluate(&vars);
let rhs1 = eq.right.evaluate(&vars);
let lhs2 = eq2.left.evaluate(&vars);
let rhs2 = eq2.right.evaluate(&vars);
if let (Some(l1), Some(r1), Some(l2), Some(r2)) = (lhs1, rhs1, lhs2, rhs2) {
prop_assert!((l1 - l2).abs() < 1e-10);
prop_assert!((r1 - r2).abs() < 1e-10);
}
}
}
}
}
proptest! {
#[test]
fn derivative_matches_finite_difference(
coef in 1i32..20i32,
x in 1.0f64..10.0f64
) {
let expr = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(coef as i64)),
Box::new(Expression::Power(
Box::new(Expression::Variable(Variable::new("x"))),
Box::new(Expression::Integer(2)),
)),
);
let derivative = expr.differentiate("x").simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), x);
let symbolic_value = derivative.evaluate(&vars);
let h = 1e-6;
let mut vars_plus = vars.clone();
let mut vars_minus = vars.clone();
vars_plus.insert("x".to_string(), x + h);
vars_minus.insert("x".to_string(), x - h);
let f_plus = expr.evaluate(&vars_plus);
let f_minus = expr.evaluate(&vars_minus);
if let (Some(sym), Some(fp), Some(fm)) = (symbolic_value, f_plus, f_minus) {
let numerical = (fp - fm) / (2.0 * h);
prop_assert!((sym - numerical).abs() < 1e-4);
}
}
#[test]
fn coordinate_magnitude_preserved(x in -1000.0f64..1000.0f64, y in -1000.0f64..1000.0f64) {
let cart = Cartesian2D::new(x, y);
let polar = cart.to_polar();
let cart_magnitude = cart.magnitude();
let polar_magnitude = polar.r;
prop_assert!((cart_magnitude - polar_magnitude).abs() < 1e-10);
}
#[test]
fn coordinate_3d_magnitude_preserved(
x in -100.0f64..100.0f64,
y in -100.0f64..100.0f64,
z in -100.0f64..100.0f64
) {
let cart = Cartesian3D::new(x, y, z);
let spherical = cart.to_spherical();
let cart_magnitude = cart.magnitude();
let spherical_magnitude = spherical.r;
prop_assert!((cart_magnitude - spherical_magnitude).abs() < 1e-10);
}
}
proptest! {
#[test]
fn addition_is_associative(a in -100i32..100i32, b in -100i32..100i32, c in -100i32..100i32) {
let left = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Integer(b as i64)),
)),
Box::new(Expression::Integer(c as i64)),
);
let right = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(b as i64)),
Box::new(Expression::Integer(c as i64)),
)),
);
let vars = HashMap::new();
let left_val = left.evaluate(&vars);
let right_val = right.evaluate(&vars);
if let (Some(lv), Some(rv)) = (left_val, right_val) {
prop_assert!((lv - rv).abs() < 1e-10);
}
}
#[test]
fn multiplication_is_commutative(a in -100i32..100i32, b in -100i32..100i32) {
let left = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Integer(b as i64)),
);
let right = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(b as i64)),
Box::new(Expression::Integer(a as i64)),
);
let vars = HashMap::new();
let left_val = left.evaluate(&vars);
let right_val = right.evaluate(&vars);
if let (Some(lv), Some(rv)) = (left_val, right_val) {
prop_assert!((lv - rv).abs() < 1e-10);
}
}
#[test]
fn distributive_property(a in -50i32..50i32, b in -50i32..50i32, c in -50i32..50i32) {
let left = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(b as i64)),
Box::new(Expression::Integer(c as i64)),
)),
);
let right = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Integer(b as i64)),
)),
Box::new(Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(a as i64)),
Box::new(Expression::Integer(c as i64)),
)),
);
let vars = HashMap::new();
let left_val = left.simplify().evaluate(&vars);
let right_val = right.simplify().evaluate(&vars);
if let (Some(lv), Some(rv)) = (left_val, right_val) {
prop_assert!((lv - rv).abs() < 1e-10);
}
}
}
#[test]
fn known_trigonometric_values() {
use std::f64::consts::PI;
let sin_0 = Expression::Function(thales::ast::Function::Sin, vec![Expression::Float(0.0)]);
assert!((sin_0.evaluate(&HashMap::new()).unwrap() - 0.0).abs() < 1e-10);
let cos_0 = Expression::Function(thales::ast::Function::Cos, vec![Expression::Float(0.0)]);
assert!((cos_0.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
let sin_pi_2 = Expression::Function(
thales::ast::Function::Sin,
vec![Expression::Float(PI / 2.0)],
);
assert!((sin_pi_2.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
let cos_pi_2 = Expression::Function(
thales::ast::Function::Cos,
vec![Expression::Float(PI / 2.0)],
);
assert!(cos_pi_2.evaluate(&HashMap::new()).unwrap().abs() < 1e-10);
let sin_pi_6 = Expression::Function(
thales::ast::Function::Sin,
vec![Expression::Float(PI / 6.0)],
);
assert!((sin_pi_6.evaluate(&HashMap::new()).unwrap() - 0.5).abs() < 1e-10);
let tan_pi_4 = Expression::Function(
thales::ast::Function::Tan,
vec![Expression::Float(PI / 4.0)],
);
assert!((tan_pi_4.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
}
#[test]
fn known_exponential_values() {
use std::f64::consts::E;
let exp_0 = Expression::Function(thales::ast::Function::Exp, vec![Expression::Float(0.0)]);
assert!((exp_0.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
let exp_1 = Expression::Function(thales::ast::Function::Exp, vec![Expression::Float(1.0)]);
assert!((exp_1.evaluate(&HashMap::new()).unwrap() - E).abs() < 1e-10);
let ln_1 = Expression::Function(thales::ast::Function::Ln, vec![Expression::Float(1.0)]);
assert!((ln_1.evaluate(&HashMap::new()).unwrap() - 0.0).abs() < 1e-10);
let ln_e = Expression::Function(thales::ast::Function::Ln, vec![Expression::Float(E)]);
assert!((ln_e.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
let log10_10 =
Expression::Function(thales::ast::Function::Log10, vec![Expression::Float(10.0)]);
assert!((log10_10.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
let log10_100 =
Expression::Function(thales::ast::Function::Log10, vec![Expression::Float(100.0)]);
assert!((log10_100.evaluate(&HashMap::new()).unwrap() - 2.0).abs() < 1e-10);
}
#[test]
fn known_power_values() {
let expr = Expression::Power(
Box::new(Expression::Integer(2)),
Box::new(Expression::Integer(0)),
);
assert!((expr.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
let expr = Expression::Power(
Box::new(Expression::Integer(2)),
Box::new(Expression::Integer(3)),
);
assert!((expr.evaluate(&HashMap::new()).unwrap() - 8.0).abs() < 1e-10);
let expr = Expression::Power(
Box::new(Expression::Integer(5)),
Box::new(Expression::Integer(2)),
);
assert!((expr.evaluate(&HashMap::new()).unwrap() - 25.0).abs() < 1e-10);
let expr = Expression::Power(
Box::new(Expression::Integer(-1)),
Box::new(Expression::Integer(2)),
);
assert!((expr.evaluate(&HashMap::new()).unwrap() - 1.0).abs() < 1e-10);
}
#[test]
fn edge_case_division_by_zero() {
let expr = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(5)),
Box::new(Expression::Integer(0)),
);
let result = expr.evaluate(&HashMap::new());
assert!(result.is_none() || result.unwrap().is_infinite());
}
#[test]
fn edge_case_empty_input() {
let result = parse_expression("");
assert!(result.is_err());
}
#[test]
fn edge_case_zero_power_zero() {
let expr = Expression::Power(
Box::new(Expression::Integer(0)),
Box::new(Expression::Integer(0)),
);
let result = expr.evaluate(&HashMap::new());
assert!(result.is_some());
}
#[test]
fn edge_case_negative_square_root() {
let expr = Expression::Function(thales::ast::Function::Sqrt, vec![Expression::Integer(-1)]);
let result = expr.evaluate(&HashMap::new());
if let Some(val) = result {
assert!(val.is_nan());
}
}
#[test]
fn edge_case_overflow_protection() {
let expr = Expression::Power(
Box::new(Expression::Integer(10)),
Box::new(Expression::Integer(100)),
);
let result = expr.evaluate(&HashMap::new());
assert!(result.is_some());
if let Some(val) = result {
assert!(val.is_finite() || val.is_infinite());
}
}
#[test]
fn edge_case_zero_cartesian_to_polar() {
let cart = Cartesian2D::new(0.0, 0.0);
let polar = cart.to_polar();
assert!((polar.r - 0.0).abs() < 1e-10);
}
#[test]
fn edge_case_simplification_idempotent() {
let expr = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(0)),
Box::new(Expression::Variable(Variable::new("x"))),
);
let simplified_once = expr.simplify();
let simplified_twice = simplified_once.simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 5.0);
let val1 = simplified_once.evaluate(&vars);
let val2 = simplified_twice.evaluate(&vars);
if let (Some(v1), Some(v2)) = (val1, val2) {
assert!((v1 - v2).abs() < 1e-10);
}
}
#[test]
fn edge_case_unary_negation() {
let x = Expression::Variable(Variable::new("x"));
let neg_x = Expression::Unary(UnaryOp::Neg, Box::new(x.clone()));
let double_neg = Expression::Unary(UnaryOp::Neg, Box::new(neg_x));
let simplified = double_neg.simplify();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 7.0);
let result = simplified.evaluate(&vars);
assert_eq!(result, Some(7.0));
}