1#[cfg(feature = "rayon")]
27use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
28
29mod bogaert;
30
31use bogaert::NodeWeightPair;
32
33use crate::{Node, Weight, __impl_node_weight_rule};
34
35use std::backtrace::Backtrace;
36
37#[derive(Debug, Clone, PartialEq)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70pub struct GaussLegendre {
71 node_weight_pairs: Vec<(Node, Weight)>,
72}
73
74impl GaussLegendre {
75 pub fn new(deg: usize) -> Result<Self, GaussLegendreError> {
86 if deg < 2 {
87 return Err(GaussLegendreError::new());
88 }
89
90 Ok(Self {
91 node_weight_pairs: (1..deg + 1)
92 .map(|k: usize| NodeWeightPair::new(deg, k).into_tuple())
93 .collect(),
94 })
95 }
96
97 #[cfg(feature = "rayon")]
98 pub fn par_new(deg: usize) -> Result<Self, GaussLegendreError> {
104 if deg < 2 {
105 return Err(GaussLegendreError::new());
106 }
107
108 Ok(Self {
109 node_weight_pairs: (1..deg + 1)
110 .into_par_iter()
111 .map(|k| NodeWeightPair::new(deg, k).into_tuple())
112 .collect(),
113 })
114 }
115
116 fn argument_transformation(x: f64, a: f64, b: f64) -> f64 {
117 0.5 * ((b - a) * x + (b + a))
118 }
119
120 fn scale_factor(a: f64, b: f64) -> f64 {
121 0.5 * (b - a)
122 }
123
124 pub fn integrate<F>(&self, a: f64, b: f64, integrand: F) -> f64
139 where
140 F: Fn(f64) -> f64,
141 {
142 let result: f64 = self
143 .node_weight_pairs
144 .iter()
145 .map(|(x_val, w_val)| integrand(Self::argument_transformation(*x_val, a, b)) * w_val)
146 .sum();
147 Self::scale_factor(a, b) * result
148 }
149
150 #[cfg(feature = "rayon")]
151 pub fn par_integrate<F>(&self, a: f64, b: f64, integrand: F) -> f64
164 where
165 F: Fn(f64) -> f64 + Sync,
166 {
167 let result: f64 = self
168 .node_weight_pairs
169 .par_iter()
170 .map(|(x_val, w_val)| integrand(Self::argument_transformation(*x_val, a, b)) * w_val)
171 .sum();
172 Self::scale_factor(a, b) * result
173 }
174}
175
176__impl_node_weight_rule! {GaussLegendre, GaussLegendreNodes, GaussLegendreWeights, GaussLegendreIter, GaussLegendreIntoIter}
177
178#[derive(Debug)]
180pub struct GaussLegendreError(Backtrace);
181
182use core::fmt;
183impl fmt::Display for GaussLegendreError {
184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185 write!(
186 f,
187 "the degree of the Gauss-Legendre quadrature rule must be at least 2"
188 )
189 }
190}
191
192impl GaussLegendreError {
193 fn new() -> Self {
195 Self(Backtrace::capture())
196 }
197
198 #[inline]
202 pub fn backtrace(&self) -> &Backtrace {
203 &self.0
204 }
205}
206
207impl std::error::Error for GaussLegendreError {}
208
209#[cfg(test)]
210mod tests {
211 use approx::assert_abs_diff_eq;
212
213 use super::*;
214
215 #[test]
216 fn check_degree_3() {
217 let (x, w): (Vec<_>, Vec<_>) = GaussLegendre::new(3).unwrap().into_iter().unzip();
218
219 let x_should = [0.7745966692414834, 0.0000000000000000, -0.7745966692414834];
220 let w_should = [0.5555555555555556, 0.8888888888888888, 0.5555555555555556];
221 for (i, x_val) in x_should.iter().enumerate() {
222 assert_abs_diff_eq!(*x_val, x[i]);
223 }
224 for (i, w_val) in w_should.iter().enumerate() {
225 assert_abs_diff_eq!(*w_val, w[i]);
226 }
227 }
228
229 #[test]
230 fn check_degree_128() {
231 let (x, w): (Vec<_>, Vec<_>) = GaussLegendre::new(128).unwrap().into_iter().unzip();
233
234 #[rustfmt::skip]
236 #[allow(clippy::excessive_precision)]
237 let x_should = [0.0122236989606157641980521,0.0366637909687334933302153,0.0610819696041395681037870,0.0854636405045154986364980,0.1097942311276437466729747,0.1340591994611877851175753,0.1582440427142249339974755,0.1823343059853371824103826,0.2063155909020792171540580,0.2301735642266599864109866,0.2538939664226943208556180,0.2774626201779044028062316,0.3008654388776772026671541,0.3240884350244133751832523,0.3471177285976355084261628,0.3699395553498590266165917,0.3925402750332674427356482,0.4149063795522750154922739,0.4370245010371041629370429,0.4588814198335521954490891,0.4804640724041720258582757,0.5017595591361444642896063,0.5227551520511754784539479,0.5434383024128103634441936,0.5637966482266180839144308,0.5838180216287630895500389,0.6034904561585486242035732,0.6228021939105849107615396,0.6417416925623075571535249,0.6602976322726460521059468,0.6784589224477192593677557,0.6962147083695143323850866,0.7135543776835874133438599,0.7304675667419088064717369,0.7469441667970619811698824,0.7629743300440947227797691,0.7785484755064119668504941,0.7936572947621932902433329,0.8082917575079136601196422,0.8224431169556438424645942,0.8361029150609068471168753,0.8492629875779689691636001,0.8619154689395484605906323,0.8740527969580317986954180,0.8856677173453972174082924,0.8967532880491581843864474,0.9073028834017568139214859,0.9173101980809605370364836,0.9267692508789478433346245,0.9356743882779163757831268,0.9440202878302201821211114,0.9518019613412643862177963,0.9590147578536999280989185,0.9656543664319652686458290,0.9717168187471365809043384,0.9771984914639073871653744,0.9820961084357185360247656,0.9864067427245862088712355,0.9901278184917343833379303,0.9932571129002129353034372,0.9957927585349811868641612,0.9977332486255140198821574,0.9990774599773758950119878,0.9998248879471319144736081];
238
239 #[rustfmt::skip]
240 #[allow(clippy::excessive_precision)]
241 let w_should = [0.0244461801962625182113259,0.0244315690978500450548486,0.0244023556338495820932980,0.0243585572646906258532685,0.0243002001679718653234426,0.0242273192228152481200933,0.0241399579890192849977167,0.0240381686810240526375873,0.0239220121367034556724504,0.0237915577810034006387807,0.0236468835844476151436514,0.0234880760165359131530253,0.0233152299940627601224157,0.0231284488243870278792979,0.0229278441436868469204110,0.0227135358502364613097126,0.0224856520327449668718246,0.0222443288937997651046291,0.0219897106684604914341221,0.0217219495380520753752610,0.0214412055392084601371119,0.0211476464682213485370195,0.0208414477807511491135839,0.0205227924869600694322850,0.0201918710421300411806732,0.0198488812328308622199444,0.0194940280587066028230219,0.0191275236099509454865185,0.0187495869405447086509195,0.0183604439373313432212893,0.0179603271850086859401969,0.0175494758271177046487069,0.0171281354231113768306810,0.0166965578015892045890915,0.0162550009097851870516575,0.0158037286593993468589656,0.0153430107688651440859909,0.0148731226021473142523855,0.0143943450041668461768239,0.0139069641329519852442880,0.0134112712886163323144890,0.0129075627392673472204428,0.0123961395439509229688217,0.0118773073727402795758911,0.0113513763240804166932817,0.0108186607395030762476596,0.0102794790158321571332153,0.0097341534150068058635483,0.0091830098716608743344787,0.0086263777986167497049788,0.0080645898904860579729286,0.0074979819256347286876720,0.0069268925668988135634267,0.0063516631617071887872143,0.0057726375428656985893346,0.0051901618326763302050708,0.0046045842567029551182905,0.0040162549837386423131943,0.0034255260409102157743378,0.0028327514714579910952857,0.0022382884309626187436221,0.0016425030186690295387909,0.0010458126793403487793129,0.0004493809602920903763943];
242
243 for (i, x_val) in x_should.iter().rev().enumerate() {
244 assert_abs_diff_eq!(*x_val, x[i], epsilon = 0.000_000_1);
245 }
246 for (i, w_val) in w_should.iter().rev().enumerate() {
247 assert_abs_diff_eq!(*w_val, w[i], epsilon = 0.000_000_1);
248 }
249 }
250
251 #[test]
252 fn check_legendre_error() {
253 let legendre_rule = GaussLegendre::new(0);
254 assert!(legendre_rule.is_err());
255 assert_eq!(
256 format!("{}", legendre_rule.err().unwrap()),
257 "the degree of the Gauss-Legendre quadrature rule must be at least 2"
258 );
259
260 let legendre_rule = GaussLegendre::new(1);
261 assert!(legendre_rule.is_err());
262 assert_eq!(
263 format!("{}", legendre_rule.err().unwrap()),
264 "the degree of the Gauss-Legendre quadrature rule must be at least 2"
265 );
266 }
267
268 #[test]
269 fn check_derives() {
270 let quad = GaussLegendre::new(10).unwrap();
271 let quad_clone = quad.clone();
272 assert_eq!(quad, quad_clone);
273 let other_quad = GaussLegendre::new(3).unwrap();
274 assert_ne!(quad, other_quad);
275 }
276
277 #[test]
278 fn check_iterators() {
279 let rule = GaussLegendre::new(3).unwrap();
282
283 assert_abs_diff_eq!(
284 2.0 / 3.0,
285 rule.iter().fold(0.0, |tot, (n, w)| tot + n * n * w)
286 );
287
288 assert_abs_diff_eq!(
289 2.0 / 3.0,
290 rule.nodes()
291 .zip(rule.weights())
292 .fold(0.0, |tot, (n, w)| tot + n * n * w)
293 );
294
295 assert_abs_diff_eq!(
296 2.0 / 3.0,
297 rule.into_iter().fold(0.0, |tot, (n, w)| tot + n * n * w)
298 );
299 }
300
301 #[test]
302 fn integrate_linear() {
303 let quad = GaussLegendre::new(5).unwrap();
304 let integral = quad.integrate(0.0, 1.0, |x| x);
305 assert_abs_diff_eq!(integral, 0.5, epsilon = 1e-15);
306 }
307
308 #[test]
309 fn integrate_parabola() {
310 let quad = GaussLegendre::new(5).unwrap();
311 let integral = quad.integrate(0.0, 3.0, |x| x.powi(2));
312 assert_abs_diff_eq!(integral, 9.0, epsilon = 1e-13);
313 }
314
315 #[cfg(feature = "rayon")]
316 #[test]
317 fn par_integrate_linear() {
318 let quad = GaussLegendre::par_new(5).unwrap();
319 let integral = quad.par_integrate(0.0, 1.0, |x| x);
320 assert_abs_diff_eq!(integral, 0.5, epsilon = 1e-15);
321 }
322
323 #[cfg(feature = "rayon")]
324 #[test]
325 fn par_integrate_parabola() {
326 let quad = GaussLegendre::par_new(5).unwrap();
327 let integral = quad.par_integrate(0.0, 3.0, |x| x.powi(2));
328 assert_abs_diff_eq!(integral, 9.0, epsilon = 1e-13);
329 }
330
331 #[cfg(feature = "rayon")]
332 #[test]
333 fn check_legendre_error_rayon() {
334 assert!(GaussLegendre::par_new(0).is_err());
335 assert!(GaussLegendre::par_new(1).is_err());
336 }
337}