mathhook_core/functions/elementary/
sqrt.rs1use crate::core::constants::EPSILON;
7use crate::core::{Expression, Number, Symbol};
8use crate::functions::properties::*;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct SqrtIntelligence {
17 properties: HashMap<String, FunctionProperties>,
18}
19
20impl Default for SqrtIntelligence {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl SqrtIntelligence {
27 pub fn new() -> Self {
38 let mut intelligence = Self {
39 properties: HashMap::with_capacity(1),
40 };
41
42 intelligence.initialize_sqrt();
43 intelligence
44 }
45
46 pub fn get_properties(&self) -> HashMap<String, FunctionProperties> {
58 self.properties.clone()
59 }
60
61 pub fn has_function(&self, name: &str) -> bool {
77 self.properties.contains_key(name)
78 }
79
80 fn initialize_sqrt(&mut self) {
82 self.properties.insert(
83 "sqrt".to_owned(),
84 FunctionProperties::Elementary(Box::new(ElementaryProperties {
85 derivative_rule: Some(DerivativeRule {
86 rule_type: DerivativeRuleType::Custom {
87 builder: Arc::new(|arg: &Expression| {
88 let sqrt_arg = Expression::function("sqrt", vec![arg.clone()]);
89 let denominator =
90 Expression::mul(vec![Expression::integer(2), sqrt_arg]);
91 Expression::mul(vec![
92 Expression::integer(1),
93 Expression::pow(denominator, Expression::integer(-1)),
94 ])
95 }),
96 },
97 result_template: "d/dx sqrt(x) = 1/(2*sqrt(x)) for x > 0".to_owned(),
98 }),
99 antiderivative_rule: Some(AntiderivativeRule {
100 rule_type: AntiderivativeRuleType::Custom {
101 builder: Arc::new(|var: Symbol| {
102 Expression::mul(vec![
103 Expression::rational(2, 3),
104 Expression::pow(
105 Expression::symbol(var),
106 Expression::rational(3, 2),
107 ),
108 ])
109 }),
110 },
111 result_template: "∫sqrt(x)dx = (2/3)x^(3/2) + C".to_owned(),
112 constant_handling: ConstantOfIntegration::AddConstant,
113 }),
114 special_values: vec![
115 SpecialValue {
116 input: "0".to_owned(),
117 output: Expression::integer(0),
118 latex_explanation: "\\sqrt{0} = 0".to_owned(),
119 },
120 SpecialValue {
121 input: "1".to_owned(),
122 output: Expression::integer(1),
123 latex_explanation: "\\sqrt{1} = 1".to_owned(),
124 },
125 SpecialValue {
126 input: "4".to_owned(),
127 output: Expression::integer(2),
128 latex_explanation: "\\sqrt{4} = 2".to_owned(),
129 },
130 SpecialValue {
131 input: "9".to_owned(),
132 output: Expression::integer(3),
133 latex_explanation: "\\sqrt{9} = 3".to_owned(),
134 },
135 ],
136 identities: Box::new(vec![
137 MathIdentity {
138 name: "Product Rule".to_owned(),
139 lhs: Expression::function(
140 "sqrt",
141 vec![Expression::mul(vec![
142 Expression::symbol("a"),
143 Expression::symbol("b"),
144 ])],
145 ),
146 rhs: Expression::mul(vec![
147 Expression::function("sqrt", vec![Expression::symbol("a")]),
148 Expression::function("sqrt", vec![Expression::symbol("b")]),
149 ]),
150 conditions: vec!["a, b ≥ 0".to_owned()],
151 },
152 MathIdentity {
153 name: "Power Simplification".to_owned(),
154 lhs: Expression::function(
155 "sqrt",
156 vec![Expression::pow(
157 Expression::symbol("x"),
158 Expression::integer(2),
159 )],
160 ),
161 rhs: Expression::function("abs", vec![Expression::symbol("x")]),
162 conditions: vec!["x ∈ ℝ".to_owned()],
163 },
164 ]),
165 domain_range: Box::new(DomainRangeData {
166 domain: Domain::Union(vec![
167 Domain::Interval(Expression::integer(0), Expression::infinity()),
168 Domain::Complex,
169 ]),
170 range: Range::Bounded(Expression::integer(0), Expression::infinity()),
171 singularities: vec![],
172 }),
173 periodicity: None,
174 wolfram_name: None,
175 })),
176 );
177 }
178}
179
180pub fn simplify_sqrt(arg: &Expression) -> Expression {
222 match arg {
223 Expression::Number(n) => evaluate_sqrt_number(n),
224
225 Expression::Pow(base, exp) if is_square(exp) => {
226 Expression::function("abs", vec![(**base).clone()])
227 }
228
229 Expression::Pow(base, exp) if is_even_power(exp) => simplify_sqrt_even_power(base, exp),
230
231 Expression::Mul(terms) => simplify_sqrt_product(terms),
232
233 Expression::Function { name, args } if name.as_ref() == "sqrt" && args.len() == 1 => {
234 Expression::function("sqrt", vec![args[0].clone()])
235 }
236
237 _ => Expression::function("sqrt", vec![arg.clone()]),
238 }
239}
240
241fn evaluate_sqrt_number(n: &Number) -> Expression {
243 use num_traits::ToPrimitive;
244
245 match n {
246 Number::Integer(i) => {
247 if *i >= 0 {
248 let sqrt_val = (*i as f64).sqrt();
249 if sqrt_val.fract().abs() < EPSILON {
250 Expression::integer(sqrt_val as i64)
251 } else {
252 Expression::function("sqrt", vec![Expression::integer(*i)])
253 }
254 } else {
255 let pos_sqrt = evaluate_sqrt_number(&Number::Integer(-i));
256 Expression::mul(vec![
257 pos_sqrt,
258 Expression::constant(crate::core::MathConstant::I),
259 ])
260 }
261 }
262 Number::Float(f) => {
263 if *f >= 0.0 {
264 Expression::float(f.sqrt())
265 } else {
266 Expression::mul(vec![
267 Expression::float((-f).sqrt()),
268 Expression::constant(crate::core::MathConstant::I),
269 ])
270 }
271 }
272 Number::BigInteger(bi) => {
273 use num_traits::Signed;
274 if **bi >= num_bigint::BigInt::from(0) {
275 if let Some(i_val) = bi.to_i64() {
276 let sqrt_val = (i_val as f64).sqrt();
277 if sqrt_val.fract().abs() < EPSILON {
278 Expression::integer(sqrt_val as i64)
279 } else {
280 Expression::function("sqrt", vec![Expression::Number(n.clone())])
281 }
282 } else {
283 Expression::function("sqrt", vec![Expression::Number(n.clone())])
284 }
285 } else {
286 let pos_sqrt = evaluate_sqrt_number(&Number::BigInteger(Box::new((**bi).abs())));
287 Expression::mul(vec![
288 pos_sqrt,
289 Expression::constant(crate::core::MathConstant::I),
290 ])
291 }
292 }
293 Number::Rational(r) => {
294 let numer = r.numer();
295 let denom = r.denom();
296
297 if let (Some(n_val), Some(d_val)) = (numer.to_i64(), denom.to_i64()) {
298 let n_sqrt = (n_val as f64).sqrt();
299 let d_sqrt = (d_val as f64).sqrt();
300
301 if n_sqrt.fract().abs() < EPSILON && d_sqrt.fract().abs() < EPSILON {
302 return Expression::rational(n_sqrt as i64, d_sqrt as i64);
303 }
304 }
305
306 Expression::function("sqrt", vec![Expression::Number(n.clone())])
307 }
308 }
309}
310
311fn is_square(exp: &Expression) -> bool {
313 matches!(exp, Expression::Number(Number::Integer(2)))
314}
315
316fn is_even_power(exp: &Expression) -> bool {
318 matches!(exp, Expression::Number(Number::Integer(n)) if n % 2 == 0)
319}
320
321fn simplify_sqrt_even_power(base: &Expression, exp: &Expression) -> Expression {
323 if let Expression::Number(Number::Integer(n)) = exp {
324 Expression::pow(base.clone(), Expression::integer(n / 2))
325 } else {
326 Expression::function("sqrt", vec![Expression::pow(base.clone(), exp.clone())])
327 }
328}
329
330fn simplify_sqrt_product(terms: &[Expression]) -> Expression {
332 let mut perfect_squares = Vec::new();
333 let mut other_terms = Vec::new();
334
335 for term in terms {
336 if let Expression::Pow(base, exp) = term {
337 if is_square(exp) {
338 perfect_squares.push(Expression::function("abs", vec![(**base).clone()]));
339 } else if is_even_power(exp) {
340 if let Expression::Number(Number::Integer(n)) = **exp {
341 perfect_squares.push(Expression::pow(
342 (**base).clone(),
343 Expression::integer(n / 2),
344 ));
345 } else {
346 other_terms.push(term.clone());
347 }
348 } else {
349 other_terms.push(term.clone());
350 }
351 } else if let Expression::Number(n) = term {
352 match evaluate_sqrt_number(n) {
353 expr @ Expression::Number(_) => perfect_squares.push(expr),
354 _ => other_terms.push(term.clone()),
355 }
356 } else {
357 other_terms.push(term.clone());
358 }
359 }
360
361 if perfect_squares.is_empty() {
362 Expression::function("sqrt", vec![Expression::mul(terms.to_vec())])
363 } else if other_terms.is_empty() {
364 Expression::mul(perfect_squares)
365 } else {
366 perfect_squares.push(Expression::function(
367 "sqrt",
368 vec![Expression::mul(other_terms)],
369 ));
370 Expression::mul(perfect_squares)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_sqrt_intelligence_creation() {
380 let intelligence = SqrtIntelligence::new();
381 assert!(intelligence.has_function("sqrt"));
382
383 let props = intelligence.get_properties();
384 assert!(props.contains_key("sqrt"));
385 }
386
387 #[test]
388 fn test_sqrt_properties() {
389 let intelligence = SqrtIntelligence::new();
390 let props = intelligence.get_properties();
391 let sqrt_props = props.get("sqrt").unwrap();
392
393 assert!(sqrt_props.has_derivative());
394 assert!(sqrt_props.has_antiderivative());
395 assert_eq!(sqrt_props.special_value_count(), 4);
396 }
397
398 #[test]
399 fn test_simplify_sqrt_zero() {
400 let result = simplify_sqrt(&Expression::integer(0));
401 assert_eq!(result, Expression::integer(0));
402 }
403
404 #[test]
405 fn test_simplify_sqrt_one() {
406 let result = simplify_sqrt(&Expression::integer(1));
407 assert_eq!(result, Expression::integer(1));
408 }
409
410 #[test]
411 fn test_simplify_sqrt_perfect_square() {
412 let result = simplify_sqrt(&Expression::integer(4));
413 assert_eq!(result, Expression::integer(2));
414
415 let result = simplify_sqrt(&Expression::integer(9));
416 assert_eq!(result, Expression::integer(3));
417 }
418
419 #[test]
420 fn test_simplify_sqrt_square() {
421 let expr = Expression::pow(Expression::symbol("x"), Expression::integer(2));
422 let result = simplify_sqrt(&expr);
423 assert_eq!(
424 result,
425 Expression::function("abs", vec![Expression::symbol("x")])
426 );
427 }
428
429 #[test]
430 fn test_simplify_sqrt_even_power() {
431 let expr = Expression::pow(Expression::symbol("x"), Expression::integer(4));
432 let result = simplify_sqrt(&expr);
433 assert_eq!(
434 result,
435 Expression::pow(Expression::symbol("x"), Expression::integer(2))
436 );
437 }
438}