1use crate::EvaluationError;
4use nalgebra::ComplexField;
5use num_complex::Complex;
6use num_traits::float::{Float, FloatCore};
7use serde::{de::DeserializeOwned, Serialize};
8use std::fmt::{Debug, Display};
9use trellis_runner::TrellisFloat;
10
11pub trait RealIntegrableScalar:
12 IntegrableFloat
13 + crate::RescaleError
14 + crate::AccumulateError<Self>
15 + crate::IntegrationOutput<Scalar = Self, Float = Self>
16 + nalgebra::RealField
17{
18}
19
20impl RealIntegrableScalar for f64 {}
21impl RealIntegrableScalar for f32 {}
22
23pub trait IntegrableFloat:
24 Clone
25 + Debug
26 + Display
27 + FloatCore
28 + Float
29 + Serialize
30 + DeserializeOwned
31 + PartialOrd
32 + PartialEq
33 + TrellisFloat
34{
35}
36
37pub trait Integrable {
39 type Input;
41 type Output: IntegrationOutput;
43
44 fn integrand(&self, input: &Self::Input) -> Result<Self::Output, EvaluationError<Self::Input>>;
46}
47
48pub trait RescaleError {
49 fn rescale(&self, result_abs: Self, result_asc: Self) -> Self;
50}
51
52pub trait IntegrationOutput:
57 Clone
58 + Default
59 + argmin_math::ArgminMul<<Self as IntegrationOutput>::Scalar, Self>
60 + argmin_math::ArgminDiv<<Self as IntegrationOutput>::Scalar, Self>
61 + argmin_math::ArgminAdd<Self, Self>
62 + argmin_math::ArgminSub<Self, Self>
63 + argmin_math::ArgminL2Norm<Self::Float>
64 + Send
65 + Sync
66{
67 type Real;
69 type Scalar: ComplexField<RealField = Self::Float>;
71 type Float: IntegrableFloat;
73
74 fn modulus(&self) -> Self::Float;
76 fn is_finite(&self) -> bool;
78}
79
80pub trait AccumulateError<R>: Send + Sync {
82 fn max(&self) -> R;
83 fn mean(&self) -> R;
84}
85
86impl IntegrableFloat for f32 {}
87impl IntegrableFloat for f64 {}
88
89impl AccumulateError<Self> for f64 {
92 fn max(&self) -> Self {
93 *self
94 }
95 fn mean(&self) -> Self {
96 *self
97 }
98}
99
100impl AccumulateError<Self> for f32 {
101 fn max(&self) -> Self {
102 *self
103 }
104 fn mean(&self) -> Self {
105 *self
106 }
107}
108
109#[cfg(feature = "ndarray")]
110impl<T: num_traits::float::FloatCore + num_traits::FromPrimitive + PartialOrd + Send + Sync>
111 AccumulateError<T> for ndarray::Array1<T>
112{
113 fn max(&self) -> T {
114 *ndarray_stats::QuantileExt::max(self).unwrap()
115 }
116 fn mean(&self) -> T {
117 self.mean().unwrap()
118 }
119}
120
121#[cfg(feature = "ndarray")]
122impl<T: num_traits::float::FloatCore + num_traits::FromPrimitive + PartialOrd + Send + Sync>
123 AccumulateError<T> for ndarray::Array2<T>
124{
125 fn max(&self) -> T {
126 *ndarray_stats::QuantileExt::max(self).unwrap()
127 }
128 fn mean(&self) -> T {
129 self.mean().unwrap()
130 }
131}
132
133impl RescaleError for f32 {
135 fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
136 let mut error = self.abs();
137 if result_asc != 0.0 && error != 0.0 {
138 let exponent = 1.5;
139 let scale = ComplexField::powf(200. * error / result_asc, exponent);
140
141 if scale < 1. {
142 error = result_asc * scale;
143 } else {
144 error = result_asc;
145 }
146 }
147
148 if result_abs > f32::EPSILON / (50. * f32::EPSILON) {
149 let min_err = 50. * f32::EPSILON * result_abs;
150 if min_err > error {
151 error = min_err;
152 }
153 }
154 error
155 }
156}
157
158impl RescaleError for f64 {
159 fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
160 let mut error = self.abs();
161 if result_asc != 0.0 && error != 0.0 {
162 let exponent = 1.5;
163 let scale = ComplexField::powf(200. * error / result_asc, exponent);
164
165 if scale < 1. {
166 error = result_asc * scale;
167 } else {
168 error = result_asc;
169 }
170 }
171
172 if result_abs > f64::EPSILON / (50. * f64::EPSILON) {
173 let min_err = 50. * f64::EPSILON * result_abs;
174 if min_err > error {
175 error = min_err;
176 }
177 }
178 error
179 }
180}
181
182#[cfg(feature = "ndarray")]
183impl<T> RescaleError for ndarray::Array1<T>
184where
185 T: RescaleError,
186{
187 fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
188 self.iter()
189 .zip(result_abs)
190 .zip(result_asc)
191 .map(|((err, abs), asc)| err.rescale(abs, asc))
192 .collect()
193 }
194}
195
196#[cfg(feature = "ndarray")]
197impl<T> RescaleError for ndarray::Array2<T>
198where
199 T: RescaleError,
200{
201 fn rescale(&self, result_abs: Self, result_asc: Self) -> Self {
202 self.iter()
203 .zip(result_abs)
204 .zip(result_asc)
205 .map(|((err, abs), asc)| err.rescale(abs, asc))
206 .collect::<ndarray::Array1<T>>()
207 .into_shape(self.dim())
208 .unwrap()
209 }
210}
211
212impl IntegrationOutput for Complex<f32> {
214 type Real = f32;
215 type Scalar = Self;
216 type Float = f32;
217
218 fn modulus(&self) -> Self::Real {
219 <Self as ComplexField>::modulus(*self)
220 }
221
222 fn is_finite(&self) -> bool {
223 ComplexField::is_finite(self)
224 }
225}
226
227impl IntegrationOutput for f32 {
228 type Real = Self;
229 type Scalar = Self;
230 type Float = Self;
231
232 fn modulus(&self) -> Self::Real {
233 <Self as ComplexField>::modulus(*self)
234 }
235
236 fn is_finite(&self) -> bool {
237 ComplexField::is_finite(self)
238 }
239}
240
241impl IntegrationOutput for Complex<f64> {
242 type Real = f64;
243 type Scalar = Self;
244 type Float = f64;
245
246 fn modulus(&self) -> Self::Real {
247 <Self as ComplexField>::modulus(*self)
248 }
249
250 fn is_finite(&self) -> bool {
251 ComplexField::is_finite(self)
252 }
253}
254
255impl IntegrationOutput for f64 {
256 type Real = Self;
257 type Scalar = Self;
258 type Float = Self;
259
260 fn modulus(&self) -> Self::Real {
261 <Self as ComplexField>::modulus(*self)
262 }
263
264 fn is_finite(&self) -> bool {
265 ComplexField::is_finite(self)
266 }
267}
268
269#[cfg(feature = "ndarray")]
270impl<T: ComplexField + Default> IntegrationOutput for ndarray::Array1<T>
271where
272 Self: argmin_math::ArgminAdd<Self, Self>
273 + argmin_math::ArgminSub<Self, Self>
274 + argmin_math::ArgminDiv<T, Self>
275 + argmin_math::ArgminMul<T, Self>
276 + argmin_math::ArgminL2Norm<<T as ComplexField>::RealField>,
277 ndarray::Array1<<T as ComplexField>::RealField>: argmin_math::ArgminAdd<
278 <T as ComplexField>::RealField,
279 ndarray::Array1<<T as ComplexField>::RealField>,
280 > + argmin_math::ArgminAdd<
281 ndarray::Array1<<T as ComplexField>::RealField>,
282 ndarray::Array1<<T as ComplexField>::RealField>,
283 > + argmin_math::ArgminMul<
284 <T as ComplexField>::RealField,
285 ndarray::Array1<<T as ComplexField>::RealField>,
286 > + AccumulateError<<T as ComplexField>::RealField>,
287 T: Copy,
288 <T as ComplexField>::RealField:
289 Default + FloatCore + IntegrableFloat + RescaleError + std::iter::Sum,
290{
291 type Real = ndarray::Array1<<T as ComplexField>::RealField>;
292 type Scalar = T;
293 type Float = <T as ComplexField>::RealField;
294
295 fn modulus(&self) -> Self::Float {
296 self.iter().map(|each| each.modulus()).sum()
297 }
298
299 fn is_finite(&self) -> bool {
300 self.iter().all(|value| ComplexField::is_finite(value))
301 }
302}
303
304#[cfg(feature = "ndarray")]
305impl<T: ComplexField + Default> IntegrationOutput for ndarray::Array2<T>
306where
307 Self: argmin_math::ArgminAdd<Self, Self>
308 + argmin_math::ArgminSub<Self, Self>
309 + argmin_math::ArgminDiv<T, Self>
310 + argmin_math::ArgminMul<T, Self>
311 + argmin_math::ArgminL2Norm<<T as ComplexField>::RealField>,
312 ndarray::Array2<<T as ComplexField>::RealField>: argmin_math::ArgminAdd<
313 <T as ComplexField>::RealField,
314 ndarray::Array2<<T as ComplexField>::RealField>,
315 > + argmin_math::ArgminAdd<
316 ndarray::Array2<<T as ComplexField>::RealField>,
317 ndarray::Array2<<T as ComplexField>::RealField>,
318 > + argmin_math::ArgminMul<
319 <T as ComplexField>::RealField,
320 ndarray::Array2<<T as ComplexField>::RealField>,
321 > + AccumulateError<<T as ComplexField>::RealField>,
322 T: Copy,
323 <T as ComplexField>::RealField:
324 Default + FloatCore + IntegrableFloat + RescaleError + std::iter::Sum,
325{
326 type Real = ndarray::Array2<<T as ComplexField>::RealField>;
327 type Scalar = T;
328 type Float = <T as ComplexField>::RealField;
329
330 fn modulus(&self) -> Self::Float {
332 self.iter().map(|each| each.modulus()).sum()
333 }
334
335 fn is_finite(&self) -> bool {
336 self.iter().all(|value| ComplexField::is_finite(value))
337 }
338}