quad_rs/bounds/
mod.rs

1// This module defines traits which integration problems must fulfil
2
3use crate::EvaluationError;
4use nalgebra::ComplexField;
5use num_complex::Complex;
6use num_traits::float::{Float, FloatCore};
7use serde::{de::DeserializeOwned, Serialize};
8use std::fmt::{Debug, Display};
9use trellis_runner::TrellisFloat;
10
11pub trait RealIntegrableScalar:
12    IntegrableFloat
13    + crate::RescaleError
14    + crate::AccumulateError<Self>
15    + crate::IntegrationOutput<Scalar = Self, Float = Self>
16    + nalgebra::RealField
17{
18}
19
20impl RealIntegrableScalar for f64 {}
21impl RealIntegrableScalar for f32 {}
22
23pub trait IntegrableFloat:
24    Clone
25    + Debug
26    + Display
27    + FloatCore
28    + Float
29    + Serialize
30    + DeserializeOwned
31    + PartialOrd
32    + PartialEq
33    + TrellisFloat
34{
35}
36
37// All integrals must satisfy the integrate trait
38pub trait Integrable {
39    // An integration method takes an input
40    type Input;
41    // And converts it to an output
42    type Output: IntegrationOutput;
43
44    // The integral must provide the integrand at `input`
45    fn integrand(&self, input: &Self::Input) -> Result<Self::Output, EvaluationError<Self::Input>>;
46}
47
48pub trait RescaleError {
49    fn rescale(&self, result_abs: Self, result_asc: Self) -> Self;
50}
51
52// To be processed all `Outputs` must satisfy
53//
54// The integration output is the variable outputted by the integral.
55// It can be a real or complex quantity, and can be a vector or scalar value.
56pub trait IntegrationOutput:
57    Clone
58    + Default
59    + argmin_math::ArgminMul<<Self as IntegrationOutput>::Scalar, Self>
60    + argmin_math::ArgminDiv<<Self as IntegrationOutput>::Scalar, Self>
61    + argmin_math::ArgminAdd<Self, Self>
62    + argmin_math::ArgminSub<Self, Self>
63    + argmin_math::ArgminL2Norm<Self::Float>
64    + Send
65    + Sync
66{
67    // The real part of `IntegrationOutput`. For real `IntegrationOutput` this is `Self`.
68    type Real;
69    // The scalar part of `IntegrationOutput`. For scalar `IntegrationOutput` this is `Self`.
70    type Scalar: ComplexField<RealField = Self::Float>;
71    // The underlying float of `Self::Item`. For real `IntegrationOutput` this is `Self::Scalar`.
72    type Float: IntegrableFloat;
73
74    // Converts complex output to real output
75    fn modulus(&self) -> Self::Float;
76    // Returns false if the output contains NaN or infinity
77    fn is_finite(&self) -> bool;
78}
79
80// Converts the error distributed over 'IntegrationOutput` to a real scalar
81pub trait AccumulateError<R>: Send + Sync {
82    fn max(&self) -> R;
83    fn mean(&self) -> R;
84}
85
86impl IntegrableFloat for f32 {}
87impl IntegrableFloat for f64 {}
88
89// Accumulate
90
91impl AccumulateError<Self> for f64 {
92    fn max(&self) -> Self {
93        *self
94    }
95    fn mean(&self) -> Self {
96        *self
97    }
98}
99
100impl AccumulateError<Self> for f32 {
101    fn max(&self) -> Self {
102        *self
103    }
104    fn mean(&self) -> Self {
105        *self
106    }
107}
108
109#[cfg(feature = "ndarray")]
110impl<T: num_traits::float::FloatCore + num_traits::FromPrimitive + PartialOrd + Send + Sync>
111    AccumulateError<T> for ndarray::Array1<T>
112{
113    fn max(&self) -> T {
114        *ndarray_stats::QuantileExt::max(self).unwrap()
115    }
116    fn mean(&self) -> T {
117        self.mean().unwrap()
118    }
119}
120
121#[cfg(feature = "ndarray")]
122impl<T: num_traits::float::FloatCore + num_traits::FromPrimitive + PartialOrd + Send + Sync>
123    AccumulateError<T> for ndarray::Array2<T>
124{
125    fn max(&self) -> T {
126        *ndarray_stats::QuantileExt::max(self).unwrap()
127    }
128    fn mean(&self) -> T {
129        self.mean().unwrap()
130    }
131}
132
133// Rescale
134impl RescaleError for f32 {
135    fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
136        let mut error = self.abs();
137        if result_asc != 0.0 && error != 0.0 {
138            let exponent = 1.5;
139            let scale = ComplexField::powf(200. * error / result_asc, exponent);
140
141            if scale < 1. {
142                error = result_asc * scale;
143            } else {
144                error = result_asc;
145            }
146        }
147
148        if result_abs > f32::EPSILON / (50. * f32::EPSILON) {
149            let min_err = 50. * f32::EPSILON * result_abs;
150            if min_err > error {
151                error = min_err;
152            }
153        }
154        error
155    }
156}
157
158impl RescaleError for f64 {
159    fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
160        let mut error = self.abs();
161        if result_asc != 0.0 && error != 0.0 {
162            let exponent = 1.5;
163            let scale = ComplexField::powf(200. * error / result_asc, exponent);
164
165            if scale < 1. {
166                error = result_asc * scale;
167            } else {
168                error = result_asc;
169            }
170        }
171
172        if result_abs > f64::EPSILON / (50. * f64::EPSILON) {
173            let min_err = 50. * f64::EPSILON * result_abs;
174            if min_err > error {
175                error = min_err;
176            }
177        }
178        error
179    }
180}
181
182#[cfg(feature = "ndarray")]
183impl<T> RescaleError for ndarray::Array1<T>
184where
185    T: RescaleError,
186{
187    fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
188        self.iter()
189            .zip(result_abs)
190            .zip(result_asc)
191            .map(|((err, abs), asc)| err.rescale(abs, asc))
192            .collect()
193    }
194}
195
196#[cfg(feature = "ndarray")]
197impl<T> RescaleError for ndarray::Array2<T>
198where
199    T: RescaleError,
200{
201    fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
202        self.iter()
203            .zip(result_abs)
204            .zip(result_asc)
205            .map(|((err, abs), asc)| err.rescale(abs, asc))
206            .collect::<ndarray::Array1<T>>()
207            .into_shape(self.dim())
208            .unwrap()
209    }
210}
211
212// Output
213impl IntegrationOutput for Complex<f32> {
214    type Real = f32;
215    type Scalar = Self;
216    type Float = f32;
217
218    fn modulus(&self) -> Self::Real {
219        <Self as ComplexField>::modulus(*self)
220    }
221
222    fn is_finite(&self) -> bool {
223        ComplexField::is_finite(self)
224    }
225}
226
227impl IntegrationOutput for f32 {
228    type Real = Self;
229    type Scalar = Self;
230    type Float = Self;
231
232    fn modulus(&self) -> Self::Real {
233        <Self as ComplexField>::modulus(*self)
234    }
235
236    fn is_finite(&self) -> bool {
237        ComplexField::is_finite(self)
238    }
239}
240
241impl IntegrationOutput for Complex<f64> {
242    type Real = f64;
243    type Scalar = Self;
244    type Float = f64;
245
246    fn modulus(&self) -> Self::Real {
247        <Self as ComplexField>::modulus(*self)
248    }
249
250    fn is_finite(&self) -> bool {
251        ComplexField::is_finite(self)
252    }
253}
254
255impl IntegrationOutput for f64 {
256    type Real = Self;
257    type Scalar = Self;
258    type Float = Self;
259
260    fn modulus(&self) -> Self::Real {
261        <Self as ComplexField>::modulus(*self)
262    }
263
264    fn is_finite(&self) -> bool {
265        ComplexField::is_finite(self)
266    }
267}
268
269#[cfg(feature = "ndarray")]
270impl<T: ComplexField + Default> IntegrationOutput for ndarray::Array1<T>
271where
272    Self: argmin_math::ArgminAdd<Self, Self>
273        + argmin_math::ArgminSub<Self, Self>
274        + argmin_math::ArgminDiv<T, Self>
275        + argmin_math::ArgminMul<T, Self>
276        + argmin_math::ArgminL2Norm<<T as ComplexField>::RealField>,
277    ndarray::Array1<<T as ComplexField>::RealField>: argmin_math::ArgminAdd<
278            <T as ComplexField>::RealField,
279            ndarray::Array1<<T as ComplexField>::RealField>,
280        > + argmin_math::ArgminAdd<
281            ndarray::Array1<<T as ComplexField>::RealField>,
282            ndarray::Array1<<T as ComplexField>::RealField>,
283        > + argmin_math::ArgminMul<
284            <T as ComplexField>::RealField,
285            ndarray::Array1<<T as ComplexField>::RealField>,
286        > + AccumulateError<<T as ComplexField>::RealField>,
287    T: Copy,
288    <T as ComplexField>::RealField:
289        Default + FloatCore + IntegrableFloat + RescaleError + std::iter::Sum,
290{
291    type Real = ndarray::Array1<<T as ComplexField>::RealField>;
292    type Scalar = T;
293    type Float = <T as ComplexField>::RealField;
294
295    fn modulus(&self) -> Self::Float {
296        self.iter().map(|each| each.modulus()).sum()
297    }
298
299    fn is_finite(&self) -> bool {
300        self.iter().all(|value| ComplexField::is_finite(value))
301    }
302}
303
304#[cfg(feature = "ndarray")]
305impl<T: ComplexField + Default> IntegrationOutput for ndarray::Array2<T>
306where
307    Self: argmin_math::ArgminAdd<Self, Self>
308        + argmin_math::ArgminSub<Self, Self>
309        + argmin_math::ArgminDiv<T, Self>
310        + argmin_math::ArgminMul<T, Self>
311        + argmin_math::ArgminL2Norm<<T as ComplexField>::RealField>,
312    ndarray::Array2<<T as ComplexField>::RealField>: argmin_math::ArgminAdd<
313            <T as ComplexField>::RealField,
314            ndarray::Array2<<T as ComplexField>::RealField>,
315        > + argmin_math::ArgminAdd<
316            ndarray::Array2<<T as ComplexField>::RealField>,
317            ndarray::Array2<<T as ComplexField>::RealField>,
318        > + argmin_math::ArgminMul<
319            <T as ComplexField>::RealField,
320            ndarray::Array2<<T as ComplexField>::RealField>,
321        > + AccumulateError<<T as ComplexField>::RealField>,
322    T: Copy,
323    <T as ComplexField>::RealField:
324        Default + FloatCore + IntegrableFloat + RescaleError + std::iter::Sum,
325{
326    type Real = ndarray::Array2<<T as ComplexField>::RealField>;
327    type Scalar = T;
328    type Float = <T as ComplexField>::RealField;
329
330    // An L1 norm: todo maybe let the caller choose how to define the sum error
331    fn modulus(&self) -> Self::Float {
332        self.iter().map(|each| each.modulus()).sum()
333    }
334
335    fn is_finite(&self) -> bool {
336        self.iter().all(|value| ComplexField::is_finite(value))
337    }
338}