mathhook_core/functions/elementary/
sqrt.rs

1//! Square root function intelligence
2//!
3//! Complete mathematical intelligence for the square root function
4//! with derivatives, antiderivatives, special values, and simplification rules.
5
6use crate::core::constants::EPSILON;
7use crate::core::{Expression, Number, Symbol};
8use crate::functions::properties::*;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Square Root Function Intelligence
13///
14/// Dedicated intelligence system for the square root function
15/// with complete mathematical properties.
16pub struct SqrtIntelligence {
17    properties: HashMap<String, FunctionProperties>,
18}
19
20impl Default for SqrtIntelligence {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl SqrtIntelligence {
27    /// Create new square root intelligence system
28    ///
29    /// # Examples
30    ///
31    /// ```
32    /// use mathhook_core::functions::elementary::sqrt::SqrtIntelligence;
33    ///
34    /// let intelligence = SqrtIntelligence::new();
35    /// assert!(intelligence.has_function("sqrt"));
36    /// ```
37    pub fn new() -> Self {
38        let mut intelligence = Self {
39            properties: HashMap::with_capacity(1),
40        };
41
42        intelligence.initialize_sqrt();
43        intelligence
44    }
45
46    /// Get square root function properties
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use mathhook_core::functions::elementary::sqrt::SqrtIntelligence;
52    ///
53    /// let intelligence = SqrtIntelligence::new();
54    /// let props = intelligence.get_properties();
55    /// assert!(props.contains_key("sqrt"));
56    /// ```
57    pub fn get_properties(&self) -> HashMap<String, FunctionProperties> {
58        self.properties.clone()
59    }
60
61    /// Check if function is square root
62    ///
63    /// # Arguments
64    ///
65    /// * `name` - The function name to check
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use mathhook_core::functions::elementary::sqrt::SqrtIntelligence;
71    ///
72    /// let intelligence = SqrtIntelligence::new();
73    /// assert!(intelligence.has_function("sqrt"));
74    /// assert!(!intelligence.has_function("sin"));
75    /// ```
76    pub fn has_function(&self, name: &str) -> bool {
77        self.properties.contains_key(name)
78    }
79
80    /// Initialize square root function
81    fn initialize_sqrt(&mut self) {
82        self.properties.insert(
83            "sqrt".to_owned(),
84            FunctionProperties::Elementary(Box::new(ElementaryProperties {
85                derivative_rule: Some(DerivativeRule {
86                    rule_type: DerivativeRuleType::Custom {
87                        builder: Arc::new(|arg: &Expression| {
88                            let sqrt_arg = Expression::function("sqrt", vec![arg.clone()]);
89                            let denominator =
90                                Expression::mul(vec![Expression::integer(2), sqrt_arg]);
91                            Expression::mul(vec![
92                                Expression::integer(1),
93                                Expression::pow(denominator, Expression::integer(-1)),
94                            ])
95                        }),
96                    },
97                    result_template: "d/dx sqrt(x) = 1/(2*sqrt(x)) for x > 0".to_owned(),
98                }),
99                antiderivative_rule: Some(AntiderivativeRule {
100                    rule_type: AntiderivativeRuleType::Custom {
101                        builder: Arc::new(|var: Symbol| {
102                            Expression::mul(vec![
103                                Expression::rational(2, 3),
104                                Expression::pow(
105                                    Expression::symbol(var),
106                                    Expression::rational(3, 2),
107                                ),
108                            ])
109                        }),
110                    },
111                    result_template: "∫sqrt(x)dx = (2/3)x^(3/2) + C".to_owned(),
112                    constant_handling: ConstantOfIntegration::AddConstant,
113                }),
114                special_values: vec![
115                    SpecialValue {
116                        input: "0".to_owned(),
117                        output: Expression::integer(0),
118                        latex_explanation: "\\sqrt{0} = 0".to_owned(),
119                    },
120                    SpecialValue {
121                        input: "1".to_owned(),
122                        output: Expression::integer(1),
123                        latex_explanation: "\\sqrt{1} = 1".to_owned(),
124                    },
125                    SpecialValue {
126                        input: "4".to_owned(),
127                        output: Expression::integer(2),
128                        latex_explanation: "\\sqrt{4} = 2".to_owned(),
129                    },
130                    SpecialValue {
131                        input: "9".to_owned(),
132                        output: Expression::integer(3),
133                        latex_explanation: "\\sqrt{9} = 3".to_owned(),
134                    },
135                ],
136                identities: Box::new(vec![
137                    MathIdentity {
138                        name: "Product Rule".to_owned(),
139                        lhs: Expression::function(
140                            "sqrt",
141                            vec![Expression::mul(vec![
142                                Expression::symbol("a"),
143                                Expression::symbol("b"),
144                            ])],
145                        ),
146                        rhs: Expression::mul(vec![
147                            Expression::function("sqrt", vec![Expression::symbol("a")]),
148                            Expression::function("sqrt", vec![Expression::symbol("b")]),
149                        ]),
150                        conditions: vec!["a, b ≥ 0".to_owned()],
151                    },
152                    MathIdentity {
153                        name: "Power Simplification".to_owned(),
154                        lhs: Expression::function(
155                            "sqrt",
156                            vec![Expression::pow(
157                                Expression::symbol("x"),
158                                Expression::integer(2),
159                            )],
160                        ),
161                        rhs: Expression::function("abs", vec![Expression::symbol("x")]),
162                        conditions: vec!["x ∈ ℝ".to_owned()],
163                    },
164                ]),
165                domain_range: Box::new(DomainRangeData {
166                    domain: Domain::Union(vec![
167                        Domain::Interval(Expression::integer(0), Expression::infinity()),
168                        Domain::Complex,
169                    ]),
170                    range: Range::Bounded(Expression::integer(0), Expression::infinity()),
171                    singularities: vec![],
172                }),
173                periodicity: None,
174                wolfram_name: None,
175            })),
176        );
177    }
178}
179
180/// Simplify square root expressions
181///
182/// Applies mathematical simplification rules for square root.
183///
184/// # Simplification Rules
185///
186/// - sqrt(0) = 0
187/// - sqrt(1) = 1
188/// - sqrt(4) = 2, sqrt(9) = 3, etc. (perfect squares)
189/// - sqrt(x²) = |x|
190/// - sqrt(x⁴) = x² (even powers)
191/// - sqrt(a*b) = sqrt(a)*sqrt(b) (when a, b ≥ 0)
192/// - sqrt(a²*b) = a*sqrt(b) (factor perfect squares)
193/// - sqrt(1/4) = 1/2 (rational perfect squares)
194///
195/// # Arguments
196///
197/// * `arg` - The argument to the square root function
198///
199/// # Returns
200///
201/// Simplified expression
202///
203/// # Examples
204///
205/// ```
206/// use mathhook_core::core::Expression;
207/// use mathhook_core::functions::elementary::sqrt::simplify_sqrt;
208///
209/// let zero = Expression::integer(0);
210/// assert_eq!(simplify_sqrt(&zero), Expression::integer(0));
211///
212/// let four = Expression::integer(4);
213/// assert_eq!(simplify_sqrt(&four), Expression::integer(2));
214///
215/// let squared = Expression::pow(Expression::symbol("x"), Expression::integer(2));
216/// assert_eq!(
217///     simplify_sqrt(&squared),
218///     Expression::function("abs", vec![Expression::symbol("x")])
219/// );
220/// ```
221pub fn simplify_sqrt(arg: &Expression) -> Expression {
222    match arg {
223        Expression::Number(n) => evaluate_sqrt_number(n),
224
225        Expression::Pow(base, exp) if is_square(exp) => {
226            Expression::function("abs", vec![(**base).clone()])
227        }
228
229        Expression::Pow(base, exp) if is_even_power(exp) => simplify_sqrt_even_power(base, exp),
230
231        Expression::Mul(terms) => simplify_sqrt_product(terms),
232
233        Expression::Function { name, args } if name.as_ref() == "sqrt" && args.len() == 1 => {
234            Expression::function("sqrt", vec![args[0].clone()])
235        }
236
237        _ => Expression::function("sqrt", vec![arg.clone()]),
238    }
239}
240
241/// Evaluate square root for numeric arguments
242fn evaluate_sqrt_number(n: &Number) -> Expression {
243    use num_traits::ToPrimitive;
244
245    match n {
246        Number::Integer(i) => {
247            if *i >= 0 {
248                let sqrt_val = (*i as f64).sqrt();
249                if sqrt_val.fract().abs() < EPSILON {
250                    Expression::integer(sqrt_val as i64)
251                } else {
252                    Expression::function("sqrt", vec![Expression::integer(*i)])
253                }
254            } else {
255                let pos_sqrt = evaluate_sqrt_number(&Number::Integer(-i));
256                Expression::mul(vec![
257                    pos_sqrt,
258                    Expression::constant(crate::core::MathConstant::I),
259                ])
260            }
261        }
262        Number::Float(f) => {
263            if *f >= 0.0 {
264                Expression::float(f.sqrt())
265            } else {
266                Expression::mul(vec![
267                    Expression::float((-f).sqrt()),
268                    Expression::constant(crate::core::MathConstant::I),
269                ])
270            }
271        }
272        Number::BigInteger(bi) => {
273            use num_traits::Signed;
274            if **bi >= num_bigint::BigInt::from(0) {
275                if let Some(i_val) = bi.to_i64() {
276                    let sqrt_val = (i_val as f64).sqrt();
277                    if sqrt_val.fract().abs() < EPSILON {
278                        Expression::integer(sqrt_val as i64)
279                    } else {
280                        Expression::function("sqrt", vec![Expression::Number(n.clone())])
281                    }
282                } else {
283                    Expression::function("sqrt", vec![Expression::Number(n.clone())])
284                }
285            } else {
286                let pos_sqrt = evaluate_sqrt_number(&Number::BigInteger(Box::new((**bi).abs())));
287                Expression::mul(vec![
288                    pos_sqrt,
289                    Expression::constant(crate::core::MathConstant::I),
290                ])
291            }
292        }
293        Number::Rational(r) => {
294            let numer = r.numer();
295            let denom = r.denom();
296
297            if let (Some(n_val), Some(d_val)) = (numer.to_i64(), denom.to_i64()) {
298                let n_sqrt = (n_val as f64).sqrt();
299                let d_sqrt = (d_val as f64).sqrt();
300
301                if n_sqrt.fract().abs() < EPSILON && d_sqrt.fract().abs() < EPSILON {
302                    return Expression::rational(n_sqrt as i64, d_sqrt as i64);
303                }
304            }
305
306            Expression::function("sqrt", vec![Expression::Number(n.clone())])
307        }
308    }
309}
310
311/// Check if exponent represents squaring (power of 2)
312fn is_square(exp: &Expression) -> bool {
313    matches!(exp, Expression::Number(Number::Integer(2)))
314}
315
316/// Check if exponent is an even integer
317fn is_even_power(exp: &Expression) -> bool {
318    matches!(exp, Expression::Number(Number::Integer(n)) if n % 2 == 0)
319}
320
321/// Simplify sqrt of even powers: sqrt(x⁴) = x²
322fn simplify_sqrt_even_power(base: &Expression, exp: &Expression) -> Expression {
323    if let Expression::Number(Number::Integer(n)) = exp {
324        Expression::pow(base.clone(), Expression::integer(n / 2))
325    } else {
326        Expression::function("sqrt", vec![Expression::pow(base.clone(), exp.clone())])
327    }
328}
329
330/// Simplify square root of a product: sqrt(a*b) = sqrt(a)*sqrt(b)
331fn simplify_sqrt_product(terms: &[Expression]) -> Expression {
332    let mut perfect_squares = Vec::new();
333    let mut other_terms = Vec::new();
334
335    for term in terms {
336        if let Expression::Pow(base, exp) = term {
337            if is_square(exp) {
338                perfect_squares.push(Expression::function("abs", vec![(**base).clone()]));
339            } else if is_even_power(exp) {
340                if let Expression::Number(Number::Integer(n)) = **exp {
341                    perfect_squares.push(Expression::pow(
342                        (**base).clone(),
343                        Expression::integer(n / 2),
344                    ));
345                } else {
346                    other_terms.push(term.clone());
347                }
348            } else {
349                other_terms.push(term.clone());
350            }
351        } else if let Expression::Number(n) = term {
352            match evaluate_sqrt_number(n) {
353                expr @ Expression::Number(_) => perfect_squares.push(expr),
354                _ => other_terms.push(term.clone()),
355            }
356        } else {
357            other_terms.push(term.clone());
358        }
359    }
360
361    if perfect_squares.is_empty() {
362        Expression::function("sqrt", vec![Expression::mul(terms.to_vec())])
363    } else if other_terms.is_empty() {
364        Expression::mul(perfect_squares)
365    } else {
366        perfect_squares.push(Expression::function(
367            "sqrt",
368            vec![Expression::mul(other_terms)],
369        ));
370        Expression::mul(perfect_squares)
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_sqrt_intelligence_creation() {
380        let intelligence = SqrtIntelligence::new();
381        assert!(intelligence.has_function("sqrt"));
382
383        let props = intelligence.get_properties();
384        assert!(props.contains_key("sqrt"));
385    }
386
387    #[test]
388    fn test_sqrt_properties() {
389        let intelligence = SqrtIntelligence::new();
390        let props = intelligence.get_properties();
391        let sqrt_props = props.get("sqrt").unwrap();
392
393        assert!(sqrt_props.has_derivative());
394        assert!(sqrt_props.has_antiderivative());
395        assert_eq!(sqrt_props.special_value_count(), 4);
396    }
397
398    #[test]
399    fn test_simplify_sqrt_zero() {
400        let result = simplify_sqrt(&Expression::integer(0));
401        assert_eq!(result, Expression::integer(0));
402    }
403
404    #[test]
405    fn test_simplify_sqrt_one() {
406        let result = simplify_sqrt(&Expression::integer(1));
407        assert_eq!(result, Expression::integer(1));
408    }
409
410    #[test]
411    fn test_simplify_sqrt_perfect_square() {
412        let result = simplify_sqrt(&Expression::integer(4));
413        assert_eq!(result, Expression::integer(2));
414
415        let result = simplify_sqrt(&Expression::integer(9));
416        assert_eq!(result, Expression::integer(3));
417    }
418
419    #[test]
420    fn test_simplify_sqrt_square() {
421        let expr = Expression::pow(Expression::symbol("x"), Expression::integer(2));
422        let result = simplify_sqrt(&expr);
423        assert_eq!(
424            result,
425            Expression::function("abs", vec![Expression::symbol("x")])
426        );
427    }
428
429    #[test]
430    fn test_simplify_sqrt_even_power() {
431        let expr = Expression::pow(Expression::symbol("x"), Expression::integer(4));
432        let result = simplify_sqrt(&expr);
433        assert_eq!(
434            result,
435            Expression::pow(Expression::symbol("x"), Expression::integer(2))
436        );
437    }
438}