use crate::EvaluationError;
use nalgebra::ComplexField;
use num_complex::Complex;
use num_traits::float::{Float, FloatCore};
use serde::{de::DeserializeOwned, Serialize};
use std::fmt::{Debug, Display};
use trellis_runner::TrellisFloat;
pub trait RealIntegrableScalar:
IntegrableFloat
+ crate::RescaleError
+ crate::AccumulateError<Self>
+ crate::IntegrationOutput<Scalar = Self, Float = Self>
+ nalgebra::RealField
{
}
impl RealIntegrableScalar for f64 {}
impl RealIntegrableScalar for f32 {}
pub trait IntegrableFloat:
Clone
+ Debug
+ Display
+ FloatCore
+ Float
+ Serialize
+ DeserializeOwned
+ PartialOrd
+ PartialEq
+ TrellisFloat
{
}
pub trait Integrable {
type Input;
type Output: IntegrationOutput;
fn integrand(&self, input: &Self::Input) -> Result<Self::Output, EvaluationError<Self::Input>>;
}
pub trait RescaleError {
fn rescale(&self, result_abs: Self, result_asc: Self) -> Self;
}
pub trait IntegrationOutput:
Clone
+ Default
+ argmin_math::ArgminMul<<Self as IntegrationOutput>::Scalar, Self>
+ argmin_math::ArgminDiv<<Self as IntegrationOutput>::Scalar, Self>
+ argmin_math::ArgminAdd<Self, Self>
+ argmin_math::ArgminSub<Self, Self>
+ argmin_math::ArgminL2Norm<Self::Float>
+ Send
+ Sync
{
type Real;
type Scalar: ComplexField<RealField = Self::Float>;
type Float: IntegrableFloat;
fn modulus(&self) -> Self::Float;
fn is_finite(&self) -> bool;
}
pub trait AccumulateError<R>: Send + Sync {
fn max(&self) -> R;
fn mean(&self) -> R;
}
impl IntegrableFloat for f32 {}
impl IntegrableFloat for f64 {}
impl AccumulateError<Self> for f64 {
fn max(&self) -> Self {
*self
}
fn mean(&self) -> Self {
*self
}
}
impl AccumulateError<Self> for f32 {
fn max(&self) -> Self {
*self
}
fn mean(&self) -> Self {
*self
}
}
#[cfg(feature = "ndarray")]
impl<T: num_traits::float::FloatCore + num_traits::FromPrimitive + PartialOrd + Send + Sync>
AccumulateError<T> for ndarray::Array1<T>
{
fn max(&self) -> T {
*ndarray_stats::QuantileExt::max(self).unwrap()
}
fn mean(&self) -> T {
self.mean().unwrap()
}
}
#[cfg(feature = "ndarray")]
impl<T: num_traits::float::FloatCore + num_traits::FromPrimitive + PartialOrd + Send + Sync>
AccumulateError<T> for ndarray::Array2<T>
{
fn max(&self) -> T {
*ndarray_stats::QuantileExt::max(self).unwrap()
}
fn mean(&self) -> T {
self.mean().unwrap()
}
}
impl RescaleError for f32 {
fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
let mut error = self.abs();
if result_asc != 0.0 && error != 0.0 {
let exponent = 1.5;
let scale = ComplexField::powf(200. * error / result_asc, exponent);
if scale < 1. {
error = result_asc * scale;
} else {
error = result_asc;
}
}
if result_abs > f32::EPSILON / (50. * f32::EPSILON) {
let min_err = 50. * f32::EPSILON * result_abs;
if min_err > error {
error = min_err;
}
}
error
}
}
impl RescaleError for f64 {
fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
let mut error = self.abs();
if result_asc != 0.0 && error != 0.0 {
let exponent = 1.5;
let scale = ComplexField::powf(200. * error / result_asc, exponent);
if scale < 1. {
error = result_asc * scale;
} else {
error = result_asc;
}
}
if result_abs > f64::EPSILON / (50. * f64::EPSILON) {
let min_err = 50. * f64::EPSILON * result_abs;
if min_err > error {
error = min_err;
}
}
error
}
}
#[cfg(feature = "ndarray")]
impl<T> RescaleError for ndarray::Array1<T>
where
T: RescaleError,
{
fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
self.iter()
.zip(result_abs)
.zip(result_asc)
.map(|((err, abs), asc)| err.rescale(abs, asc))
.collect()
}
}
#[cfg(feature = "ndarray")]
impl<T> RescaleError for ndarray::Array2<T>
where
T: RescaleError,
{
fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
self.iter()
.zip(result_abs)
.zip(result_asc)
.map(|((err, abs), asc)| err.rescale(abs, asc))
.collect::<ndarray::Array1<T>>()
.into_shape(self.dim())
.unwrap()
}
}
impl IntegrationOutput for Complex<f32> {
type Real = f32;
type Scalar = Self;
type Float = f32;
fn modulus(&self) -> Self::Real {
<Self as ComplexField>::modulus(*self)
}
fn is_finite(&self) -> bool {
ComplexField::is_finite(self)
}
}
impl IntegrationOutput for f32 {
type Real = Self;
type Scalar = Self;
type Float = Self;
fn modulus(&self) -> Self::Real {
<Self as ComplexField>::modulus(*self)
}
fn is_finite(&self) -> bool {
ComplexField::is_finite(self)
}
}
impl IntegrationOutput for Complex<f64> {
type Real = f64;
type Scalar = Self;
type Float = f64;
fn modulus(&self) -> Self::Real {
<Self as ComplexField>::modulus(*self)
}
fn is_finite(&self) -> bool {
ComplexField::is_finite(self)
}
}
impl IntegrationOutput for f64 {
type Real = Self;
type Scalar = Self;
type Float = Self;
fn modulus(&self) -> Self::Real {
<Self as ComplexField>::modulus(*self)
}
fn is_finite(&self) -> bool {
ComplexField::is_finite(self)
}
}
#[cfg(feature = "ndarray")]
impl<T: ComplexField + Default> IntegrationOutput for ndarray::Array1<T>
where
Self: argmin_math::ArgminAdd<Self, Self>
+ argmin_math::ArgminSub<Self, Self>
+ argmin_math::ArgminDiv<T, Self>
+ argmin_math::ArgminMul<T, Self>
+ argmin_math::ArgminL2Norm<<T as ComplexField>::RealField>,
ndarray::Array1<<T as ComplexField>::RealField>: argmin_math::ArgminAdd<
<T as ComplexField>::RealField,
ndarray::Array1<<T as ComplexField>::RealField>,
> + argmin_math::ArgminAdd<
ndarray::Array1<<T as ComplexField>::RealField>,
ndarray::Array1<<T as ComplexField>::RealField>,
> + argmin_math::ArgminMul<
<T as ComplexField>::RealField,
ndarray::Array1<<T as ComplexField>::RealField>,
> + AccumulateError<<T as ComplexField>::RealField>,
T: Copy,
<T as ComplexField>::RealField:
Default + FloatCore + IntegrableFloat + RescaleError + std::iter::Sum,
{
type Real = ndarray::Array1<<T as ComplexField>::RealField>;
type Scalar = T;
type Float = <T as ComplexField>::RealField;
fn modulus(&self) -> Self::Float {
self.iter().map(|each| each.modulus()).sum()
}
fn is_finite(&self) -> bool {
self.iter().all(|value| ComplexField::is_finite(value))
}
}
#[cfg(feature = "ndarray")]
impl<T: ComplexField + Default> IntegrationOutput for ndarray::Array2<T>
where
Self: argmin_math::ArgminAdd<Self, Self>
+ argmin_math::ArgminSub<Self, Self>
+ argmin_math::ArgminDiv<T, Self>
+ argmin_math::ArgminMul<T, Self>
+ argmin_math::ArgminL2Norm<<T as ComplexField>::RealField>,
ndarray::Array2<<T as ComplexField>::RealField>: argmin_math::ArgminAdd<
<T as ComplexField>::RealField,
ndarray::Array2<<T as ComplexField>::RealField>,
> + argmin_math::ArgminAdd<
ndarray::Array2<<T as ComplexField>::RealField>,
ndarray::Array2<<T as ComplexField>::RealField>,
> + argmin_math::ArgminMul<
<T as ComplexField>::RealField,
ndarray::Array2<<T as ComplexField>::RealField>,
> + AccumulateError<<T as ComplexField>::RealField>,
T: Copy,
<T as ComplexField>::RealField:
Default + FloatCore + IntegrableFloat + RescaleError + std::iter::Sum,
{
type Real = ndarray::Array2<<T as ComplexField>::RealField>;
type Scalar = T;
type Float = <T as ComplexField>::RealField;
fn modulus(&self) -> Self::Float {
self.iter().map(|each| each.modulus()).sum()
}
fn is_finite(&self) -> bool {
self.iter().all(|value| ComplexField::is_finite(value))
}
}