num_valid/functions/
sqrt.rs

1#![deny(rustdoc::broken_intra_doc_links)]
2
3use crate::{
4    functions::FunctionErrors,
5    kernels::{RawComplexTrait, RawRealTrait, RawScalarTrait},
6    validation::StrictFinitePolicy,
7};
8use num::Complex;
9use std::{backtrace::Backtrace, fmt};
10use thiserror::Error;
11use try_create::ValidationPolicy;
12
13//------------------------------------------------------------------------------------------------
14/// Errors that can occur during the input validation phase when attempting to compute
15/// the square root of a real number.
16///
17/// This enum is used as the source for the [`Input`](SqrtRealErrors::Input) variant of [`SqrtRealErrors`].
18/// It is generic over `RawReal: RawRealTrait`, which defines the specific raw real number type and its associated
19/// validation error type.
20///
21/// # Variants
22///
23/// - [`Self::NegativeValue`]: The input value is negative, which is not allowed for real square roots.
24/// - [`Self::ValidationError`]: The input value failed basic validation checks (e.g., NaN, infinity, subnormal) according to the validation policy.
25#[derive(Debug, Error)]
26pub enum SqrtRealInputErrors<RawReal: RawRealTrait> {
27    /// The input value for the square root computation is negative.
28    ///
29    /// The square root of a negative real number is not a real number.
30    #[error("the input value ({value:?}) is negative!")]
31    NegativeValue {
32        /// The negative input value.
33        value: RawReal,
34
35        /// The backtrace of the error.
36        backtrace: Backtrace,
37    },
38
39    /// The input value failed basic validation checks (e.g., it is NaN, infinite, or subnormal).
40    ///
41    /// This error occurs if the input value itself is considered invalid
42    /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
43    /// before the domain-specific check (like negativity) for the square root
44    /// operation is performed.
45    #[error("the input value is invalid!")]
46    ValidationError {
47        /// The source error that occurred during validation.
48        #[source]
49        #[backtrace]
50        source: <RawReal as RawScalarTrait>::ValidationErrors,
51    },
52}
53
54/// Errors that can occur during the input validation phase when computing the square root of a complex number.
55///
56/// This enum is used as the source for the [`Input`](SqrtComplexErrors::Input) variant of [`SqrtComplexErrors`].
57/// It is generic over `RawComplex: RawComplexTrait`.
58///
59/// # Variants
60///
61/// - [`Self::ValidationError`]: The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal) according to the validation policy.
62#[derive(Debug, Error)]
63pub enum SqrtComplexInputErrors<RawComplex: RawComplexTrait> {
64    /// The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal).
65    ///
66    /// This error occurs if the input complex value itself is considered invalid
67    /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
68    /// before any domain-specific checks for the complex square root operation are performed.
69    /// It signifies that the real or imaginary part (or both) does not meet fundamental validity criteria.
70    #[error("the input value is invalid!")]
71    ValidationError {
72        /// The detailed source error from the raw complex number's validation.
73        ///
74        /// This field encapsulates the specific error of type `<RawComplex as RawScalarTrait>::ValidationErrors`
75        /// that was reported during the validation of the input complex number's components.
76        #[source]
77        #[backtrace]
78        source: <RawComplex as RawScalarTrait>::ValidationErrors,
79    },
80}
81
82/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
83/// the computation of the square root of a real number.
84///
85/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a real number.
86///
87/// # Generic Parameters
88///
89/// - `RawReal`: A type that implements [`RawRealTrait`]. This defines:
90///   - The raw error type for the input real number via `SqrtRealInputErrors<RawReal>`.
91///   - The raw error type for the output real number (the square root) also via `<RawReal as RawScalarTrait>::ValidationErrors`.
92///
93/// # Variants
94///
95/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
96///
97/// - `Input { source: SqrtRealInputErrors<RawReal> }`:
98///   Indicates that the input real number was invalid for square root computation.
99///   This could be due to the number being negative or failing general validation checks
100///   (e.g., NaN, Infinity, subnormal). The `source` field provides more specific details
101///   via [`SqrtRealInputErrors`].
102///
103/// - `Output { source: <RawReal as RawScalarTrait>::ValidationErrors }`:
104///   Indicates that the computed square root (which should be a real number)
105///   failed validation. This typically means the result of the `sqrt` operation yielded
106///   a non-finite value (NaN or Infinity), which is unexpected if the input was valid
107///   (non-negative and finite).
108pub type SqrtRealErrors<RawReal> =
109    FunctionErrors<SqrtRealInputErrors<RawReal>, <RawReal as RawScalarTrait>::ValidationErrors>;
110
111/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
112/// the computation of the square root of a complex number.
113///
114/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a complex number.
115///
116/// # Generic Parameters
117///
118/// - `RawComplex`: A type that implements [`RawComplexTrait`]. This defines:
119///   - The raw error type for the input complex number via `SqrtComplexInputErrors<RawComplex>`.
120///   - The raw error type for the output complex number (the square root) also via `<RawComplex as RawScalarTrait>::ValidationErrors`.
121///
122/// # Variants
123///
124/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
125///
126/// - `Input { source: SqrtComplexInputErrors<RawComplex> }`:
127///   Indicates that the input complex number was invalid for square root computation.
128///   This is typically due to the complex number's components (real or imaginary parts)
129///   failing general validation checks (e.g., NaN, Infinity, subnormal).
130///   The `source` field provides more specific details via [`SqrtComplexInputErrors`].
131///
132/// - `Output { source: <RawComplex as RawScalarTrait>::ValidationErrors }`:
133///   Indicates that the computed complex square root itself failed validation.
134///   This typically means the result of the `sqrt` operation yielded a complex number
135///   with non-finite components (NaN or Infinity), which is unexpected if the input was valid.
136pub type SqrtComplexErrors<RawComplex> = FunctionErrors<
137    SqrtComplexInputErrors<RawComplex>,
138    <RawComplex as RawScalarTrait>::ValidationErrors,
139>;
140//--------------------------------------------------------------------------------------------
141
142//--------------------------------------------------------------------------------------------
143/// A trait for computing the principal square root of a number.
144///
145/// The principal square root of a non-negative real number `x` is the unique non-negative real number `y`
146/// such that `y^2 = x`.
147/// For a complex number `z`, its square root `w` satisfies `w^2 = z`. Complex numbers
148/// (except 0) have two square roots; this trait computes the principal square root,
149/// typically defined as `exp(0.5 * log(z))`, which usually has a non-negative real part.
150///
151/// This trait provides both a fallible version ([`try_sqrt`](Sqrt::try_sqrt)) that performs validation
152/// and an infallible version ([`sqrt`](Sqrt::sqrt)) that may panic in debug builds if validation fails.
153pub trait Sqrt: Sized {
154    /// The error type that can be returned by the `try_sqrt` method.
155    type Error: fmt::Debug;
156
157    /// Attempts to compute the principal square root of `self`, returning a `Result`.
158    ///
159    /// Implementations should validate the input `self` according to the domain
160    /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
161    /// If the input is valid, the square root is computed, and then the result
162    /// is also validated using the same policy.
163    ///
164    /// # Returns
165    ///
166    /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
167    ///   and the computed square root satisfy the validation policy.
168    /// - `Err(SqrtRealErrors)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
169    ///   or if the computed square root is invalid (see below).
170    ///
171    /// # Errors
172    ///
173    /// - Returns [`SqrtRealErrors::Input`] (for reals) or [`SqrtComplexErrors::Input`] (for complex)
174    ///   via [`FunctionErrors::Input`] if the input is invalid (e.g., negative real, NaN, Infinity, subnormal).
175    /// - Returns [`SqrtRealErrors::Output`] or [`SqrtComplexErrors::Output`] via [`FunctionErrors::Output`]
176    ///   if the result of the computation is not finite (e.g., NaN, Infinity) as per the validation policy.
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// use num_valid::functions::Sqrt;
182    /// use num::Complex;
183    ///
184    /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
185    /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
186    /// assert!(f64::NAN.try_sqrt().is_err());
187    ///
188    /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
189    /// let sqrt_z = z.try_sqrt().unwrap();
190    /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
191    /// ```
192    fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error>;
193
194    /// Computes the square principal square root of `self`.
195    fn sqrt(self) -> Self;
196}
197
198impl Sqrt for f64 {
199    type Error = SqrtRealErrors<f64>;
200
201    /// Attempts to compute the principal square root of `self`, returning a `Result`.
202    ///
203    /// Implementations should validate the input `self` according to the domain
204    /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
205    /// If the input is valid, the square root is computed, and then the result
206    /// is also validated using the same policy.
207    ///
208    /// # Returns
209    ///
210    /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
211    ///   and the computed square root satisfy the validation policy.
212    /// - `Err(Self::Error)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
213    ///   or if the computed square root is invalid.
214    ///
215    /// # Examples
216    ///
217    /// ```
218    /// use num_valid::functions::Sqrt;
219    /// use num::Complex;
220    ///
221    /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
222    /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
223    /// assert!(f64::NAN.try_sqrt().is_err());
224    ///
225    /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
226    /// let sqrt_z = z.try_sqrt().unwrap();
227    /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
228    /// ```
229    #[inline(always)]
230    fn try_sqrt(self) -> Result<f64, <f64 as Sqrt>::Error> {
231        StrictFinitePolicy::<f64, 53>::validate(self)
232            .map_err(|e| SqrtRealInputErrors::ValidationError { source: e }.into())
233            .and_then(|validated_value| {
234                if validated_value < 0.0 {
235                    Err(SqrtRealInputErrors::NegativeValue {
236                        value: validated_value,
237                        backtrace: Backtrace::force_capture(),
238                    }
239                    .into())
240                } else {
241                    StrictFinitePolicy::<f64, 53>::validate(f64::sqrt(validated_value))
242                        .map_err(|e| SqrtRealErrors::Output { source: e })
243                }
244            })
245    }
246
247    /// Computes and returns the principal square root of `self`.
248    ///
249    /// # Behavior
250    ///
251    /// - **Debug Builds (`#[cfg(debug_assertions)]`)**: This method internally calls
252    ///   [`try_sqrt().unwrap()`](Sqrt::try_sqrt). It will panic if `try_sqrt` returns an `Err`.
253    /// - **Release Builds (`#[cfg(not(debug_assertions))]`)**: This method calls the underlying
254    ///   square root function directly (e.g., `f64::sqrt`).
255    ///   The behavior for invalid inputs (like `sqrt(-1.0)` for `f64` returning NaN)
256    ///   depends on the underlying implementation.
257    ///
258    /// # Panics
259    ///
260    /// In debug builds, this method will panic if [`try_sqrt()`](Sqrt::try_sqrt) would return an `Err`.
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use num_valid::functions::Sqrt;
266    /// use num::Complex;
267    ///
268    /// assert_eq!(9.0_f64.sqrt(), 3.0);
269    ///
270    /// // For f64, sqrt of negative is NaN in release, panics in debug with ftl's Sqrt
271    /// #[cfg(not(debug_assertions))]
272    /// assert!((-1.0_f64).sqrt().is_nan());
273    ///
274    /// let z = Complex::new(0.0, 4.0); // sqrt(4i) = sqrt(2) + i*sqrt(2)
275    /// let sqrt_z = z.sqrt();
276    /// let expected_val = std::f64::consts::SQRT_2;
277    /// assert!((sqrt_z.re - expected_val).abs() < 1e-9 && (sqrt_z.im - expected_val).abs() < 1e-9);
278    /// ```
279    #[inline(always)]
280    fn sqrt(self) -> Self {
281        #[cfg(debug_assertions)]
282        {
283            self.try_sqrt().unwrap()
284        }
285        #[cfg(not(debug_assertions))]
286        {
287            f64::sqrt(self)
288        }
289    }
290}
291
292impl Sqrt for Complex<f64> {
293    type Error = SqrtComplexErrors<Complex<f64>>;
294
295    /// Attempts to compute the principal square root of `self` (a `Complex<f64>`).
296    ///
297    /// This method first validates `self` using [`StrictFinitePolicy`] (components must be finite and normal).
298    /// If valid, it computes `Complex::sqrt` and validates the result using [`StrictFinitePolicy`].
299    ///
300    /// # Returns
301    ///
302    /// - `Ok(Complex<f64>)`: If `self` and the computed square root have finite and normal components.
303    /// - `Err(SqrtComplexErrors<Complex<f64>>)`: If `self` or the result has invalid components.
304    #[inline(always)]
305    fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error> {
306        StrictFinitePolicy::<Complex<f64>, 53>::validate(self)
307            .map_err(|e| SqrtComplexInputErrors::ValidationError { source: e }.into())
308            .and_then(|validated_value| {
309                StrictFinitePolicy::<Complex<f64>, 53>::validate(Complex::<f64>::sqrt(
310                    validated_value,
311                ))
312                .map_err(|e| SqrtComplexErrors::Output { source: e })
313            })
314    }
315
316    /// Computes and returns the principal square root of `self` (a `Complex<f64>`).
317    ///
318    /// # Behavior
319    ///
320    /// - **Debug Builds**: Calls `try_sqrt().unwrap()`. Panics on invalid input/output.
321    /// - **Release Builds**: Calls `Complex::sqrt(self)` directly.
322    ///
323    /// # Panics
324    ///
325    /// In debug builds, if `try_sqrt()` would return an `Err`.
326    #[inline(always)]
327    fn sqrt(self) -> Self {
328        #[cfg(debug_assertions)]
329        {
330            self.try_sqrt().unwrap()
331        }
332        #[cfg(not(debug_assertions))]
333        {
334            Complex::<f64>::sqrt(self)
335        }
336    }
337}
338
339//------------------------------------------------------------------------------------------------
340
341//------------------------------------------------------------------------------------------------
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use num::Complex;
346
347    #[cfg(feature = "rug")]
348    use try_create::TryNew;
349
350    mod sqrt {
351        use super::*;
352
353        mod native64 {
354            use super::*;
355
356            mod real {
357                use super::*;
358
359                #[test]
360                fn test_f64_sqrt_valid() {
361                    let value = 4.0;
362
363                    assert_eq!(value.try_sqrt().unwrap(), 2.0);
364                    assert_eq!(<f64 as Sqrt>::sqrt(value), 2.0);
365                }
366
367                #[test]
368                fn test_f64_sqrt_negative_value() {
369                    let value = -4.0;
370                    let result = value.try_sqrt();
371                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
372                }
373
374                #[test]
375                fn test_f64_sqrt_subnormal() {
376                    let value = f64::MIN_POSITIVE / 2.0;
377                    let result = value.try_sqrt();
378                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
379                }
380
381                #[test]
382                fn test_f64_sqrt_zero() {
383                    let value = 0.0;
384                    let result = value.try_sqrt();
385                    assert!(matches!(result, Ok(0.0)));
386                }
387
388                #[test]
389                fn test_f64_sqrt_nan() {
390                    let value = f64::NAN;
391                    let result = value.try_sqrt();
392                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
393                }
394
395                #[test]
396                fn test_f64_sqrt_infinite() {
397                    let value = f64::INFINITY;
398                    let result = value.try_sqrt();
399                    println!("result: {result:?}");
400                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
401                }
402            }
403
404            mod complex {
405                use super::*;
406
407                #[test]
408                fn test_complex_f64_sqrt_valid() {
409                    let value = Complex::new(4.0, 0.0);
410
411                    let expected_result = Complex::new(2.0, 0.0);
412
413                    assert_eq!(value.try_sqrt().unwrap(), expected_result);
414                    assert_eq!(<Complex<f64> as Sqrt>::sqrt(value), expected_result);
415                }
416
417                #[test]
418                fn test_complex_f64_sqrt_invalid() {
419                    let value = Complex::new(f64::NAN, 0.0);
420                    let result = value.try_sqrt();
421                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
422
423                    let value = Complex::new(0.0, f64::NAN);
424                    let result = value.try_sqrt();
425                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
426
427                    let value = Complex::new(f64::INFINITY, 0.0);
428                    let result = value.try_sqrt();
429                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
430
431                    let value = Complex::new(0.0, f64::INFINITY);
432                    let result = value.try_sqrt();
433                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
434
435                    let value = Complex::new(f64::MIN_POSITIVE / 2.0, 0.0);
436                    let result = value.try_sqrt();
437                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
438
439                    let value = Complex::new(0., f64::MIN_POSITIVE / 2.0);
440                    let result = value.try_sqrt();
441                    assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
442                }
443            }
444        }
445
446        #[cfg(feature = "rug")]
447        mod rug53 {
448            use super::*;
449            use crate::kernels::rug::{ComplexRugStrictFinite, RealRugStrictFinite};
450
451            mod real {
452                use super::*;
453
454                #[test]
455                fn test_rug_float_sqrt_valid() {
456                    let value =
457                        RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 4.0)).unwrap();
458                    let expected_result =
459                        RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 2.0)).unwrap();
460
461                    assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
462                    assert_eq!(value.sqrt(), expected_result);
463                }
464
465                #[test]
466                fn test_rug_float_sqrt_negative_value() {
467                    let value =
468                        RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, -4.0)).unwrap();
469                    let result = value.try_sqrt();
470                    assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
471                }
472            }
473
474            mod complex {
475                use super::*;
476
477                #[test]
478                fn test_complex_rug_float_sqrt_valid() {
479                    let value = ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
480                        53,
481                        (rug::Float::with_val(53, 4.0), rug::Float::with_val(53, 0.0)),
482                    ))
483                    .unwrap();
484
485                    let expected_result =
486                        ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
487                            53,
488                            (rug::Float::with_val(53, 2.0), rug::Float::with_val(53, 0.0)),
489                        ))
490                        .unwrap();
491
492                    assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
493                    assert_eq!(value.sqrt(), expected_result);
494                }
495            }
496        }
497    }
498}