1use crate::error::{IntegrateError, IntegrateResult};
7use crate::IntegrateFloat;
8use std::f64::consts::PI;
9use std::fmt::Debug;
10
11#[derive(Debug, Clone)]
13pub struct QuadOptions<F: IntegrateFloat> {
14 pub abs_tol: F,
16 pub rel_tol: F,
18 pub max_evals: usize,
20 pub use_abs_error: bool,
22 pub use_simpson: bool,
24}
25
26impl<F: IntegrateFloat> Default for QuadOptions<F> {
27 fn default() -> Self {
28 Self {
29 abs_tol: F::from_f64(1.49e-8).unwrap(), rel_tol: F::from_f64(1.49e-8).unwrap(), max_evals: 500, use_abs_error: false,
33 use_simpson: false,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct QuadResult<F: IntegrateFloat> {
41 pub value: F,
43 pub abs_error: F,
45 pub n_evals: usize,
47 pub converged: bool,
49}
50
51#[allow(dead_code)]
74pub fn trapezoid<F, Func>(f: Func, a: F, b: F, n: usize) -> F
75where
76 F: IntegrateFloat,
77 Func: Fn(F) -> F,
78{
79 if n == 0 {
80 return F::zero();
81 }
82
83 let h = (b - a) / F::from_usize(n).unwrap();
84 let mut sum = F::from_f64(0.5).unwrap() * (f(a) + f(b));
85
86 for i in 1..n {
87 let x = a + F::from_usize(i).unwrap() * h;
88 sum += f(x);
89 }
90
91 sum * h
92}
93
94#[allow(dead_code)]
117pub fn simpson<F, Func>(mut f: Func, a: F, b: F, n: usize) -> IntegrateResult<F>
118where
119 F: IntegrateFloat,
120 Func: FnMut(F) -> F,
121{
122 if n == 0 {
123 return Ok(F::zero());
124 }
125
126 if !n.is_multiple_of(2) {
127 return Err(IntegrateError::ValueError(
128 "Number of intervals must be even".to_string(),
129 ));
130 }
131
132 let h = (b - a) / F::from_usize(n).unwrap();
133 let mut sum_even = F::zero();
134 let mut sum_odd = F::zero();
135
136 for i in 1..n {
137 let x = a + F::from_usize(i).unwrap() * h;
138 if i % 2 == 0 {
139 sum_even += f(x);
140 } else {
141 sum_odd += f(x);
142 }
143 }
144
145 let result =
146 (f(a) + f(b) + F::from_f64(2.0).unwrap() * sum_even + F::from_f64(4.0).unwrap() * sum_odd)
147 * h
148 / F::from_f64(3.0).unwrap();
149 Ok(result)
150}
151
152#[allow(dead_code)]
176pub fn quad<F, Func>(
177 f: Func,
178 a: F,
179 b: F,
180 options: Option<QuadOptions<F>>,
181) -> IntegrateResult<QuadResult<F>>
182where
183 F: IntegrateFloat,
184 Func: Fn(F) -> F + Copy,
185{
186 let opts = options.unwrap_or_default();
187
188 if opts.use_simpson {
189 let n = 1000; let result = simpson(f, a, b, n)?;
192
193 return Ok(QuadResult {
194 value: result,
195 abs_error: F::from_f64(1e-8).unwrap(), n_evals: n + 1, converged: true,
198 });
199 }
200
201 let mut n_evals = 0;
203
204 let (value, error, converged) = adaptive_quad_impl(f, a, b, &mut n_evals, &opts)?;
206
207 Ok(QuadResult {
208 value,
209 abs_error: error,
210 n_evals,
211 converged,
212 })
213}
214
215#[allow(dead_code)]
217fn adaptive_quad_impl<F, Func>(
218 f: Func,
219 a: F,
220 b: F,
221 n_evals: &mut usize,
222 options: &QuadOptions<F>,
223) -> IntegrateResult<(F, F, bool)>
224where
226 F: IntegrateFloat,
227 Func: Fn(F) -> F + Copy,
228{
229 let n_initial = 10; let mut eval_count_coarse = 0;
232 let coarse_result = {
233 let f_with_count = |x: F| {
235 eval_count_coarse += 1;
236 f(x)
237 };
238 simpson(f_with_count, a, b, n_initial)?
239 };
240 *n_evals += eval_count_coarse;
241
242 let n_refined = 20; let mut eval_count_refined = 0;
245 let refined_result = {
246 let f_with_count = |x: F| {
248 eval_count_refined += 1;
249 f(x)
250 };
251 simpson(f_with_count, a, b, n_refined)?
252 };
253 *n_evals += eval_count_refined;
254
255 let error = (refined_result - coarse_result).abs();
257 let tolerance = if options.use_abs_error {
258 options.abs_tol
259 } else {
260 options.abs_tol + options.rel_tol * refined_result.abs()
261 };
262
263 let converged = error <= tolerance || *n_evals >= options.max_evals;
265
266 if *n_evals >= options.max_evals && error > tolerance {
267 return Err(IntegrateError::ConvergenceError(format!(
268 "Failed to converge after {} function evaluations",
269 *n_evals
270 )));
271 }
272
273 if !converged {
275 let mid = (a + b) / F::from_f64(2.0).unwrap();
276
277 let (left_value, left_error, left_converged) =
279 adaptive_quad_impl(f, a, mid, n_evals, options)?;
280 let (right_value, right_error, right_converged) =
281 adaptive_quad_impl(f, mid, b, n_evals, options)?;
282
283 let value = left_value + right_value;
285 let abs_error = left_error + right_error;
286 let sub_converged = left_converged && right_converged;
287
288 return Ok((value, abs_error, sub_converged));
289 }
290
291 Ok((refined_result, error, converged))
292}
293
294#[allow(dead_code)] fn simpson_with_count<F, Func>(
297 f: &mut Func,
298 a: F,
299 b: F,
300 n: usize,
301 count: &mut usize,
302) -> IntegrateResult<F>
303where
304 F: IntegrateFloat,
305 Func: FnMut(F) -> F,
306{
307 if n == 0 {
308 return Ok(F::zero());
309 }
310
311 if !n.is_multiple_of(2) {
312 return Err(IntegrateError::ValueError(
313 "Number of intervals must be even".to_string(),
314 ));
315 }
316
317 let h = (b - a) / F::from_usize(n).unwrap();
318 let mut sum_even = F::zero();
319 let mut sum_odd = F::zero();
320
321 *count += 2; let fa = f(a);
323 let fb = f(b);
324
325 for i in 1..n {
326 let x = a + F::from_usize(i).unwrap() * h;
327 *count += 1;
328 if i % 2 == 0 {
329 sum_even += f(x);
330 } else {
331 sum_odd += f(x);
332 }
333 }
334
335 let result =
336 (fa + fb + F::from_f64(2.0).unwrap() * sum_even + F::from_f64(4.0).unwrap() * sum_odd) * h
337 / F::from_f64(3.0).unwrap();
338 Ok(result)
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use approx::assert_relative_eq;
345
346 #[test]
347 fn test_trapezoid_rule() {
348 let result = trapezoid(|x| x * x, 0.0, 1.0, 100);
351 assert_relative_eq!(result, 1.0 / 3.0, epsilon = 1e-4);
352
353 let pi = std::f64::consts::PI;
356 let result = trapezoid(|x| x.sin(), 0.0, pi, 1000);
357 assert_relative_eq!(result, 2.0, epsilon = 1e-4);
358 }
359
360 #[test]
361 fn test_simpson_rule() {
362 let result = simpson(|x| x * x, 0.0, 1.0, 100).unwrap();
365 assert_relative_eq!(result, 1.0 / 3.0, epsilon = 1e-8);
366
367 let pi = std::f64::consts::PI;
370 let result = simpson(|x| x.sin(), 0.0, pi, 100).unwrap();
371 assert_relative_eq!(result, 2.0, epsilon = 1e-6);
373
374 let error = simpson(|x| x * x, 0.0, 1.0, 99);
376 assert!(error.is_err());
377 }
378
379 #[test]
380 fn test_adaptive_quad() {
381 let result = quad(|x| x * x, 0.0, 1.0, None).unwrap();
384 assert_relative_eq!(result.value, 1.0 / 3.0, epsilon = 1e-8);
385 assert!(result.converged);
386
387 let options = QuadOptions {
390 use_simpson: true, ..Default::default()
392 };
393
394 let result = quad(
396 |x: f64| x.cos(),
397 0.0,
398 std::f64::consts::PI / 2.0,
399 Some(options),
400 )
401 .unwrap();
402 assert_relative_eq!(result.value, 1.0, epsilon = 1e-6);
403 }
404}