mathhook_core/functions/elementary/
sqrt_eval.rs1use crate::core::constants::EPSILON;
4use crate::core::{Expression, Number};
5
6pub fn sqrt(arg: &Expression) -> Expression {
30 match arg {
31 Expression::Number(n) => evaluate_sqrt_number(n),
32 _ => Expression::function("sqrt", vec![arg.clone()]),
33 }
34}
35
36fn evaluate_sqrt_number(n: &Number) -> Expression {
37 match n {
38 Number::Integer(i) if *i >= 0 => {
39 let sqrt_val = (*i as f64).sqrt();
40 if sqrt_val.fract().abs() < EPSILON {
41 Expression::integer(sqrt_val as i64)
42 } else {
43 Expression::float(sqrt_val)
44 }
45 }
46 Number::Integer(i) if *i < 0 => {
47 let abs_val = i.abs();
48 let sqrt_abs = (abs_val as f64).sqrt();
49 if sqrt_abs.fract().abs() < EPSILON {
50 let sqrt_int = sqrt_abs as i64;
51 if sqrt_int == 1 {
52 Expression::i()
53 } else {
54 Expression::mul(vec![Expression::integer(sqrt_int), Expression::i()])
55 }
56 } else {
57 Expression::mul(vec![Expression::float(sqrt_abs), Expression::i()])
58 }
59 }
60 Number::Float(f) if *f >= 0.0 => Expression::float(f.sqrt()),
61 Number::Float(f) if *f < 0.0 => {
62 let sqrt_abs = f.abs().sqrt();
63 Expression::mul(vec![Expression::float(sqrt_abs), Expression::i()])
64 }
65 _ => Expression::function("sqrt", vec![Expression::Number(n.clone())]),
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::f64::consts::SQRT_2;
73
74 #[test]
75 fn test_sqrt_zero() {
76 assert_eq!(sqrt(&Expression::integer(0)), Expression::integer(0));
77 }
78
79 #[test]
80 fn test_sqrt_perfect_square() {
81 assert_eq!(sqrt(&Expression::integer(4)), Expression::integer(2));
82 assert_eq!(sqrt(&Expression::integer(9)), Expression::integer(3));
83 }
84
85 #[test]
86 fn test_sqrt_non_perfect() {
87 let result = sqrt(&Expression::integer(2));
88 if let Expression::Number(Number::Float(f)) = result {
89 assert!((f - SQRT_2).abs() < 1e-10);
90 } else {
91 panic!("Expected float result");
92 }
93 }
94
95 #[test]
96 fn test_sqrt_negative_one() {
97 assert_eq!(sqrt(&Expression::integer(-1)), Expression::i());
98 }
99
100 #[test]
101 fn test_sqrt_negative_perfect_square() {
102 let result = sqrt(&Expression::integer(-4));
103 assert_eq!(
104 result,
105 Expression::mul(vec![Expression::integer(2), Expression::i()])
106 );
107 }
108
109 #[test]
110 fn test_sqrt_negative_non_perfect() {
111 let result = sqrt(&Expression::integer(-2));
112 if let Expression::Mul(factors) = result {
113 assert_eq!(factors.len(), 2);
114 if let Expression::Number(Number::Float(f)) = &factors[0] {
115 assert!((f - SQRT_2).abs() < 1e-10);
116 } else {
117 panic!("Expected float for sqrt(2)");
118 }
119 assert_eq!(factors[1], Expression::i());
120 } else {
121 panic!("Expected multiplication expression");
122 }
123 }
124
125 #[test]
126 fn test_sqrt_negative_float() {
127 let result = sqrt(&Expression::float(-4.0));
128 if let Expression::Mul(factors) = result {
129 assert_eq!(factors.len(), 2);
130 if let Expression::Number(Number::Float(f)) = &factors[0] {
131 assert!((f - 2.0).abs() < 1e-10);
132 } else {
133 panic!("Expected float for sqrt(4.0)");
134 }
135 assert_eq!(factors[1], Expression::i());
136 } else {
137 panic!("Expected multiplication expression");
138 }
139 }
140}