mathhook_core/functions/elementary/
sqrt_eval.rs

1//! Square root function evaluation
2
3use crate::core::constants::EPSILON;
4use crate::core::{Expression, Number};
5
6/// Evaluate square root function
7///
8/// # Mathematical Definition
9///
10/// √x = x^(1/2)
11///
12/// # Arguments
13///
14/// * `arg` - Expression to compute square root of
15///
16/// # Returns
17///
18/// Square root expression
19///
20/// # Examples
21///
22/// ```
23/// use mathhook_core::functions::elementary::sqrt_eval::sqrt;
24/// use mathhook_core::expr;
25///
26/// let result = sqrt(&expr!(4));
27/// assert_eq!(result, expr!(2));
28/// ```
29pub fn sqrt(arg: &Expression) -> Expression {
30    match arg {
31        Expression::Number(n) => evaluate_sqrt_number(n),
32        _ => Expression::function("sqrt", vec![arg.clone()]),
33    }
34}
35
36fn evaluate_sqrt_number(n: &Number) -> Expression {
37    match n {
38        Number::Integer(i) if *i >= 0 => {
39            let sqrt_val = (*i as f64).sqrt();
40            if sqrt_val.fract().abs() < EPSILON {
41                Expression::integer(sqrt_val as i64)
42            } else {
43                Expression::float(sqrt_val)
44            }
45        }
46        Number::Integer(i) if *i < 0 => {
47            let abs_val = i.abs();
48            let sqrt_abs = (abs_val as f64).sqrt();
49            if sqrt_abs.fract().abs() < EPSILON {
50                let sqrt_int = sqrt_abs as i64;
51                if sqrt_int == 1 {
52                    Expression::i()
53                } else {
54                    Expression::mul(vec![Expression::integer(sqrt_int), Expression::i()])
55                }
56            } else {
57                Expression::mul(vec![Expression::float(sqrt_abs), Expression::i()])
58            }
59        }
60        Number::Float(f) if *f >= 0.0 => Expression::float(f.sqrt()),
61        Number::Float(f) if *f < 0.0 => {
62            let sqrt_abs = f.abs().sqrt();
63            Expression::mul(vec![Expression::float(sqrt_abs), Expression::i()])
64        }
65        _ => Expression::function("sqrt", vec![Expression::Number(n.clone())]),
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use std::f64::consts::SQRT_2;
73
74    #[test]
75    fn test_sqrt_zero() {
76        assert_eq!(sqrt(&Expression::integer(0)), Expression::integer(0));
77    }
78
79    #[test]
80    fn test_sqrt_perfect_square() {
81        assert_eq!(sqrt(&Expression::integer(4)), Expression::integer(2));
82        assert_eq!(sqrt(&Expression::integer(9)), Expression::integer(3));
83    }
84
85    #[test]
86    fn test_sqrt_non_perfect() {
87        let result = sqrt(&Expression::integer(2));
88        if let Expression::Number(Number::Float(f)) = result {
89            assert!((f - SQRT_2).abs() < 1e-10);
90        } else {
91            panic!("Expected float result");
92        }
93    }
94
95    #[test]
96    fn test_sqrt_negative_one() {
97        assert_eq!(sqrt(&Expression::integer(-1)), Expression::i());
98    }
99
100    #[test]
101    fn test_sqrt_negative_perfect_square() {
102        let result = sqrt(&Expression::integer(-4));
103        assert_eq!(
104            result,
105            Expression::mul(vec![Expression::integer(2), Expression::i()])
106        );
107    }
108
109    #[test]
110    fn test_sqrt_negative_non_perfect() {
111        let result = sqrt(&Expression::integer(-2));
112        if let Expression::Mul(factors) = result {
113            assert_eq!(factors.len(), 2);
114            if let Expression::Number(Number::Float(f)) = &factors[0] {
115                assert!((f - SQRT_2).abs() < 1e-10);
116            } else {
117                panic!("Expected float for sqrt(2)");
118            }
119            assert_eq!(factors[1], Expression::i());
120        } else {
121            panic!("Expected multiplication expression");
122        }
123    }
124
125    #[test]
126    fn test_sqrt_negative_float() {
127        let result = sqrt(&Expression::float(-4.0));
128        if let Expression::Mul(factors) = result {
129            assert_eq!(factors.len(), 2);
130            if let Expression::Number(Number::Float(f)) = &factors[0] {
131                assert!((f - 2.0).abs() < 1e-10);
132            } else {
133                panic!("Expected float for sqrt(4.0)");
134            }
135            assert_eq!(factors[1], Expression::i());
136        } else {
137            panic!("Expected multiplication expression");
138        }
139    }
140}