use std::collections::HashMap;
use rand_v10::RngExt;
use rand_v10::rng;
use crate::numerical::elementary::eval_expr;
use crate::numerical::integrate::QuadratureMethod;
use crate::numerical::integrate::quadrature;
use crate::symbolic::calculus::differentiate;
use crate::symbolic::calculus::substitute;
use crate::symbolic::core::Expr;
use crate::symbolic::matrix;
use crate::symbolic::simplify_dag::simplify;
const TOLERANCE: f64 = 1e-6;
const NUM_SAMPLES: usize = 100;
#[must_use]
pub fn verify_equation_solution<S: std::hash::BuildHasher>(
equations: &[Expr],
solution: &HashMap<String, Expr, S>,
free_vars: &[&str],
) -> bool {
let mut rng = rng();
for eq in equations {
let unwrapped_eq = unwrap_dag(eq.clone());
let diff = if let Expr::Eq(lhs, rhs) = unwrapped_eq {
simplify(&Expr::new_sub(lhs.clone(), rhs.clone()))
} else {
unwrapped_eq.clone()
};
for _ in 0..NUM_SAMPLES {
let mut current_vars = HashMap::new();
for var in free_vars {
current_vars.insert((*var).to_string(), rng.random_range(-10.0..10.0));
}
let mut substituted_expr = diff.clone();
for (var, sol_expr) in solution {
substituted_expr = substitute(&substituted_expr, var, sol_expr);
}
match eval_expr(&simplify(&substituted_expr), ¤t_vars) {
| Ok(val) => {
if val.abs() > TOLERANCE {
return false;
}
},
| Err(_) => {
continue;
},
}
}
}
true
}
pub(crate) fn unwrap_dag(expr: Expr) -> Expr {
match expr {
| Expr::Dag(node) => node.to_expr().unwrap_or(Expr::Dag(node)),
| _ => expr,
}
}
#[must_use]
pub fn verify_indefinite_integral(
integrand: &Expr,
integral_result: &Expr,
var: &str,
) -> bool {
let derivative_of_result = differentiate(integral_result, var);
let diff = simplify(&Expr::new_sub(integrand.clone(), derivative_of_result));
let mut rng = rng();
let mut success_count = 0;
let mut attempt_count = 0;
while success_count < NUM_SAMPLES && attempt_count < NUM_SAMPLES * 2 {
let mut vars = HashMap::new();
let x_val = rng.random_range(-10.0..10.0);
vars.insert(var.to_string(), x_val);
if let Ok(val) = eval_expr(&diff, &vars) {
if val.abs() > TOLERANCE {
return false;
}
success_count += 1;
}
attempt_count += 1;
}
success_count > 0
}
#[must_use]
pub fn verify_definite_integral(
integrand: &Expr,
var: &str,
range: (f64, f64),
symbolic_result: &Expr,
) -> bool {
let symbolic_val = match eval_expr(&simplify(symbolic_result), &HashMap::new()) {
| Ok(v) => v,
| Err(_) => return false,
};
quadrature(integrand, var, range, 1000, &QuadratureMethod::Simpson)
.is_ok_and(|numerical_val| (symbolic_val - numerical_val).abs() < TOLERANCE)
}
#[must_use]
pub fn verify_ode_solution(
ode: &Expr,
solution: &Expr,
func_name: &str,
var: &str,
) -> bool {
let unwrapped_ode = unwrap_dag(ode.clone());
let eq_zero = if let Expr::Eq(lhs, rhs) = unwrapped_ode {
Expr::new_sub(lhs, rhs)
} else {
unwrapped_ode
};
let mut rng = rng();
for _ in 0..NUM_SAMPLES {
let x_val = rng.random_range(-10.0..10.0);
let mut vars = HashMap::new();
vars.insert(var.to_string(), x_val);
let mut substituted_ode = simplify(&eq_zero);
let y = solution.clone();
let y_prime = differentiate(&y, var);
let y_double_prime = differentiate(&y_prime, var);
substituted_ode = substitute(&substituted_ode, func_name, &y);
substituted_ode = substitute(&substituted_ode, &format!("{func_name}'"), &y_prime);
substituted_ode = substitute(&substituted_ode, &format!("{func_name}''"), &y_double_prime);
match eval_expr(&simplify(&substituted_ode), &vars) {
| Ok(val) => {
if val.abs() > TOLERANCE * 10.0 {
return false;
}
},
| Err(_) => continue,
}
}
true
}
#[must_use]
pub fn verify_matrix_inverse(
original: &Expr,
inverse: &Expr,
) -> bool {
let product = matrix::mul_matrices(original, inverse);
let simplified_product = unwrap_dag(simplify(&product));
if let Expr::Matrix(prod_mat) = simplified_product {
let _n = prod_mat.len();
for (i, row) in prod_mat.iter().enumerate() {
for (j, item) in row.iter().enumerate() {
let expected = if i == j { 1.0 } else { 0.0 };
match eval_expr(item, &HashMap::new()) {
| Ok(val) => {
if (val - expected).abs() > TOLERANCE {
return false;
}
},
| Err(_) => {
return false;
},
}
}
}
return true;
}
false
}
#[must_use]
pub fn verify_derivative(
original_func: &Expr,
derivative_func: &Expr,
var: &str,
) -> bool {
let mut rng = rng();
for _ in 0..NUM_SAMPLES {
let x_val = rng.random_range(-10.0..10.0);
let mut vars_map = HashMap::new();
vars_map.insert(var.to_string(), x_val);
let symbolic_deriv_val = match eval_expr(derivative_func, &vars_map) {
| Ok(v) => v,
| Err(_) => continue,
};
let numerical_deriv_val =
match crate::numerical::calculus::gradient(original_func, &[var], &[x_val]) {
| Ok(grad_vec) => grad_vec[0],
| Err(_) => continue,
};
if (symbolic_deriv_val - numerical_deriv_val).abs() > TOLERANCE * 100.0 {
return false;
}
}
true
}
#[must_use]
pub fn verify_limit(
f: &Expr,
var: &str,
target: &Expr,
limit_val: &Expr,
) -> bool {
let x0 = match eval_expr(&simplify(target), &HashMap::new()) {
| Ok(v) => v,
| Err(_) => return false,
};
let l = match eval_expr(&simplify(limit_val), &HashMap::new()) {
| Ok(v) => v,
| Err(_) => return false,
};
let epsilons = [1e-3, 1e-5, 1e-7];
for &eps in &epsilons {
let mut vars = HashMap::new();
vars.insert(var.to_string(), x0 + eps);
if let Ok(val) = eval_expr(f, &vars) {
if (val - l).abs() > eps.mul_add(100.0, TOLERANCE) {
return false;
}
}
vars.insert(var.to_string(), x0 - eps);
if let Ok(val) = eval_expr(f, &vars) {
if (val - l).abs() > eps.mul_add(100.0, TOLERANCE) {
return false;
}
}
}
true
}