#![deny(rustdoc::broken_intra_doc_links)]
use crate::{
core::{errors::capture_backtrace, policies::StrictFinitePolicy},
functions::FunctionErrors,
kernels::{RawComplexTrait, RawRealTrait, RawScalarTrait},
};
use num::Complex;
use std::{backtrace::Backtrace, fmt};
use thiserror::Error;
use try_create::ValidationPolicy;
#[derive(Debug, Error)]
pub enum SqrtRealInputErrors<RawReal: RawRealTrait> {
#[error("the input value ({value:?}) is negative!")]
NegativeValue {
value: RawReal,
backtrace: Backtrace,
},
#[error("the input value is invalid!")]
ValidationError {
#[source]
#[backtrace]
source: <RawReal as RawScalarTrait>::ValidationErrors,
},
}
#[derive(Debug, Error)]
pub enum SqrtComplexInputErrors<RawComplex: RawComplexTrait> {
#[error("the input value is invalid!")]
ValidationError {
#[source]
#[backtrace]
source: <RawComplex as RawScalarTrait>::ValidationErrors,
},
}
pub type SqrtRealErrors<RawReal> =
FunctionErrors<SqrtRealInputErrors<RawReal>, <RawReal as RawScalarTrait>::ValidationErrors>;
pub type SqrtComplexErrors<RawComplex> = FunctionErrors<
SqrtComplexInputErrors<RawComplex>,
<RawComplex as RawScalarTrait>::ValidationErrors,
>;
pub trait Sqrt: Sized {
type Error: fmt::Debug;
#[must_use = "this `Result` may contain an error that should be handled"]
fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error>;
fn sqrt(self) -> Self;
}
impl Sqrt for f64 {
type Error = SqrtRealErrors<f64>;
#[inline(always)]
fn try_sqrt(self) -> Result<f64, <f64 as Sqrt>::Error> {
StrictFinitePolicy::<f64, 53>::validate(self)
.map_err(|e| SqrtRealInputErrors::ValidationError { source: e }.into())
.and_then(|validated_value| {
if validated_value < 0.0 {
Err(SqrtRealInputErrors::NegativeValue {
value: validated_value,
backtrace: capture_backtrace(),
}
.into())
} else {
StrictFinitePolicy::<f64, 53>::validate(f64::sqrt(validated_value))
.map_err(|e| SqrtRealErrors::Output { source: e })
}
})
}
#[inline(always)]
fn sqrt(self) -> Self {
#[cfg(debug_assertions)]
{
self.try_sqrt().unwrap()
}
#[cfg(not(debug_assertions))]
{
f64::sqrt(self)
}
}
}
impl Sqrt for Complex<f64> {
type Error = SqrtComplexErrors<Complex<f64>>;
#[inline(always)]
fn try_sqrt(self) -> Result<Self, <Self as Sqrt>::Error> {
StrictFinitePolicy::<Complex<f64>, 53>::validate(self)
.map_err(|e| SqrtComplexInputErrors::ValidationError { source: e }.into())
.and_then(|validated_value| {
StrictFinitePolicy::<Complex<f64>, 53>::validate(Complex::<f64>::sqrt(
validated_value,
))
.map_err(|e| SqrtComplexErrors::Output { source: e })
})
}
#[inline(always)]
fn sqrt(self) -> Self {
#[cfg(debug_assertions)]
{
self.try_sqrt().unwrap()
}
#[cfg(not(debug_assertions))]
{
Complex::<f64>::sqrt(self)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use num::Complex;
#[cfg(feature = "rug")]
use try_create::TryNew;
mod sqrt {
use super::*;
mod native64 {
use super::*;
mod real {
use super::*;
#[test]
fn test_f64_sqrt_valid() {
let value = 4.0;
assert_eq!(value.try_sqrt().unwrap(), 2.0);
assert_eq!(<f64 as Sqrt>::sqrt(value), 2.0);
}
#[test]
fn test_f64_sqrt_negative_value() {
let value = -4.0;
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
}
#[test]
fn test_f64_sqrt_subnormal() {
let value = f64::MIN_POSITIVE / 2.0;
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
}
#[test]
fn test_f64_sqrt_zero() {
let value = 0.0;
let result = value.try_sqrt();
assert!(matches!(result, Ok(0.0)));
}
#[test]
fn test_f64_sqrt_nan() {
let value = f64::NAN;
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
}
#[test]
fn test_f64_sqrt_infinite() {
let value = f64::INFINITY;
let result = value.try_sqrt();
println!("result: {result:?}");
assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
}
}
mod complex {
use super::*;
#[test]
fn test_complex_f64_sqrt_valid() {
let value = Complex::new(4.0, 0.0);
let expected_result = Complex::new(2.0, 0.0);
assert_eq!(value.try_sqrt().unwrap(), expected_result);
assert_eq!(<Complex<f64> as Sqrt>::sqrt(value), expected_result);
}
#[test]
fn test_complex_f64_sqrt_invalid() {
let value = Complex::new(f64::NAN, 0.0);
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
let value = Complex::new(0.0, f64::NAN);
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
let value = Complex::new(f64::INFINITY, 0.0);
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
let value = Complex::new(0.0, f64::INFINITY);
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
let value = Complex::new(f64::MIN_POSITIVE / 2.0, 0.0);
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
let value = Complex::new(0., f64::MIN_POSITIVE / 2.0);
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtComplexErrors::Input { .. })));
}
}
}
#[cfg(feature = "rug")]
mod rug53 {
use super::*;
use crate::backends::rug::validated::{ComplexRugStrictFinite, RealRugStrictFinite};
mod real {
use super::*;
#[test]
fn test_rug_float_sqrt_valid() {
let value =
RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 4.0)).unwrap();
let expected_result =
RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, 2.0)).unwrap();
assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
assert_eq!(value.sqrt(), expected_result);
}
#[test]
fn test_rug_float_sqrt_negative_value() {
let value =
RealRugStrictFinite::<53>::try_new(rug::Float::with_val(53, -4.0)).unwrap();
let result = value.try_sqrt();
assert!(matches!(result, Err(SqrtRealErrors::Input { .. })));
}
}
mod complex {
use super::*;
#[test]
fn test_complex_rug_float_sqrt_valid() {
let value = ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
53,
(rug::Float::with_val(53, 4.0), rug::Float::with_val(53, 0.0)),
))
.unwrap();
let expected_result =
ComplexRugStrictFinite::<53>::try_new(rug::Complex::with_val(
53,
(rug::Float::with_val(53, 2.0), rug::Float::with_val(53, 0.0)),
))
.unwrap();
assert_eq!(value.clone().try_sqrt().unwrap(), expected_result);
assert_eq!(value.sqrt(), expected_result);
}
}
}
}
}