num_valid/functions/sqrt.rs
1#![deny(rustdoc::broken_intra_doc_links)]
2
3use crate::{
4 functions::FunctionErrors,
5 kernels::{RawComplexTrait, RawRealTrait, RawScalarTrait},
6 validation::StrictFinitePolicy,
7};
8use num::Complex;
9use std::{backtrace::Backtrace, fmt};
10use thiserror::Error;
11use try_create::ValidationPolicy;
12
13//------------------------------------------------------------------------------------------------
14/// Errors that can occur during the input validation phase when attempting to compute
15/// the square root of a real number.
16///
17/// This enum is used as the source for the [`Input`](SqrtRealErrors::Input) variant of [`SqrtRealErrors`].
18/// It is generic over `RawReal: RawRealTrait`, which defines the specific raw real number type and its associated
19/// validation error type.
20///
21/// # Variants
22///
23/// - [`Self::NegativeValue`]: The input value is negative, which is not allowed for real square roots.
24/// - [`Self::ValidationError`]: The input value failed basic validation checks (e.g., NaN, infinity, subnormal) according to the validation policy.
25#[derive(Debug, Error)]
26pub enum SqrtRealInputErrors<RawReal: RawRealTrait> {
27 /// The input value for the square root computation is negative.
28 ///
29 /// The square root of a negative real number is not a real number.
30 #[error("the input value ({value:?}) is negative!")]
31 NegativeValue {
32 /// The negative input value.
33 value: RawReal,
34
35 /// The backtrace of the error.
36 backtrace: Backtrace,
37 },
38
39 /// The input value failed basic validation checks (e.g., it is NaN, infinite, or subnormal).
40 ///
41 /// This error occurs if the input value itself is considered invalid
42 /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
43 /// before the domain-specific check (like negativity) for the square root
44 /// operation is performed.
45 #[error("the input value is invalid!")]
46 ValidationError {
47 /// The source error that occurred during validation.
48 #[source]
49 #[backtrace]
50 source: <RawReal as RawScalarTrait>::ValidationErrors,
51 },
52}
53
54/// Errors that can occur during the input validation phase when computing the square root of a complex number.
55///
56/// This enum is used as the source for the [`Input`](SqrtComplexErrors::Input) variant of [`SqrtComplexErrors`].
57/// It is generic over `RawComplex: RawComplexTrait`.
58///
59/// # Variants
60///
61/// - [`Self::ValidationError`]: The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal) according to the validation policy.
62#[derive(Debug, Error)]
63pub enum SqrtComplexInputErrors<RawComplex: RawComplexTrait> {
64 /// The input complex number failed basic validation checks (e.g., its components are NaN, infinite, or subnormal).
65 ///
66 /// This error occurs if the input complex value itself is considered invalid
67 /// according to the validation policy (e.g., [`StrictFinitePolicy`]),
68 /// before any domain-specific checks for the complex square root operation are performed.
69 /// It signifies that the real or imaginary part (or both) does not meet fundamental validity criteria.
70 #[error("the input value is invalid!")]
71 ValidationError {
72 /// The detailed source error from the raw complex number's validation.
73 ///
74 /// This field encapsulates the specific error of type `<RawComplex as RawScalarTrait>::ValidationErrors`
75 /// that was reported during the validation of the input complex number's components.
76 #[source]
77 #[backtrace]
78 source: <RawComplex as RawScalarTrait>::ValidationErrors,
79 },
80}
81
82/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
83/// the computation of the square root of a real number.
84///
85/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a real number.
86///
87/// # Generic Parameters
88///
89/// - `RawReal`: A type that implements [`RawRealTrait`]. This defines:
90/// - The raw error type for the input real number via `SqrtRealInputErrors<RawReal>`.
91/// - The raw error type for the output real number (the square root) also via `<RawReal as RawScalarTrait>::ValidationErrors`.
92///
93/// # Variants
94///
95/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
96///
97/// - `Input { source: SqrtRealInputErrors<RawReal> }`:
98/// Indicates that the input real number was invalid for square root computation.
99/// This could be due to the number being negative or failing general validation checks
100/// (e.g., NaN, Infinity, subnormal). The `source` field provides more specific details
101/// via [`SqrtRealInputErrors`].
102///
103/// - `Output { source: <RawReal as RawScalarTrait>::ValidationErrors }`:
104/// Indicates that the computed square root (which should be a real number)
105/// failed validation. This typically means the result of the `sqrt` operation yielded
106/// a non-finite value (NaN or Infinity), which is unexpected if the input was valid
107/// (non-negative and finite).
108pub type SqrtRealErrors<RawReal> =
109 FunctionErrors<SqrtRealInputErrors<RawReal>, <RawReal as RawScalarTrait>::ValidationErrors>;
110
111/// A type alias for [`FunctionErrors`], specialized for errors that can occur during
112/// the computation of the square root of a complex number.
113///
114/// This type represents the possible failures when calling [`Sqrt::try_sqrt()`] on a complex number.
115///
116/// # Generic Parameters
117///
118/// - `RawComplex`: A type that implements [`RawComplexTrait`]. This defines:
119/// - The raw error type for the input complex number via `SqrtComplexInputErrors<RawComplex>`.
120/// - The raw error type for the output complex number (the square root) also via `<RawComplex as RawScalarTrait>::ValidationErrors`.
121///
122/// # Variants
123///
124/// This type alias wraps [`FunctionErrors`], which has the following variants in this context:
125///
126/// - `Input { source: SqrtComplexInputErrors<RawComplex> }`:
127/// Indicates that the input complex number was invalid for square root computation.
128/// This is typically due to the complex number's components (real or imaginary parts)
129/// failing general validation checks (e.g., NaN, Infinity, subnormal).
130/// The `source` field provides more specific details via [`SqrtComplexInputErrors`].
131///
132/// - `Output { source: <RawComplex as RawScalarTrait>::ValidationErrors }`:
133/// Indicates that the computed complex square root itself failed validation.
134/// This typically means the result of the `sqrt` operation yielded a complex number
135/// with non-finite components (NaN or Infinity), which is unexpected if the input was valid.
136pub type SqrtComplexErrors<RawComplex> = FunctionErrors<
137 SqrtComplexInputErrors<RawComplex>,
138 <RawComplex as RawScalarTrait>::ValidationErrors,
139>;
140//--------------------------------------------------------------------------------------------
141
142//--------------------------------------------------------------------------------------------
143/// A trait for computing the principal square root of a number.
144///
145/// The principal square root of a non-negative real number `x` is the unique non-negative real number `y`
146/// such that `y^2 = x`.
147/// For a complex number `z`, its square root `w` satisfies `w^2 = z`. Complex numbers
148/// (except 0) have two square roots; this trait computes the principal square root,
149/// typically defined as `exp(0.5 * log(z))`, which usually has a non-negative real part.
150///
151/// This trait provides both a fallible version ([`try_sqrt`](Sqrt::try_sqrt)) that performs validation
152/// and an infallible version ([`sqrt`](Sqrt::sqrt)) that may panic in debug builds if validation fails.
153pub trait Sqrt: Sized {
154 /// The error type that can be returned by the `try_sqrt` method.
155 type Error: fmt::Debug;
156
157 /// Attempts to compute the principal square root of `self`, returning a `Result`.
158 ///
159 /// Implementations should validate the input `self` according to the domain
160 /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
161 /// If the input is valid, the square root is computed, and then the result
162 /// is also validated using the same policy.
163 ///
164 /// # Returns
165 ///
166 /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
167 /// and the computed square root satisfy the validation policy.
168 /// - `Err(SqrtRealErrors)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
169 /// or if the computed square root is invalid (see below).
170 ///
171 /// # Errors
172 ///
173 /// - Returns [`SqrtRealErrors::Input`] (for reals) or [`SqrtComplexErrors::Input`] (for complex)
174 /// via [`FunctionErrors::Input`] if the input is invalid (e.g., negative real, NaN, Infinity, subnormal).
175 /// - Returns [`SqrtRealErrors::Output`] or [`SqrtComplexErrors::Output`] via [`FunctionErrors::Output`]
176 /// if the result of the computation is not finite (e.g., NaN, Infinity) as per the validation policy.
177 ///
178 /// # Examples
179 ///
180 /// ```
181 /// use num_valid::functions::Sqrt;
182 /// use num::Complex;
183 ///
184 /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
185 /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
186 /// assert!(f64::NAN.try_sqrt().is_err());
187 ///
188 /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
189 /// let sqrt_z = z.try_sqrt().unwrap();
190 /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
191 /// ```
192 fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error>;
193
194 /// Computes the square principal square root of `self`.
195 fn sqrt(self) -> Self;
196}
197
198impl Sqrt for f64 {
199 type Error = SqrtRealErrors<f64>;
200
201 /// Attempts to compute the principal square root of `self`, returning a `Result`.
202 ///
203 /// Implementations should validate the input `self` according to the domain
204 /// (e.g., non-negative for reals) and a general validation policy (e.g., [`StrictFinitePolicy`]).
205 /// If the input is valid, the square root is computed, and then the result
206 /// is also validated using the same policy.
207 ///
208 /// # Returns
209 ///
210 /// - `Ok(Self)`: If the input is valid for the square root operation and both the input
211 /// and the computed square root satisfy the validation policy.
212 /// - `Err(Self::Error)`: If the input is invalid (e.g., negative for real sqrt, NaN, Infinity)
213 /// or if the computed square root is invalid.
214 ///
215 /// # Examples
216 ///
217 /// ```
218 /// use num_valid::functions::Sqrt;
219 /// use num::Complex;
220 ///
221 /// assert_eq!(4.0_f64.try_sqrt().unwrap(), 2.0);
222 /// assert!((-1.0_f64).try_sqrt().is_err()); // Negative real
223 /// assert!(f64::NAN.try_sqrt().is_err());
224 ///
225 /// let z = Complex::new(-4.0, 0.0); // sqrt(-4) = 2i
226 /// let sqrt_z = z.try_sqrt().unwrap();
227 /// assert!((sqrt_z.re).abs() < 1e-9 && (sqrt_z.im - 2.0).abs() < 1e-9);
228 /// ```
229 #[inline(always)]
230 fn try_sqrt(self) -> Result<f64, <f64 as Sqrt>::Error> {
231 StrictFinitePolicy::<f64, 53>::validate(self)
232 .map_err(|e| SqrtRealInputErrors::ValidationError { source: e }.into())
233 .and_then(|validated_value| {
234 if validated_value < 0.0 {
235 Err(SqrtRealInputErrors::NegativeValue {
236 value: validated_value,
237 backtrace: Backtrace::force_capture(),
238 }
239 .into())
240 } else {
241 StrictFinitePolicy::<f64, 53>::validate(f64::sqrt(validated_value))
242 .map_err(|e| SqrtRealErrors::Output { source: e })
243 }
244 })
245 }
246
247 /// Computes and returns the principal square root of `self`.
248 ///
249 /// # Behavior
250 ///
251 /// - **Debug Builds (`#[cfg(debug_assertions)]`)**: This method internally calls
252 /// [`try_sqrt().unwrap()`](Sqrt::try_sqrt). It will panic if `try_sqrt` returns an `Err`.
253 /// - **Release Builds (`#[cfg(not(debug_assertions))]`)**: This method calls the underlying
254 /// square root function directly (e.g., `f64::sqrt`).
255 /// The behavior for invalid inputs (like `sqrt(-1.0)` for `f64` returning NaN)
256 /// depends on the underlying implementation.
257 ///
258 /// # Panics
259 ///
260 /// In debug builds, this method will panic if [`try_sqrt()`](Sqrt::try_sqrt) would return an `Err`.
261 ///
262 /// # Examples
263 ///
264 /// ```
265 /// use num_valid::functions::Sqrt;
266 /// use num::Complex;
267 ///
268 /// assert_eq!(9.0_f64.sqrt(), 3.0);
269 ///
270 /// // For f64, sqrt of negative is NaN in release, panics in debug with ftl's Sqrt
271 /// #[cfg(not(debug_assertions))]
272 /// assert!((-1.0_f64).sqrt().is_nan());
273 ///
274 /// let z = Complex::new(0.0, 4.0); // sqrt(4i) = sqrt(2) + i*sqrt(2)
275 /// let sqrt_z = z.sqrt();
276 /// let expected_val = std::f64::consts::SQRT_2;
277 /// assert!((sqrt_z.re - expected_val).abs() < 1e-9 && (sqrt_z.im - expected_val).abs() < 1e-9);
278 /// ```
279 #[inline(always)]
280 fn sqrt(self) -> Self {
281 #[cfg(debug_assertions)]
282 {
283 self.try_sqrt().unwrap()
284 }
285 #[cfg(not(debug_assertions))]
286 {
287 f64::sqrt(self)
288 }
289 }
290}
291
292impl Sqrt for Complex<f64> {
293 type Error = SqrtComplexErrors<Complex<f64>>;
294
295 /// Attempts to compute the principal square root of `self` (a `Complex<f64>`).
296 ///
297 /// This method first validates `self` using [`StrictFinitePolicy`] (components must be finite and normal).
298 /// If valid, it computes `Complex::sqrt` and validates the result using [`StrictFinitePolicy`].
299 ///
300 /// # Returns
301 ///
302 /// - `Ok(Complex<f64>)`: If `self` and the computed square root have finite and normal components.
303 /// - `Err(SqrtComplexErrors<Complex<f64>>)`: If `self` or the result has invalid components.
304 #[inline(always)]
305 fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error> {
306 StrictFinitePolicy::<Complex<f64>, 53>::validate(self)
307 .map_err(|e| SqrtComplexInputErrors::ValidationError { source: e }.into())
308 .and_then(|validated_value| {
309 StrictFinitePolicy::<Complex<f64>, 53>::validate(Complex::<f64>::sqrt(
310 validated_value,
311 ))
312 .map_err(|e| SqrtComplexErrors::Output { source: e })
313 })
314 }
315
316 /// Computes and returns the principal square root of `self` (a `Complex<f64>`).
317 ///
318 /// # Behavior
319 ///
320 /// - **Debug Builds**: Calls `try_sqrt().unwrap()`. Panics on invalid input/output.
321 /// - **Release Builds**: Calls `Complex::sqrt(self)` directly.
322 ///
323 /// # Panics
324 ///
325 /// In debug builds, if `try_sqrt()` would return an `Err`.
326 #[inline(always)]
327 fn sqrt(self) -> Self {
328 #[cfg(debug_assertions)]
329 {
330 self.try_sqrt().unwrap()
331 }
332 #[cfg(not(debug_assertions))]
333 {
334 Complex::<f64>::sqrt(self)
335 }
336 }
337}
338
339//------------------------------------------------------------------------------------------------
340
341//------------------------------------------------------------------------------------------------
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use num::Complex;
346
347 #[cfg(feature = "rug")]
348 use try_create::TryNew;
349
350 mod sqrt {
351 use super::*;
352
353 mod native64 {
354 use super::*;
355
356 mod real {
357 use super::*;
358
359 #[test]
360 fn test_f64_sqrt_valid() {
361 let value = 4.0;
362
363 assert_eq!(value.try_sqrt().unwrap(), 2.0);
364 assert_eq!(<f64 as Sqrt>::sqrt(value), 2.0);
365 }
366
367 #[test]
368 fn test_f64_sqrt_negative_value() {
369 let value = -4.0;
370 let result = value.try_sqrt();
371 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
372 }
373
374 #[test]
375 fn test_f64_sqrt_subnormal() {
376 let value = f64::MIN_POSITIVE / 2.0;
377 let result = value.try_sqrt();
378 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
379 }
380
381 #[test]
382 fn test_f64_sqrt_zero() {
383 let value = 0.0;
384 let result = value.try_sqrt();
385 assert!(matches!(result, Ok(0.0)));
386 }
387
388 #[test]
389 fn test_f64_sqrt_nan() {
390 let value = f64::NAN;
391 let result = value.try_sqrt();
392 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
393 }
394
395 #[test]
396 fn test_f64_sqrt_infinite() {
397 let value = f64::INFINITY;
398 let result = value.try_sqrt();
399 println!("result: {result:?}");
400 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
401 }
402 }
403
404 mod complex {
405 use super::*;
406
407 #[test]
408 fn test_complex_f64_sqrt_valid() {
409 let value = Complex::new(4.0, 0.0);
410
411 let expected_result = Complex::new(2.0, 0.0);
412
413 assert_eq!(value.try_sqrt().unwrap(), expected_result);
414 assert_eq!(<Complex<f64> as Sqrt>::sqrt(value), expected_result);
415 }
416
417 #[test]
418 fn test_complex_f64_sqrt_invalid() {
419 let value = Complex::new(f64::NAN, 0.0);
420 let result = value.try_sqrt();
421 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
422
423 let value = Complex::new(0.0, f64::NAN);
424 let result = value.try_sqrt();
425 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
426
427 let value = Complex::new(f64::INFINITY, 0.0);
428 let result = value.try_sqrt();
429 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
430
431 let value = Complex::new(0.0, f64::INFINITY);
432 let result = value.try_sqrt();
433 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
434
435 let value = Complex::new(f64::MIN_POSITIVE / 2.0, 0.0);
436 let result = value.try_sqrt();
437 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
438
439 let value = Complex::new(0., f64::MIN_POSITIVE / 2.0);
440 let result = value.try_sqrt();
441 assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
442 }
443 }
444 }
445
446 #[cfg(feature = "rug")]
447 mod rug53 {
448 use super::*;
449 use crate::kernels::rug::{ComplexRugStrictFinite, RealRugStrictFinite};
450
451 mod real {
452 use super::*;
453
454 #[test]
455 fn test_rug_float_sqrt_valid() {
456 let value =
457 RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 4.0)).unwrap();
458 let expected_result =
459 RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 2.0)).unwrap();
460
461 assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
462 assert_eq!(value.sqrt(), expected_result);
463 }
464
465 #[test]
466 fn test_rug_float_sqrt_negative_value() {
467 let value =
468 RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, -4.0)).unwrap();
469 let result = value.try_sqrt();
470 assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
471 }
472 }
473
474 mod complex {
475 use super::*;
476
477 #[test]
478 fn test_complex_rug_float_sqrt_valid() {
479 let value = ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
480 53,
481 (rug::Float::with_val(53, 4.0), rug::Float::with_val(53, 0.0)),
482 ))
483 .unwrap();
484
485 let expected_result =
486 ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
487 53,
488 (rug::Float::with_val(53, 2.0), rug::Float::with_val(53, 0.0)),
489 ))
490 .unwrap();
491
492 assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
493 assert_eq!(value.sqrt(), expected_result);
494 }
495 }
496 }
497 }
498}