use core::any::type_name;
use core::fmt::{Debug, Display};
use core::iter::{Product, Sum};
use num_traits::{Num, NumCast, One, Zero};
use serde::Serialize;
use serde::de::DeserializeOwned;
pub(crate) mod scalar_sealed {
pub trait Sealed {}
}
use scalar_sealed::Sealed;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScalarCastError {
RealPartOutOfRange {
source: &'static str,
target: &'static str,
},
ImagPartOutOfRange {
source: &'static str,
target: &'static str,
},
}
impl core::fmt::Display for ScalarCastError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::RealPartOutOfRange { source, target } => write!(
f,
"real component of scalar type {source} cannot be represented as {target}"
),
Self::ImagPartOutOfRange { source, target } => write!(
f,
"imaginary component of scalar type {source} cannot be represented as {target}"
),
}
}
}
impl std::error::Error for ScalarCastError {}
pub trait Scalar:
Num
+ NumCast
+ Zero
+ One
+ Copy
+ Clone
+ Default
+ Send
+ Sync
+ 'static
+ Debug
+ Display
+ Sum<Self>
+ Product<Self>
+ Sealed
{
type Real: Num
+ NumCast
+ Zero
+ One
+ Copy
+ Clone
+ Default
+ Send
+ Sync
+ 'static
+ Debug
+ Display
+ Sum<Self::Real>
+ Product<Self::Real>
+ Scalar<Real = Self::Real>;
fn conj(self) -> Self;
#[inline]
fn abs(self) -> Self {
Self::from_re_im(self.abs_real(), Self::Real::zero())
}
#[inline]
fn norm_sqr(self) -> Self {
Self::from_re_im(self.norm_sqr_real(), Self::Real::zero())
}
fn sqrt(self) -> Self;
fn re(self) -> Self::Real;
fn im(self) -> Self::Real;
fn abs_real(self) -> Self::Real;
fn norm_sqr_real(self) -> Self::Real;
#[inline]
fn try_cast<U: Scalar>(self) -> Result<U, ScalarCastError> {
let source = type_name::<Self>();
let target = type_name::<U>();
let source_re = self.re();
let source_im = self.im();
let re = <U::Real as NumCast>::from(source_re)
.ok_or(ScalarCastError::RealPartOutOfRange { source, target })?;
if source_re.is_finite() && !re.is_finite() {
return Err(ScalarCastError::RealPartOutOfRange { source, target });
}
let im = <U::Real as NumCast>::from(source_im)
.ok_or(ScalarCastError::ImagPartOutOfRange { source, target })?;
if source_im.is_finite() && !im.is_finite() {
return Err(ScalarCastError::ImagPartOutOfRange { source, target });
}
Ok(U::from_re_im(re, im))
}
#[inline]
fn cast<U: Scalar>(self) -> U {
self.try_cast::<U>().expect("scalar cast failed")
}
fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
fn is_finite(self) -> bool;
}
pub trait ScalarSerde: Scalar + Serialize + DeserializeOwned {}
impl<T> ScalarSerde for T where T: Scalar + Serialize + DeserializeOwned {}