num_valid/functions/
sqrt.rs

1#![deny(rustdoc::broken_intra_doc_links)]
2
3//! Square root operations.
4//!
5//! This module provides the [`Sqrt`] trait for computing square roots of validated
6//! real and complex numbers.
7
8use crate::{
9    core::{errors::capture_backtrace, policies::StrictFinitePolicy},
10    functions::FunctionErrors,
11    kernels::{RawComplexTrait, RawRealTrait, RawScalarTrait},
12};
13use num::Complex;
14use std::{backtrace::Backtrace, fmt};
15use thiserror::Error;
16use try_create::ValidationPolicy;
17
18//------------------------------------------------------------------------------------------------
19/// Errors that can occur during the input validation phase when attempting to compute
20/// the square root of a real number.
21///
22/// This enum is used as the source for the [`Input`](SqrtRealErrors::Input) variant of [`SqrtRealErrors`].
23/// It is generic over `RawReal: RawRealTrait`, which defines the specific raw real number type and its associated
24/// validation error type.
25///
26/// # Variants
27///
28/// - [`Self::NegativeValue`]: The input value is negative, which is not allowed for real square roots.
29/// - [`Self::ValidationError`]: The input value failed basic validation checks (e.g., NaN, infinity, subnormal) according to the validation policy.
30#[derive(Debug, Error)]
31pub enum SqrtRealInputErrors<RawReal: RawRealTrait> {
32    /// The input value for the square root computation is negative.
33    ///
34    /// The square root of a negative real number is not a real number.
35    #[error("the input value ({value:?}) is negative!")]
36    NegativeValue {
37        /// The negative input value.
38        value: RawReal,
39
40        /// The backtrace of the error.
41        backtrace: Backtrace,
42    },
43
44    /// The input value failed basic validation checks (e.g., it is NaN, infinite, or subnormal).
45    ///
46    /// This error occurs if the input value itself is considered invalid
47    /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
48    /// before the domain-specific check (like negativity) for the square root
49    /// operation is performed.
50    #[error("the input value is invalid!")]
51    ValidationError {
52        /// The source error that occurred during validation.
53        #[source]
54        #[backtrace]
55        source: <RawReal as RawScalarTrait>::ValidationErrors,
56    },
57}
58
59/// Errors that can occur during the input validation phase when computing the square root of a complex number.
60///
61/// This enum is used as the source for the [`Input`](SqrtComplexErrors::Input) variant of [`SqrtComplexErrors`].
62/// It is generic over `RawComplex: RawComplexTrait`.
63///
64/// # Variants
65///
66/// - [`Self::ValidationError`]: The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal) according to the validation policy.
67#[derive(Debug, Error)]
68pub enum SqrtComplexInputErrors<RawComplex: RawComplexTrait> {
69    /// The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal).
70    ///
71    /// This error occurs if the input complex value itself is considered invalid
72    /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
73    /// before any domain-specific checks for the complex square root operation are performed.
74    /// It signifies that the real or imaginary part (or both) does not meet fundamental validity criteria.
75    #[error("the input value is invalid!")]
76    ValidationError {
77        /// The detailed source error from the raw complex number's validation.
78        ///
79        /// This field encapsulates the specific error of type `<RawComplex as RawScalarTrait>::ValidationErrors`
80        /// that was reported during the validation of the input complex number's components.
81        #[source]
82        #[backtrace]
83        source: <RawComplex as RawScalarTrait>::ValidationErrors,
84    },
85}
86
87/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
88/// the computation of the square root of a real number.
89///
90/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a real number.
91///
92/// # Generic Parameters
93///
94/// - `RawReal`: A type that implements [`RawRealTrait`]. This defines:
95///   - The raw error type for the input real number via `SqrtRealInputErrors<RawReal>`.
96///   - The raw error type for the output real number (the square root) also via `<RawReal as RawScalarTrait>::ValidationErrors`.
97///
98/// # Variants
99///
100/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
101///
102/// - `Input { source: SqrtRealInputErrors<RawReal> }`:
103///   Indicates that the input real number was invalid for square root computation.
104///   This could be due to the number being negative or failing general validation checks
105///   (e.g., NaN, Infinity, subnormal). The `source` field provides more specific details
106///   via [`SqrtRealInputErrors`].
107///
108/// - `Output { source: <RawReal as RawScalarTrait>::ValidationErrors }`:
109///   Indicates that the computed square root (which should be a real number)
110///   failed validation. This typically means the result of the `sqrt` operation yielded
111///   a non-finite value (NaN or Infinity), which is unexpected if the input was valid
112///   (non-negative and finite).
113pub type SqrtRealErrors<RawReal> =
114    FunctionErrors<SqrtRealInputErrors<RawReal>, <RawReal as RawScalarTrait>::ValidationErrors>;
115
116/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
117/// the computation of the square root of a complex number.
118///
119/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a complex number.
120///
121/// # Generic Parameters
122///
123/// - `RawComplex`: A type that implements [`RawComplexTrait`]. This defines:
124///   - The raw error type for the input complex number via `SqrtComplexInputErrors<RawComplex>`.
125///   - The raw error type for the output complex number (the square root) also via `<RawComplex as RawScalarTrait>::ValidationErrors`.
126///
127/// # Variants
128///
129/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
130///
131/// - `Input { source: SqrtComplexInputErrors<RawComplex> }`:
132///   Indicates that the input complex number was invalid for square root computation.
133///   This is typically due to the complex number's components (real or imaginary parts)
134///   failing general validation checks (e.g., NaN, Infinity, subnormal).
135///   The `source` field provides more specific details via [`SqrtComplexInputErrors`].
136///
137/// - `Output { source: <RawComplex as RawScalarTrait>::ValidationErrors }`:
138///   Indicates that the computed complex square root itself failed validation.
139///   This typically means the result of the `sqrt` operation yielded a complex number
140///   with non-finite components (NaN or Infinity), which is unexpected if the input was valid.
141pub type SqrtComplexErrors<RawComplex> = FunctionErrors<
142    SqrtComplexInputErrors<RawComplex>,
143    <RawComplex as RawScalarTrait>::ValidationErrors,
144>;
145//--------------------------------------------------------------------------------------------
146
147//--------------------------------------------------------------------------------------------
148/// A trait for computing the principal square root of a number.
149///
150/// The principal square root of a non-negative real number `x` is the unique non-negative real number `y`
151/// such that `y^2 = x`.
152/// For a complex number `z`, its square root `w` satisfies `w^2 = z`. Complex numbers
153/// (except 0) have two square roots; this trait computes the principal square root,
154/// typically defined as `exp(0.5 * log(z))`, which usually has a non-negative real part.
155///
156/// This trait provides both a fallible version ([`try_sqrt`](Sqrt::try_sqrt)) that performs validation
157/// and an infallible version ([`sqrt`](Sqrt::sqrt)) that may panic in debug builds if validation fails.
158pub trait Sqrt: Sized {
159    /// The error type that can be returned by the `try_sqrt` method.
160    type Error: fmt::Debug;
161
162    /// Attempts to compute the principal square root of `self`, returning a `Result`.
163    ///
164    /// Implementations should validate the input `self` according to the domain
165    /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
166    /// If the input is valid, the square root is computed, and then the result
167    /// is also validated using the same policy.
168    ///
169    /// # Returns
170    ///
171    /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
172    ///   and the computed square root satisfy the validation policy.
173    /// - `Err(SqrtRealErrors)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
174    ///   or if the computed square root is invalid (see below).
175    ///
176    /// # Errors
177    ///
178    /// - Returns [`SqrtRealErrors::Input`] (for reals) or [`SqrtComplexErrors::Input`] (for complex)
179    ///   via [`FunctionErrors::Input`] if the input is invalid (e.g., negative real, NaN, Infinity, subnormal).
180    /// - Returns [`SqrtRealErrors::Output`] or [`SqrtComplexErrors::Output`] via [`FunctionErrors::Output`]
181    ///   if the result of the computation is not finite (e.g., NaN, Infinity) as per the validation policy.
182    ///
183    /// # Examples
184    ///
185    /// ```
186    /// use num_valid::functions::Sqrt;
187    /// use num::Complex;
188    ///
189    /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
190    /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
191    /// assert!(f64::NAN.try_sqrt().is_err());
192    ///
193    /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
194    /// let sqrt_z = z.try_sqrt().unwrap();
195    /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
196    /// ```
197    #[must_use = "this `Result` may contain an error that should be handled"]
198    fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error>;
199
200    /// Computes the square principal square root of `self`.
201    fn sqrt(self) -> Self;
202}
203
204impl Sqrt for f64 {
205    type Error = SqrtRealErrors<f64>;
206
207    /// Attempts to compute the principal square root of `self`, returning a `Result`.
208    ///
209    /// Implementations should validate the input `self` according to the domain
210    /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
211    /// If the input is valid, the square root is computed, and then the result
212    /// is also validated using the same policy.
213    ///
214    /// # Returns
215    ///
216    /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
217    ///   and the computed square root satisfy the validation policy.
218    /// - `Err(Self::Error)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
219    ///   or if the computed square root is invalid.
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// use num_valid::functions::Sqrt;
225    /// use num::Complex;
226    ///
227    /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
228    /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
229    /// assert!(f64::NAN.try_sqrt().is_err());
230    ///
231    /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
232    /// let sqrt_z = z.try_sqrt().unwrap();
233    /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
234    /// ```
235    #[inline(always)]
236    fn try_sqrt(self) -> Result<f64, <f64 as Sqrt>::Error> {
237        StrictFinitePolicy::<f64, 53>::validate(self)
238            .map_err(|e| SqrtRealInputErrors::ValidationError { source: e }.into())
239            .and_then(|validated_value| {
240                if validated_value < 0.0 {
241                    Err(SqrtRealInputErrors::NegativeValue {
242                        value: validated_value,
243                        backtrace: capture_backtrace(),
244                    }
245                    .into())
246                } else {
247                    StrictFinitePolicy::<f64, 53>::validate(f64::sqrt(validated_value))
248                        .map_err(|e| SqrtRealErrors::Output { source: e })
249                }
250            })
251    }
252
253    /// Computes and returns the principal square root of `self`.
254    ///
255    /// # Behavior
256    ///
257    /// - **Debug Builds (`#[cfg(debug_assertions)]`)**: This method internally calls
258    ///   [`try_sqrt().unwrap()`](Sqrt::try_sqrt). It will panic if `try_sqrt` returns an `Err`.
259    /// - **Release Builds (`#[cfg(not(debug_assertions))]`)**: This method calls the underlying
260    ///   square root function directly (e.g., `f64::sqrt`).
261    ///   The behavior for invalid inputs (like `sqrt(-1.0)` for `f64` returning NaN)
262    ///   depends on the underlying implementation.
263    ///
264    /// # Panics
265    ///
266    /// In debug builds, this method will panic if [`try_sqrt()`](Sqrt::try_sqrt) would return an `Err`.
267    ///
268    /// # Examples
269    ///
270    /// ```
271    /// use num_valid::functions::Sqrt;
272    /// use num::Complex;
273    ///
274    /// assert_eq!(9.0_f64.sqrt(), 3.0);
275    ///
276    /// // For f64, sqrt of negative is NaN in release, panics in debug with ftl's Sqrt
277    /// #[cfg(not(debug_assertions))]
278    /// assert!((-1.0_f64).sqrt().is_nan());
279    ///
280    /// let z = Complex::new(0.0, 4.0); // sqrt(4i) = sqrt(2) + i*sqrt(2)
281    /// let sqrt_z = z.sqrt();
282    /// let expected_val = std::f64::consts::SQRT_2;
283    /// assert!((sqrt_z.re - expected_val).abs() < 1e-9 && (sqrt_z.im - expected_val).abs() < 1e-9);
284    /// ```
285    #[inline(always)]
286    fn sqrt(self) -> Self {
287        #[cfg(debug_assertions)]
288        {
289            self.try_sqrt().unwrap()
290        }
291        #[cfg(not(debug_assertions))]
292        {
293            f64::sqrt(self)
294        }
295    }
296}
297
298impl Sqrt for Complex<f64> {
299    type Error = SqrtComplexErrors<Complex<f64>>;
300
301    /// Attempts to compute the principal square root of `self` (a `Complex<f64>`).
302    ///
303    /// This method first validates `self` using [`StrictFinitePolicy`] (components must be finite and normal).
304    /// If valid, it computes `Complex::sqrt` and validates the result using [`StrictFinitePolicy`].
305    ///
306    /// # Returns
307    ///
308    /// - `Ok(Complex<f64>)`: If `self` and the computed square root have finite and normal components.
309    /// - `Err(SqrtComplexErrors<Complex<f64>>)`: If `self` or the result has invalid components.
310    #[inline(always)]
311    fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error> {
312        StrictFinitePolicy::<Complex<f64>, 53>::validate(self)
313            .map_err(|e| SqrtComplexInputErrors::ValidationError { source: e }.into())
314            .and_then(|validated_value| {
315                StrictFinitePolicy::<Complex<f64>, 53>::validate(Complex::<f64>::sqrt(
316                    validated_value,
317                ))
318                .map_err(|e| SqrtComplexErrors::Output { source: e })
319            })
320    }
321
322    /// Computes and returns the principal square root of `self` (a `Complex<f64>`).
323    ///
324    /// # Behavior
325    ///
326    /// - **Debug Builds**: Calls `try_sqrt().unwrap()`. Panics on invalid input/output.
327    /// - **Release Builds**: Calls `Complex::sqrt(self)` directly.
328    ///
329    /// # Panics
330    ///
331    /// In debug builds, if `try_sqrt()` would return an `Err`.
332    #[inline(always)]
333    fn sqrt(self) -> Self {
334        #[cfg(debug_assertions)]
335        {
336            self.try_sqrt().unwrap()
337        }
338        #[cfg(not(debug_assertions))]
339        {
340            Complex::<f64>::sqrt(self)
341        }
342    }
343}
344
345//------------------------------------------------------------------------------------------------
346
347//------------------------------------------------------------------------------------------------
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use num::Complex;
352
353    #[cfg(feature = "rug")]
354    use try_create::TryNew;
355
356    mod sqrt {
357        use super::*;
358
359        mod native64 {
360            use super::*;
361
362            mod real {
363                use super::*;
364
365                #[test]
366                fn test_f64_sqrt_valid() {
367                    let value = 4.0;
368
369                    assert_eq!(value.try_sqrt().unwrap(), 2.0);
370                    assert_eq!(<f64 as Sqrt>::sqrt(value), 2.0);
371                }
372
373                #[test]
374                fn test_f64_sqrt_negative_value() {
375                    let value = -4.0;
376                    let result = value.try_sqrt();
377                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
378                }
379
380                #[test]
381                fn test_f64_sqrt_subnormal() {
382                    let value = f64::MIN_POSITIVE / 2.0;
383                    let result = value.try_sqrt();
384                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
385                }
386
387                #[test]
388                fn test_f64_sqrt_zero() {
389                    let value = 0.0;
390                    let result = value.try_sqrt();
391                    assert!(matches!(result, Ok(0.0)));
392                }
393
394                #[test]
395                fn test_f64_sqrt_nan() {
396                    let value = f64::NAN;
397                    let result = value.try_sqrt();
398                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
399                }
400
401                #[test]
402                fn test_f64_sqrt_infinite() {
403                    let value = f64::INFINITY;
404                    let result = value.try_sqrt();
405                    println!("result: {result:?}");
406                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
407                }
408            }
409
410            mod complex {
411                use super::*;
412
413                #[test]
414                fn test_complex_f64_sqrt_valid() {
415                    let value = Complex::new(4.0, 0.0);
416
417                    let expected_result = Complex::new(2.0, 0.0);
418
419                    assert_eq!(value.try_sqrt().unwrap(), expected_result);
420                    assert_eq!(<Complex<f64> as Sqrt>::sqrt(value), expected_result);
421                }
422
423                #[test]
424                fn test_complex_f64_sqrt_invalid() {
425                    let value = Complex::new(f64::NAN, 0.0);
426                    let result = value.try_sqrt();
427                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
428
429                    let value = Complex::new(0.0, f64::NAN);
430                    let result = value.try_sqrt();
431                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
432
433                    let value = Complex::new(f64::INFINITY, 0.0);
434                    let result = value.try_sqrt();
435                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
436
437                    let value = Complex::new(0.0, f64::INFINITY);
438                    let result = value.try_sqrt();
439                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
440
441                    let value = Complex::new(f64::MIN_POSITIVE / 2.0, 0.0);
442                    let result = value.try_sqrt();
443                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
444
445                    let value = Complex::new(0., f64::MIN_POSITIVE / 2.0);
446                    let result = value.try_sqrt();
447                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
448                }
449            }
450        }
451
452        #[cfg(feature = "rug")]
453        mod rug53 {
454            use super::*;
455            use crate::backends::rug::validated::{ComplexRugStrictFinite, RealRugStrictFinite};
456
457            mod real {
458                use super::*;
459
460                #[test]
461                fn test_rug_float_sqrt_valid() {
462                    let value =
463                        RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 4.0)).unwrap();
464                    let expected_result =
465                        RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 2.0)).unwrap();
466
467                    assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
468                    assert_eq!(value.sqrt(), expected_result);
469                }
470
471                #[test]
472                fn test_rug_float_sqrt_negative_value() {
473                    let value =
474                        RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, -4.0)).unwrap();
475                    let result = value.try_sqrt();
476                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
477                }
478            }
479
480            mod complex {
481                use super::*;
482
483                #[test]
484                fn test_complex_rug_float_sqrt_valid() {
485                    let value = ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
486                        53,
487                        (rug::Float::with_val(53, 4.0), rug::Float::with_val(53, 0.0)),
488                    ))
489                    .unwrap();
490
491                    let expected_result =
492                        ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
493                            53,
494                            (rug::Float::with_val(53, 2.0), rug::Float::with_val(53, 0.0)),
495                        ))
496                        .unwrap();
497
498                    assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
499                    assert_eq!(value.sqrt(), expected_result);
500                }
501            }
502        }
503    }
504}