1use alloc::collections::BinaryHeap;
12use core::cmp::Ordering;
13
14use numra_core::Scalar;
15
16use crate::error::IntegrationError;
17
18extern crate alloc;
19
20#[derive(Clone, Debug)]
22pub struct QuadOptions<S: Scalar> {
23 pub atol: S,
25 pub rtol: S,
27 pub max_subdivisions: usize,
29 pub points: Vec<S>,
31}
32
33impl<S: Scalar> Default for QuadOptions<S> {
34 fn default() -> Self {
35 Self {
36 atol: S::from_f64(1.49e-8),
37 rtol: S::from_f64(1.49e-8),
38 max_subdivisions: 50,
39 points: Vec::new(),
40 }
41 }
42}
43
44impl<S: Scalar> QuadOptions<S> {
45 pub fn atol(mut self, atol: S) -> Self {
47 self.atol = atol;
48 self
49 }
50
51 pub fn rtol(mut self, rtol: S) -> Self {
53 self.rtol = rtol;
54 self
55 }
56
57 pub fn max_subdivisions(mut self, max: usize) -> Self {
59 self.max_subdivisions = max;
60 self
61 }
62
63 pub fn points(mut self, pts: Vec<S>) -> Self {
65 self.points = pts;
66 self
67 }
68}
69
70#[derive(Clone, Debug)]
72pub struct QuadResult<S: Scalar> {
73 pub value: S,
75 pub error_estimate: S,
77 pub n_evaluations: usize,
79 pub n_subdivisions: usize,
81}
82
83const K15_NODES: [f64; 8] = [
93 0.0,
94 0.2077849550078985,
95 0.4058451513773972,
96 0.5860872354676911,
97 0.7415311855993945,
98 0.8648644233597691,
99 0.9491079123427585,
100 0.9914553711208126,
101];
102
103const K15_WEIGHTS: [f64; 8] = [
105 0.2094821410847278,
106 0.2044329400752989,
107 0.1903505780647854,
108 0.1690047266392679,
109 0.1406532597155259,
110 0.1047900103222502,
111 0.0630920926299786,
112 0.0229353220105292,
113];
114
115const G7_WEIGHTS: [f64; 4] = [
118 0.4179591836734694,
119 0.3818300505051189,
120 0.2797053914892767,
121 0.1294849661688697,
122];
123
124fn g7k15<S, F>(f: &mut F, a: S, b: S) -> (S, S, usize)
127where
128 S: Scalar,
129 F: FnMut(S) -> S,
130{
131 let mid = (a + b) * S::HALF;
132 let half_len = (b - a) * S::HALF;
133
134 let mut k15 = S::ZERO;
135 let mut g7 = S::ZERO;
136
137 let f_center = f(mid);
139 k15 += S::from_f64(K15_WEIGHTS[0]) * f_center;
140 g7 += S::from_f64(G7_WEIGHTS[0]) * f_center;
141
142 for &i in &[1usize, 3, 5, 7] {
144 let x = half_len * S::from_f64(K15_NODES[i]);
145 let f_pos = f(mid + x);
146 let f_neg = f(mid - x);
147 k15 += S::from_f64(K15_WEIGHTS[i]) * (f_pos + f_neg);
148 }
149
150 for (g_idx, &k_idx) in [2usize, 4, 6].iter().enumerate() {
152 let x = half_len * S::from_f64(K15_NODES[k_idx]);
153 let f_pos = f(mid + x);
154 let f_neg = f(mid - x);
155 let fsum = f_pos + f_neg;
156 k15 += S::from_f64(K15_WEIGHTS[k_idx]) * fsum;
157 g7 += S::from_f64(G7_WEIGHTS[g_idx + 1]) * fsum;
158 }
159
160 (k15 * half_len, g7 * half_len, 15)
161}
162
163struct SubInterval<S: Scalar> {
165 a: S,
166 b: S,
167 result: S,
168 error: S,
169}
170
171impl<S: Scalar> PartialEq for SubInterval<S> {
172 fn eq(&self, other: &Self) -> bool {
173 self.error.to_f64() == other.error.to_f64()
174 }
175}
176
177impl<S: Scalar> Eq for SubInterval<S> {}
178
179impl<S: Scalar> PartialOrd for SubInterval<S> {
180 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
181 Some(self.cmp(other))
182 }
183}
184
185impl<S: Scalar> Ord for SubInterval<S> {
186 fn cmp(&self, other: &Self) -> Ordering {
187 self.error
189 .to_f64()
190 .partial_cmp(&other.error.to_f64())
191 .unwrap_or(Ordering::Equal)
192 }
193}
194
195pub fn quad<S, F>(
209 mut f: F,
210 a: S,
211 b: S,
212 opts: &QuadOptions<S>,
213) -> Result<QuadResult<S>, IntegrationError>
214where
215 S: Scalar,
216 F: FnMut(S) -> S,
217{
218 let mut breakpoints = Vec::new();
220 breakpoints.push(a);
221 for &p in &opts.points {
222 if p > a && p < b {
223 breakpoints.push(p);
224 }
225 }
226 breakpoints.push(b);
227 breakpoints.sort_by(|x, y| {
229 x.to_f64()
230 .partial_cmp(&y.to_f64())
231 .unwrap_or(Ordering::Equal)
232 });
233 breakpoints.dedup_by(|a, b| ((*a) - (*b)).abs() < S::EPSILON);
235
236 let mut heap: BinaryHeap<SubInterval<S>> = BinaryHeap::new();
237 let mut total_result = S::ZERO;
238 let mut total_error = S::ZERO;
239 let mut total_evals = 0usize;
240 let mut n_subdivisions = 0usize;
241
242 for i in 0..breakpoints.len() - 1 {
244 let seg_a = breakpoints[i];
245 let seg_b = breakpoints[i + 1];
246 let (k15, g7, ne) = g7k15(&mut f, seg_a, seg_b);
247 let err = (k15 - g7).abs();
248 total_result += k15;
249 total_error += err;
250 total_evals += ne;
251 n_subdivisions += 1;
252
253 if !k15.is_finite() {
255 let mid = (seg_a + seg_b) * S::HALF;
256 return Err(IntegrationError::InvalidValue { x: mid.to_f64() });
257 }
258
259 heap.push(SubInterval {
260 a: seg_a,
261 b: seg_b,
262 result: k15,
263 error: err,
264 });
265 }
266
267 let tol = opts.atol.max(opts.rtol * total_result.abs());
269 if total_error <= tol {
270 return Ok(QuadResult {
271 value: total_result,
272 error_estimate: total_error,
273 n_evaluations: total_evals,
274 n_subdivisions,
275 });
276 }
277
278 while n_subdivisions < opts.max_subdivisions {
280 let worst = match heap.pop() {
281 Some(w) => w,
282 None => break,
283 };
284
285 let mid = (worst.a + worst.b) * S::HALF;
287
288 let (k15_l, g7_l, ne_l) = g7k15(&mut f, worst.a, mid);
289 let err_l = (k15_l - g7_l).abs();
290
291 let (k15_r, g7_r, ne_r) = g7k15(&mut f, mid, worst.b);
292 let err_r = (k15_r - g7_r).abs();
293
294 total_evals += ne_l + ne_r;
295 n_subdivisions += 1;
296
297 total_result = total_result - worst.result + k15_l + k15_r;
299 total_error = total_error - worst.error + err_l + err_r;
300
301 if !k15_l.is_finite() || !k15_r.is_finite() {
302 return Err(IntegrationError::InvalidValue { x: mid.to_f64() });
303 }
304
305 heap.push(SubInterval {
306 a: worst.a,
307 b: mid,
308 result: k15_l,
309 error: err_l,
310 });
311 heap.push(SubInterval {
312 a: mid,
313 b: worst.b,
314 result: k15_r,
315 error: err_r,
316 });
317
318 let tol = opts.atol.max(opts.rtol * total_result.abs());
319 if total_error <= tol {
320 return Ok(QuadResult {
321 value: total_result,
322 error_estimate: total_error,
323 n_evaluations: total_evals,
324 n_subdivisions,
325 });
326 }
327 }
328
329 Err(IntegrationError::MaxSubdivisions {
331 subdivisions: n_subdivisions,
332 error_estimate: total_error.to_f64(),
333 })
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use approx::assert_relative_eq;
340
341 #[test]
342 fn test_quad_sin() {
343 let result = quad(
345 |x: f64| x.sin(),
346 0.0,
347 core::f64::consts::PI,
348 &QuadOptions::default(),
349 )
350 .unwrap();
351 assert_relative_eq!(result.value, 2.0, epsilon = 1e-10);
352 assert!(result.error_estimate < 1e-10);
353 }
354
355 #[test]
356 fn test_quad_exp() {
357 let result = quad(|x: f64| x.exp(), 0.0, 1.0, &QuadOptions::default()).unwrap();
359 let expected = core::f64::consts::E - 1.0;
360 assert_relative_eq!(result.value, expected, epsilon = 1e-12);
361 }
362
363 #[test]
364 fn test_quad_polynomial() {
365 let result = quad(|x: f64| x.powi(4), 0.0, 1.0, &QuadOptions::default()).unwrap();
367 assert_relative_eq!(result.value, 0.2, epsilon = 1e-14);
368 }
369
370 #[test]
371 fn test_quad_singular_sqrt() {
372 let opts = QuadOptions::default()
374 .atol(1e-8)
375 .rtol(1e-8)
376 .max_subdivisions(100)
377 .points(vec![0.0]);
378 let result = quad(
379 |x: f64| {
380 if x.abs() < 1e-300 {
381 0.0
382 } else {
383 1.0 / x.sqrt()
384 }
385 },
386 0.0,
387 1.0,
388 &opts,
389 )
390 .unwrap();
391 assert_relative_eq!(result.value, 2.0, epsilon = 1e-6);
392 }
393
394 #[test]
395 fn test_quad_oscillatory() {
396 let opts = QuadOptions::default().max_subdivisions(200);
398 let result = quad(
399 |x: f64| (100.0 * x).sin(),
400 0.0,
401 core::f64::consts::PI,
402 &opts,
403 )
404 .unwrap();
405 assert!(result.value.abs() < 1e-6);
406 }
407
408 #[test]
409 fn test_quad_tight_tolerance() {
410 let opts = QuadOptions::default().atol(1e-14).rtol(1e-14);
412 let result = quad(|x: f64| x.cos(), 0.0, core::f64::consts::FRAC_PI_2, &opts).unwrap();
413 assert_relative_eq!(result.value, 1.0, epsilon = 1e-13);
414 }
415
416 #[test]
417 fn test_quad_f32() {
418 let opts = QuadOptions::<f32>::default().atol(1e-4).rtol(1e-4);
420 let result = quad(|x: f32| x.sin(), 0.0f32, core::f32::consts::PI, &opts).unwrap();
421 assert!((result.value - 2.0).abs() < 1e-4);
422 }
423
424 #[test]
425 fn test_quad_gaussian() {
426 let result = quad(|x: f64| (-x * x).exp(), -5.0, 5.0, &QuadOptions::default()).unwrap();
428 assert_relative_eq!(result.value, core::f64::consts::PI.sqrt(), epsilon = 1e-10);
429 }
430}