gauss_quad/midpoint/
mod.rs

1//! Numerical integration using the midpoint rule.
2//!
3//! This is one of the simplest integration schemes.
4//!
5//! 1. Divide the domain into equally sized sections.
6//! 2. Find the function value at the midpoint of each section.
7//! 3. The section's integral is approximated as a rectangle as wide as the section and as tall as the function
8//!    value at the midpoint.
9//!
10//! ```
11//! use gauss_quad::midpoint::{Midpoint, MidpointError};
12//! use approx::assert_abs_diff_eq;
13//!
14//! use core::f64::consts::PI;
15//!
16//! let eps = 0.001;
17//!
18//! let n = 30;
19//! let quad = Midpoint::new(n)?;
20//!
21//! // integrate some functions
22//! let two_thirds = quad.integrate(-1.0, 1.0, |x| x * x);
23//! assert_abs_diff_eq!(two_thirds, 0.66666, epsilon = eps);
24//!
25//! let estimate_sin = quad.integrate(-PI, PI, |x| x.sin());
26//! assert_abs_diff_eq!(estimate_sin, 0.0, epsilon = eps);
27//!
28//! // some functions need more steps than others
29//! let m = 100;
30//! let better_quad = Midpoint::new(m)?;
31//!
32//! let piecewise = better_quad.integrate(-5.0, 5.0, |x|
33//!     if x > 1.0 && x < 2.0 {
34//!         (-x * x).exp()
35//!     } else {
36//!         0.0
37//!     }
38//! );
39//!
40//! assert_abs_diff_eq!(0.135257, piecewise, epsilon = eps);
41//! # Ok::<(), MidpointError>(())
42//! ```
43
44#[cfg(feature = "rayon")]
45use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
46
47use crate::{Node, __impl_node_rule};
48
49use std::backtrace::Backtrace;
50
51/// A midpoint rule.
52/// ```
53/// # use gauss_quad::midpoint::{Midpoint, MidpointError};
54/// // initialize a midpoint rule with 100 cells
55/// let quad: Midpoint = Midpoint::new(100)?;
56///
57/// // numerically integrate a function from -1.0 to 1.0 using the midpoint rule
58/// let approx = quad.integrate(-1.0, 1.0, |x| x * x);
59/// # Ok::<(), MidpointError>(())
60/// ```
61#[derive(Debug, Clone, PartialEq)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub struct Midpoint {
64    /// The dimensionless midpoints
65    nodes: Vec<Node>,
66}
67
68impl Midpoint {
69    /// Initialize a new midpoint rule with `degree` number of cells. The nodes are evenly spaced.
70    // -- code based on Luca Palmieri's "Scientific computing: a Rust adventure [Part 2 - Array1]"
71    //    <https://www.lpalmieri.com/posts/2019-04-07-scientific-computing-a-rust-adventure-part-2-array1/>
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if `degree` is less than 1.
76    pub fn new(degree: usize) -> Result<Self, MidpointError> {
77        if degree > 0 {
78            Ok(Self {
79                nodes: (0..degree).map(|d| d as f64).collect(),
80            })
81        } else {
82            Err(MidpointError::new())
83        }
84    }
85
86    /// Integrate over the domain [a, b].
87    pub fn integrate<F>(&self, a: f64, b: f64, integrand: F) -> f64
88    where
89        F: Fn(f64) -> f64,
90    {
91        let rect_width = (b - a) / self.nodes.len() as f64;
92
93        let sum: f64 = self
94            .nodes
95            .iter()
96            .map(|&node| integrand(a + rect_width * (0.5 + node)))
97            .sum();
98
99        sum * rect_width
100    }
101
102    #[cfg(feature = "rayon")]
103    /// Same as [`integrate`](Midpoint::integrate) but runs in parallel.
104    pub fn par_integrate<F>(&self, a: f64, b: f64, integrand: F) -> f64
105    where
106        F: Fn(f64) -> f64 + Sync,
107    {
108        let rect_width = (b - a) / self.nodes.len() as f64;
109
110        let sum: f64 = self
111            .nodes
112            .par_iter()
113            .map(|&node| integrand(a + rect_width * (0.5 + node)))
114            .sum();
115
116        sum * rect_width
117    }
118}
119
120__impl_node_rule! {Midpoint, MidpointIter, MidpointIntoIter}
121
122/// The error returned by [`Midpoint::new`] if given a degree of 0.
123#[derive(Debug)]
124pub struct MidpointError(Backtrace);
125
126use core::fmt;
127impl fmt::Display for MidpointError {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        write!(f, "the degree of the midpoint rule needs to be at least 1")
130    }
131}
132
133impl MidpointError {
134    /// Calls [`Backtrace::capture`] and wraps the result in a `MidpointError` struct.
135    fn new() -> Self {
136        Self(Backtrace::capture())
137    }
138
139    /// Returns a [`Backtrace`] to where the error was created.
140    ///
141    /// This backtrace is captured with [`Backtrace::capture`], see it for more information about how to make it display information when printed.
142    #[inline]
143    pub fn backtrace(&self) -> &Backtrace {
144        &self.0
145    }
146}
147
148impl std::error::Error for MidpointError {}
149
150#[cfg(test)]
151mod tests {
152    use approx::assert_abs_diff_eq;
153
154    use super::*;
155
156    #[test]
157    fn check_midpoint_integration() {
158        let quad = Midpoint::new(100).unwrap();
159        let integral = quad.integrate(0.0, 1.0, |x| x * x);
160        assert_abs_diff_eq!(integral, 1.0 / 3.0, epsilon = 0.0001);
161    }
162
163    #[cfg(feature = "rayon")]
164    #[test]
165    fn par_check_midpoint_integration() {
166        let quad = Midpoint::new(100).unwrap();
167        let integral = quad.par_integrate(0.0, 1.0, |x| x * x);
168        assert_abs_diff_eq!(integral, 1.0 / 3.0, epsilon = 0.0001);
169    }
170
171    #[test]
172    fn check_midpoint_error() {
173        let midpoint_rule = Midpoint::new(0);
174        assert!(midpoint_rule.is_err());
175        assert_eq!(
176            format!("{}", midpoint_rule.err().unwrap()),
177            "the degree of the midpoint rule needs to be at least 1"
178        );
179    }
180
181    #[test]
182    fn check_derives() {
183        let quad = Midpoint::new(10).unwrap();
184        let quad_clone = quad.clone();
185        assert_eq!(quad, quad_clone);
186        let other_quad = Midpoint::new(3).unwrap();
187        assert_ne!(quad, other_quad);
188    }
189
190    #[test]
191    fn check_iterators() {
192        let rule = Midpoint::new(100).unwrap();
193        let a = 0.0;
194        let b = 1.0;
195        let ans = 1.0 / 3.0;
196        let rect_width = (b - a) / rule.degree() as f64;
197
198        assert_abs_diff_eq!(
199            ans,
200            rule.iter().fold(0.0, |tot, n| {
201                let x = a + rect_width * (0.5 + n);
202                tot + x * x
203            }) * rect_width,
204            epsilon = 1e-4
205        );
206
207        assert_abs_diff_eq!(
208            ans,
209            rule.into_iter().fold(0.0, |tot, n| {
210                let x = a + rect_width * (0.5 + n);
211                tot + x * x
212            }) * rect_width,
213            epsilon = 1e-4
214        );
215    }
216}