num_valid/functions/sqrt.rs
1#![deny(rustdoc::broken_intra_doc_links)]
2
3//! Square root operations.
4//!
5//! This module provides the [`Sqrt`] trait for computing square roots of validated
6//! real and complex numbers.
7
8use crate::{
9 core::{errors::capture_backtrace, policies::StrictFinitePolicy},
10 functions::FunctionErrors,
11 kernels::{RawComplexTrait, RawRealTrait, RawScalarTrait},
12};
13use num::Complex;
14use std::{backtrace::Backtrace, fmt};
15use thiserror::Error;
16use try_create::ValidationPolicy;
17
18//------------------------------------------------------------------------------------------------
19/// Errors that can occur during the input validation phase when attempting to compute
20/// the square root of a real number.
21///
22/// This enum is used as the source for the [`Input`](SqrtRealErrors::Input) variant of [`SqrtRealErrors`].
23/// It is generic over `RawReal: RawRealTrait`, which defines the specific raw real number type and its associated
24/// validation error type.
25///
26/// # Variants
27///
28/// - [`Self::NegativeValue`]: The input value is negative, which is not allowed for real square roots.
29/// - [`Self::ValidationError`]: The input value failed basic validation checks (e.g., NaN, infinity, subnormal) according to the validation policy.
30#[derive(Debug, Error)]
31pub enum SqrtRealInputErrors<RawReal: RawRealTrait> {
32 /// The input value for the square root computation is negative.
33 ///
34 /// The square root of a negative real number is not a real number.
35 #[error("the input value ({value:?}) is negative!")]
36 NegativeValue {
37 /// The negative input value.
38 value: RawReal,
39
40 /// The backtrace of the error.
41 backtrace: Backtrace,
42 },
43
44 /// The input value failed basic validation checks (e.g., it is NaN, infinite, or subnormal).
45 ///
46 /// This error occurs if the input value itself is considered invalid
47 /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
48 /// before the domain-specific check (like negativity) for the square root
49 /// operation is performed.
50 #[error("the input value is invalid!")]
51 ValidationError {
52 /// The source error that occurred during validation.
53 #[source]
54 #[backtrace]
55 source: <RawReal as RawScalarTrait>::ValidationErrors,
56 },
57}
58
59/// Errors that can occur during the input validation phase when computing the square root of a complex number.
60///
61/// This enum is used as the source for the [`Input`](SqrtComplexErrors::Input) variant of [`SqrtComplexErrors`].
62/// It is generic over `RawComplex: RawComplexTrait`.
63///
64/// # Variants
65///
66/// - [`Self::ValidationError`]: The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal) according to the validation policy.
67#[derive(Debug, Error)]
68pub enum SqrtComplexInputErrors<RawComplex: RawComplexTrait> {
69 /// The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal).
70 ///
71 /// This error occurs if the input complex value itself is considered invalid
72 /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
73 /// before any domain-specific checks for the complex square root operation are performed.
74 /// It signifies that the real or imaginary part (or both) does not meet fundamental validity criteria.
75 #[error("the input value is invalid!")]
76 ValidationError {
77 /// The detailed source error from the raw complex number's validation.
78 ///
79 /// This field encapsulates the specific error of type `<RawComplex as RawScalarTrait>::ValidationErrors`
80 /// that was reported during the validation of the input complex number's components.
81 #[source]
82 #[backtrace]
83 source: <RawComplex as RawScalarTrait>::ValidationErrors,
84 },
85}
86
87/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
88/// the computation of the square root of a real number.
89///
90/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a real number.
91///
92/// # Generic Parameters
93///
94/// - `RawReal`: A type that implements [`RawRealTrait`]. This defines:
95/// - The raw error type for the input real number via `SqrtRealInputErrors<RawReal>`.
96/// - The raw error type for the output real number (the square root) also via `<RawReal as RawScalarTrait>::ValidationErrors`.
97///
98/// # Variants
99///
100/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
101///
102/// - `Input { source: SqrtRealInputErrors<RawReal> }`:
103/// Indicates that the input real number was invalid for square root computation.
104/// This could be due to the number being negative or failing general validation checks
105/// (e.g., NaN, Infinity, subnormal). The `source` field provides more specific details
106/// via [`SqrtRealInputErrors`].
107///
108/// - `Output { source: <RawReal as RawScalarTrait>::ValidationErrors }`:
109/// Indicates that the computed square root (which should be a real number)
110/// failed validation. This typically means the result of the `sqrt` operation yielded
111/// a non-finite value (NaN or Infinity), which is unexpected if the input was valid
112/// (non-negative and finite).
113pub type SqrtRealErrors<RawReal> =
114 FunctionErrors<SqrtRealInputErrors<RawReal>, <RawReal as RawScalarTrait>::ValidationErrors>;
115
116/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
117/// the computation of the square root of a complex number.
118///
119/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a complex number.
120///
121/// # Generic Parameters
122///
123/// - `RawComplex`: A type that implements [`RawComplexTrait`]. This defines:
124/// - The raw error type for the input complex number via `SqrtComplexInputErrors<RawComplex>`.
125/// - The raw error type for the output complex number (the square root) also via `<RawComplex as RawScalarTrait>::ValidationErrors`.
126///
127/// # Variants
128///
129/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
130///
131/// - `Input { source: SqrtComplexInputErrors<RawComplex> }`:
132/// Indicates that the input complex number was invalid for square root computation.
133/// This is typically due to the complex number's components (real or imaginary parts)
134/// failing general validation checks (e.g., NaN, Infinity, subnormal).
135/// The `source` field provides more specific details via [`SqrtComplexInputErrors`].
136///
137/// - `Output { source: <RawComplex as RawScalarTrait>::ValidationErrors }`:
138/// Indicates that the computed complex square root itself failed validation.
139/// This typically means the result of the `sqrt` operation yielded a complex number
140/// with non-finite components (NaN or Infinity), which is unexpected if the input was valid.
141pub type SqrtComplexErrors<RawComplex> = FunctionErrors<
142 SqrtComplexInputErrors<RawComplex>,
143 <RawComplex as RawScalarTrait>::ValidationErrors,
144>;
145//--------------------------------------------------------------------------------------------
146
147//--------------------------------------------------------------------------------------------
148/// A trait for computing the principal square root of a number.
149///
150/// The principal square root of a non-negative real number `x` is the unique non-negative real number `y`
151/// such that `y^2 = x`.
152/// For a complex number `z`, its square root `w` satisfies `w^2 = z`. Complex numbers
153/// (except 0) have two square roots; this trait computes the principal square root,
154/// typically defined as `exp(0.5 * log(z))`, which usually has a non-negative real part.
155///
156/// This trait provides both a fallible version ([`try_sqrt`](Sqrt::try_sqrt)) that performs validation
157/// and an infallible version ([`sqrt`](Sqrt::sqrt)) that may panic in debug builds if validation fails.
158pub trait Sqrt: Sized {
159 /// The error type that can be returned by the `try_sqrt` method.
160 type Error: fmt::Debug;
161
162 /// Attempts to compute the principal square root of `self`, returning a `Result`.
163 ///
164 /// Implementations should validate the input `self` according to the domain
165 /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
166 /// If the input is valid, the square root is computed, and then the result
167 /// is also validated using the same policy.
168 ///
169 /// # Returns
170 ///
171 /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
172 /// and the computed square root satisfy the validation policy.
173 /// - `Err(SqrtRealErrors)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
174 /// or if the computed square root is invalid (see below).
175 ///
176 /// # Errors
177 ///
178 /// - Returns [`SqrtRealErrors::Input`] (for reals) or [`SqrtComplexErrors::Input`] (for complex)
179 /// via [`FunctionErrors::Input`] if the input is invalid (e.g., negative real, NaN, Infinity, subnormal).
180 /// - Returns [`SqrtRealErrors::Output`] or [`SqrtComplexErrors::Output`] via [`FunctionErrors::Output`]
181 /// if the result of the computation is not finite (e.g., NaN, Infinity) as per the validation policy.
182 ///
183 /// # Examples
184 ///
185 /// ```
186 /// use num_valid::functions::Sqrt;
187 /// use num::Complex;
188 ///
189 /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
190 /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
191 /// assert!(f64::NAN.try_sqrt().is_err());
192 ///
193 /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
194 /// let sqrt_z = z.try_sqrt().unwrap();
195 /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
196 /// ```
197 #[must_use = "this `Result` may contain an error that should be handled"]
198 fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error>;
199
200 /// Computes the square principal square root of `self`.
201 fn sqrt(self) -> Self;
202}
203
204impl Sqrt for f64 {
205 type Error = SqrtRealErrors<f64>;
206
207 /// Attempts to compute the principal square root of `self`, returning a `Result`.
208 ///
209 /// Implementations should validate the input `self` according to the domain
210 /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
211 /// If the input is valid, the square root is computed, and then the result
212 /// is also validated using the same policy.
213 ///
214 /// # Returns
215 ///
216 /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
217 /// and the computed square root satisfy the validation policy.
218 /// - `Err(Self::Error)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
219 /// or if the computed square root is invalid.
220 ///
221 /// # Examples
222 ///
223 /// ```
224 /// use num_valid::functions::Sqrt;
225 /// use num::Complex;
226 ///
227 /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
228 /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
229 /// assert!(f64::NAN.try_sqrt().is_err());
230 ///
231 /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
232 /// let sqrt_z = z.try_sqrt().unwrap();
233 /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
234 /// ```
235 #[inline(always)]
236 fn try_sqrt(self) -> Result<f64, <f64 as Sqrt>::Error> {
237 StrictFinitePolicy::<f64, 53>::validate(self)
238 .map_err(|e| SqrtRealInputErrors::ValidationError { source: e }.into())
239 .and_then(|validated_value| {
240 if validated_value < 0.0 {
241 Err(SqrtRealInputErrors::NegativeValue {
242 value: validated_value,
243 backtrace: capture_backtrace(),
244 }
245 .into())
246 } else {
247 StrictFinitePolicy::<f64, 53>::validate(f64::sqrt(validated_value))
248 .map_err(|e| SqrtRealErrors::Output { source: e })
249 }
250 })
251 }
252
253 /// Computes and returns the principal square root of `self`.
254 ///
255 /// # Behavior
256 ///
257 /// - **Debug Builds (`#[cfg(debug_assertions)]`)**: This method internally calls
258 /// [`try_sqrt().unwrap()`](Sqrt::try_sqrt). It will panic if `try_sqrt` returns an `Err`.
259 /// - **Release Builds (`#[cfg(not(debug_assertions))]`)**: This method calls the underlying
260 /// square root function directly (e.g., `f64::sqrt`).
261 /// The behavior for invalid inputs (like `sqrt(-1.0)` for `f64` returning NaN)
262 /// depends on the underlying implementation.
263 ///
264 /// # Panics
265 ///
266 /// In debug builds, this method will panic if [`try_sqrt()`](Sqrt::try_sqrt) would return an `Err`.
267 ///
268 /// # Examples
269 ///
270 /// ```
271 /// use num_valid::functions::Sqrt;
272 /// use num::Complex;
273 ///
274 /// assert_eq!(9.0_f64.sqrt(), 3.0);
275 ///
276 /// // For f64, sqrt of negative is NaN in release, panics in debug with ftl's Sqrt
277 /// #[cfg(not(debug_assertions))]
278 /// assert!((-1.0_f64).sqrt().is_nan());
279 ///
280 /// let z = Complex::new(0.0, 4.0); // sqrt(4i) = sqrt(2) + i*sqrt(2)
281 /// let sqrt_z = z.sqrt();
282 /// let expected_val = std::f64::consts::SQRT_2;
283 /// assert!((sqrt_z.re - expected_val).abs() < 1e-9 && (sqrt_z.im - expected_val).abs() < 1e-9);
284 /// ```
285 #[inline(always)]
286 fn sqrt(self) -> Self {
287 #[cfg(debug_assertions)]
288 {
289 self.try_sqrt().unwrap()
290 }
291 #[cfg(not(debug_assertions))]
292 {
293 f64::sqrt(self)
294 }
295 }
296}
297
298impl Sqrt for Complex<f64> {
299 type Error = SqrtComplexErrors<Complex<f64>>;
300
301 /// Attempts to compute the principal square root of `self` (a `Complex<f64>`).
302 ///
303 /// This method first validates `self` using [`StrictFinitePolicy`] (components must be finite and normal).
304 /// If valid, it computes `Complex::sqrt` and validates the result using [`StrictFinitePolicy`].
305 ///
306 /// # Returns
307 ///
308 /// - `Ok(Complex<f64>)`: If `self` and the computed square root have finite and normal components.
309 /// - `Err(SqrtComplexErrors<Complex<f64>>)`: If `self` or the result has invalid components.
310 #[inline(always)]
311 fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error> {
312 StrictFinitePolicy::<Complex<f64>, 53>::validate(self)
313 .map_err(|e| SqrtComplexInputErrors::ValidationError { source: e }.into())
314 .and_then(|validated_value| {
315 StrictFinitePolicy::<Complex<f64>, 53>::validate(Complex::<f64>::sqrt(
316 validated_value,
317 ))
318 .map_err(|e| SqrtComplexErrors::Output { source: e })
319 })
320 }
321
322 /// Computes and returns the principal square root of `self` (a `Complex<f64>`).
323 ///
324 /// # Behavior
325 ///
326 /// - **Debug Builds**: Calls `try_sqrt().unwrap()`. Panics on invalid input/output.
327 /// - **Release Builds**: Calls `Complex::sqrt(self)` directly.
328 ///
329 /// # Panics
330 ///
331 /// In debug builds, if `try_sqrt()` would return an `Err`.
332 #[inline(always)]
333 fn sqrt(self) -> Self {
334 #[cfg(debug_assertions)]
335 {
336 self.try_sqrt().unwrap()
337 }
338 #[cfg(not(debug_assertions))]
339 {
340 Complex::<f64>::sqrt(self)
341 }
342 }
343}
344
345//------------------------------------------------------------------------------------------------
346
347//------------------------------------------------------------------------------------------------
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use num::Complex;
352
353 #[cfg(feature = "rug")]
354 use try_create::TryNew;
355
356 mod sqrt {
357 use super::*;
358
359 mod native64 {
360 use super::*;
361
362 mod real {
363 use super::*;
364
365 #[test]
366 fn test_f64_sqrt_valid() {
367 let value = 4.0;
368
369 assert_eq!(value.try_sqrt().unwrap(), 2.0);
370 assert_eq!(<f64 as Sqrt>::sqrt(value), 2.0);
371 }
372
373 #[test]
374 fn test_f64_sqrt_negative_value() {
375 let value = -4.0;
376 let result = value.try_sqrt();
377 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
378 }
379
380 #[test]
381 fn test_f64_sqrt_subnormal() {
382 let value = f64::MIN_POSITIVE / 2.0;
383 let result = value.try_sqrt();
384 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
385 }
386
387 #[test]
388 fn test_f64_sqrt_zero() {
389 let value = 0.0;
390 let result = value.try_sqrt();
391 assert!(matches!(result, Ok(0.0)));
392 }
393
394 #[test]
395 fn test_f64_sqrt_nan() {
396 let value = f64::NAN;
397 let result = value.try_sqrt();
398 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
399 }
400
401 #[test]
402 fn test_f64_sqrt_infinite() {
403 let value = f64::INFINITY;
404 let result = value.try_sqrt();
405 println!("result: {result:?}");
406 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
407 }
408 }
409
410 mod complex {
411 use super::*;
412
413 #[test]
414 fn test_complex_f64_sqrt_valid() {
415 let value = Complex::new(4.0, 0.0);
416
417 let expected_result = Complex::new(2.0, 0.0);
418
419 assert_eq!(value.try_sqrt().unwrap(), expected_result);
420 assert_eq!(<Complex<f64> as Sqrt>::sqrt(value), expected_result);
421 }
422
423 #[test]
424 fn test_complex_f64_sqrt_invalid() {
425 let value = Complex::new(f64::NAN, 0.0);
426 let result = value.try_sqrt();
427 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
428
429 let value = Complex::new(0.0, f64::NAN);
430 let result = value.try_sqrt();
431 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
432
433 let value = Complex::new(f64::INFINITY, 0.0);
434 let result = value.try_sqrt();
435 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
436
437 let value = Complex::new(0.0, f64::INFINITY);
438 let result = value.try_sqrt();
439 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
440
441 let value = Complex::new(f64::MIN_POSITIVE / 2.0, 0.0);
442 let result = value.try_sqrt();
443 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
444
445 let value = Complex::new(0., f64::MIN_POSITIVE / 2.0);
446 let result = value.try_sqrt();
447 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
448 }
449 }
450 }
451
452 #[cfg(feature = "rug")]
453 mod rug53 {
454 use super::*;
455 use crate::backends::rug::validated::{ComplexRugStrictFinite, RealRugStrictFinite};
456
457 mod real {
458 use super::*;
459
460 #[test]
461 fn test_rug_float_sqrt_valid() {
462 let value =
463 RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 4.0)).unwrap();
464 let expected_result =
465 RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 2.0)).unwrap();
466
467 assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
468 assert_eq!(value.sqrt(), expected_result);
469 }
470
471 #[test]
472 fn test_rug_float_sqrt_negative_value() {
473 let value =
474 RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, -4.0)).unwrap();
475 let result = value.try_sqrt();
476 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
477 }
478 }
479
480 mod complex {
481 use super::*;
482
483 #[test]
484 fn test_complex_rug_float_sqrt_valid() {
485 let value = ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
486 53,
487 (rug::Float::with_val(53, 4.0), rug::Float::with_val(53, 0.0)),
488 ))
489 .unwrap();
490
491 let expected_result =
492 ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
493 53,
494 (rug::Float::with_val(53, 2.0), rug::Float::with_val(53, 0.0)),
495 ))
496 .unwrap();
497
498 assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
499 assert_eq!(value.sqrt(), expected_result);
500 }
501 }
502 }
503 }
504}