use crate::ast::{BinaryOp, Expression, Function, UnaryOp, Variable};
use crate::resolution_path::{Operation, ResolutionStep};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct ApproxResult {
pub approximation: Expression,
pub error_bound: f64,
pub valid_range: (f64, f64),
pub formula_used: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ApproxType {
SmallAngleSin,
SmallAngleCos,
SmallAngleTan,
SmallAngle1MinusCos,
ScaledExp(f64),
PythagoreanSmall,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ScaledExpForm {
Standard,
Scaled01,
Scaled001,
Custom(f64),
}
impl ScaledExpForm {
pub fn scaling_factor(&self) -> f64 {
match self {
ScaledExpForm::Standard => 1.0,
ScaledExpForm::Scaled01 => 0.1,
ScaledExpForm::Scaled001 => 0.01,
ScaledExpForm::Custom(k) => *k,
}
}
}
pub fn apply_small_angle_approx(
expr: &Expression,
var: &Variable,
threshold: f64,
) -> Option<ApproxResult> {
match expr {
Expression::Function(func, args) if args.len() == 1 => {
let arg = &args[0];
if !is_small_angle_candidate(arg, var) {
return None;
}
match func {
Function::Sin => {
Some(ApproxResult {
approximation: arg.clone(),
error_bound: threshold.powi(3) / 6.0,
valid_range: (-threshold, threshold),
formula_used: "sin(θ) ≈ θ".to_string(),
})
}
Function::Cos => {
let theta_squared =
Expression::Power(Box::new(arg.clone()), Box::new(Expression::Integer(2)));
let term = Expression::Binary(
BinaryOp::Div,
Box::new(theta_squared),
Box::new(Expression::Integer(2)),
);
let approximation = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Integer(1)),
Box::new(term),
);
Some(ApproxResult {
approximation,
error_bound: threshold.powi(4) / 24.0,
valid_range: (-threshold, threshold),
formula_used: "cos(θ) ≈ 1 - θ²/2".to_string(),
})
}
Function::Tan => {
Some(ApproxResult {
approximation: arg.clone(),
error_bound: threshold.powi(3) / 3.0,
valid_range: (-threshold, threshold),
formula_used: "tan(θ) ≈ θ".to_string(),
})
}
_ => None,
}
}
Expression::Binary(BinaryOp::Sub, left, right) => {
if let (Expression::Integer(1), Expression::Function(Function::Cos, args)) =
(left.as_ref(), right.as_ref())
{
if args.len() == 1 && is_small_angle_candidate(&args[0], var) {
let arg = &args[0];
let theta_squared =
Expression::Power(Box::new(arg.clone()), Box::new(Expression::Integer(2)));
let approximation = Expression::Binary(
BinaryOp::Div,
Box::new(theta_squared),
Box::new(Expression::Integer(2)),
);
return Some(ApproxResult {
approximation,
error_bound: threshold.powi(4) / 12.0,
valid_range: (-threshold, threshold),
formula_used: "1 - cos(θ) ≈ θ²/2".to_string(),
});
}
}
None
}
_ => None,
}
}
fn is_small_angle_candidate(expr: &Expression, var: &Variable) -> bool {
match expr {
Expression::Variable(v) => v == var,
Expression::Unary(UnaryOp::Neg, inner) => is_small_angle_candidate(inner, var),
Expression::Binary(BinaryOp::Mul, left, right) => {
matches!(left.as_ref(), Expression::Integer(_) | Expression::Float(_))
&& matches!(right.as_ref(), Expression::Variable(v) if v == var)
|| matches!(
right.as_ref(),
Expression::Integer(_) | Expression::Float(_)
) && matches!(left.as_ref(), Expression::Variable(v) if v == var)
}
_ => false,
}
}
pub fn compute_approximation_error(
exact: &Expression,
approx: &Expression,
var: &Variable,
value: f64,
) -> f64 {
let mut vars = HashMap::new();
vars.insert(var.name.clone(), value);
let exact_val = exact.evaluate(&vars).unwrap_or(0.0);
let approx_val = approx.evaluate(&vars).unwrap_or(0.0);
(exact_val - approx_val).abs()
}
pub fn select_exp_scaling(argument_range: (f64, f64)) -> ScaledExpForm {
let (min, max) = argument_range;
let range_magnitude = (max - min).abs();
if range_magnitude <= 20.0 {
ScaledExpForm::Standard
} else if range_magnitude <= 200.0 {
ScaledExpForm::Scaled01
} else {
ScaledExpForm::Scaled001
}
}
pub fn is_approximation_valid(approx_type: &ApproxType, variable_value: f64) -> bool {
let abs_value = variable_value.abs();
match approx_type {
ApproxType::SmallAngleSin
| ApproxType::SmallAngleCos
| ApproxType::SmallAngleTan
| ApproxType::SmallAngle1MinusCos => {
abs_value < 0.2
}
ApproxType::ScaledExp(_) => {
abs_value < 700.0 }
ApproxType::PythagoreanSmall => {
abs_value < 0.1
}
}
}
pub fn generate_approximation_step(
original: &Expression,
approximation: &Expression,
error_bound: f64,
formula_used: String,
) -> ResolutionStep {
let explanation = format!(
"Apply approximation: {}. Error bound: {:.2e}",
formula_used, error_bound
);
ResolutionStep::new(
Operation::ApproximationSubstitution {
original: original.clone(),
approximation: approximation.clone(),
error_bound,
},
explanation,
approximation.clone(),
)
}
pub fn optimize_pythagorean(expr: &Expression) -> Option<Expression> {
if let Expression::Function(Function::Sqrt, args) = expr {
if args.len() != 1 {
return None;
}
if let Expression::Binary(BinaryOp::Sub, left, right) = &args[0] {
if matches!(left.as_ref(), Expression::Integer(1)) {
if let Expression::Power(base, exp) = right.as_ref() {
if matches!(exp.as_ref(), Expression::Integer(2)) {
let one = Expression::Integer(1);
let one_minus_x = Expression::Binary(
BinaryOp::Sub,
Box::new(one.clone()),
Box::new(base.as_ref().clone()),
);
let one_plus_x = Expression::Binary(
BinaryOp::Add,
Box::new(one),
Box::new(base.as_ref().clone()),
);
let optimized = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Function(Function::Sqrt, vec![one_minus_x])),
Box::new(Expression::Function(Function::Sqrt, vec![one_plus_x])),
);
return Some(optimized);
}
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_small_angle_sin_approximation() {
let x = Variable::new("x");
let sin_x = Expression::Function(Function::Sin, vec![Expression::Variable(x.clone())]);
let result = apply_small_angle_approx(&sin_x, &x, 0.1);
assert!(result.is_some());
let approx = result.unwrap();
assert_eq!(approx.approximation, Expression::Variable(x.clone()));
assert!(approx.error_bound < 2e-4); assert_eq!(approx.valid_range, (-0.1, 0.1));
assert_eq!(approx.formula_used, "sin(θ) ≈ θ");
}
#[test]
fn test_small_angle_cos_approximation() {
let x = Variable::new("x");
let cos_x = Expression::Function(Function::Cos, vec![Expression::Variable(x.clone())]);
let result = apply_small_angle_approx(&cos_x, &x, 0.1);
assert!(result.is_some());
let approx = result.unwrap();
assert!(approx.error_bound < 5e-5); assert_eq!(approx.formula_used, "cos(θ) ≈ 1 - θ²/2");
}
#[test]
fn test_small_angle_tan_approximation() {
let x = Variable::new("x");
let tan_x = Expression::Function(Function::Tan, vec![Expression::Variable(x.clone())]);
let result = apply_small_angle_approx(&tan_x, &x, 0.1);
assert!(result.is_some());
let approx = result.unwrap();
assert_eq!(approx.approximation, Expression::Variable(x));
assert!(approx.error_bound < 4e-4); assert_eq!(approx.formula_used, "tan(θ) ≈ θ");
}
#[test]
fn test_one_minus_cos_approximation() {
let x = Variable::new("x");
let cos_x = Expression::Function(Function::Cos, vec![Expression::Variable(x.clone())]);
let one_minus_cos = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Integer(1)),
Box::new(cos_x),
);
let result = apply_small_angle_approx(&one_minus_cos, &x, 0.1);
assert!(result.is_some());
let approx = result.unwrap();
assert_eq!(approx.formula_used, "1 - cos(θ) ≈ θ²/2");
assert!(approx.error_bound < 1e-4);
}
#[test]
fn test_error_computation() {
let x = Variable::new("x");
let exact = Expression::Function(Function::Sin, vec![Expression::Variable(x.clone())]);
let approx = Expression::Variable(x.clone());
let error = compute_approximation_error(&exact, &approx, &x, 0.05);
assert!(error < 2.1e-5);
let error = compute_approximation_error(&exact, &approx, &x, 0.1);
assert!(error < 2e-4);
}
#[test]
fn test_approximation_validity() {
assert!(is_approximation_valid(&ApproxType::SmallAngleSin, 0.05));
assert!(is_approximation_valid(&ApproxType::SmallAngleSin, 0.1));
assert!(!is_approximation_valid(&ApproxType::SmallAngleSin, 1.0));
assert!(!is_approximation_valid(&ApproxType::SmallAngleCos, 0.5));
assert!(is_approximation_valid(&ApproxType::PythagoreanSmall, 0.01));
assert!(!is_approximation_valid(&ApproxType::PythagoreanSmall, 0.5));
}
#[test]
fn test_exp_scaling_selection() {
let form = select_exp_scaling((0.0, 10.0));
assert_eq!(form, ScaledExpForm::Standard);
let form = select_exp_scaling((0.0, 100.0));
assert_eq!(form, ScaledExpForm::Scaled01);
let form = select_exp_scaling((0.0, 500.0));
assert_eq!(form, ScaledExpForm::Scaled001);
}
#[test]
fn test_scaled_exp_form_factor() {
assert_eq!(ScaledExpForm::Standard.scaling_factor(), 1.0);
assert_eq!(ScaledExpForm::Scaled01.scaling_factor(), 0.1);
assert_eq!(ScaledExpForm::Scaled001.scaling_factor(), 0.01);
assert_eq!(ScaledExpForm::Custom(0.5).scaling_factor(), 0.5);
}
#[test]
fn test_pythagorean_optimization() {
let x = Expression::Variable(Variable::new("x"));
let x_squared = Expression::Power(Box::new(x.clone()), Box::new(Expression::Integer(2)));
let one_minus_x_squared = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Integer(1)),
Box::new(x_squared),
);
let sqrt_expr = Expression::Function(Function::Sqrt, vec![one_minus_x_squared]);
let optimized = optimize_pythagorean(&sqrt_expr);
assert!(optimized.is_some());
if let Some(Expression::Binary(BinaryOp::Mul, left, right)) = optimized {
assert!(matches!(
left.as_ref(),
Expression::Function(Function::Sqrt, _)
));
assert!(matches!(
right.as_ref(),
Expression::Function(Function::Sqrt, _)
));
} else {
panic!("Expected multiplication of two square roots");
}
}
#[test]
fn test_approximation_error_bounds_are_conservative() {
let x = Variable::new("x");
let sin_x = Expression::Function(Function::Sin, vec![Expression::Variable(x.clone())]);
let result = apply_small_angle_approx(&sin_x, &x, 0.1).unwrap();
for value in [0.01, 0.05, 0.09] {
let actual_error =
compute_approximation_error(&sin_x, &result.approximation, &x, value);
assert!(
actual_error <= result.error_bound,
"Actual error {} exceeds stated bound {} at value {}",
actual_error,
result.error_bound,
value
);
}
}
#[test]
fn test_approximation_step_generation() {
let x = Variable::new("x");
let original = Expression::Function(Function::Sin, vec![Expression::Variable(x.clone())]);
let approx = Expression::Variable(x);
let step = generate_approximation_step(&original, &approx, 1e-5, "sin(θ) ≈ θ".to_string());
assert!(step.explanation.contains("sin(θ) ≈ θ"));
assert!(step.explanation.contains("Error bound"));
assert_eq!(step.result, approx);
}
}