1#[cfg(feature = "rayon")]
22use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
23
24use crate::gamma::gamma;
25use crate::{DMatrixf64, Node, Weight, __impl_node_weight_rule};
26
27use std::backtrace::Backtrace;
28
29#[derive(Debug, Clone, PartialEq)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51pub struct GaussLaguerre {
52 node_weight_pairs: Vec<(Node, Weight)>,
53 alpha: f64,
54}
55
56impl GaussLaguerre {
57 pub fn new(deg: usize, alpha: f64) -> Result<Self, GaussLaguerreError> {
74 match (deg >= 2, (alpha.is_finite() && alpha > -1.0)) {
75 (true, true) => Ok(()),
76 (false, true) => Err(GaussLaguerreErrorReason::Degree),
77 (true, false) => Err(GaussLaguerreErrorReason::Alpha),
78 (false, false) => Err(GaussLaguerreErrorReason::DegreeAlpha),
79 }
80 .map_err(GaussLaguerreError::new)?;
81
82 let mut companion_matrix = DMatrixf64::from_element(deg, deg, 0.0);
83
84 let mut diag = alpha + 1.0;
85 for idx in 0..deg - 1 {
87 let idx_f64 = 1.0 + idx as f64;
88 let off_diag = (idx_f64 * (idx_f64 + alpha)).sqrt();
89 unsafe {
90 *companion_matrix.get_unchecked_mut((idx, idx)) = diag;
91 *companion_matrix.get_unchecked_mut((idx, idx + 1)) = off_diag;
92 *companion_matrix.get_unchecked_mut((idx + 1, idx)) = off_diag;
93 }
94 diag += 2.0;
95 }
96 unsafe {
97 *companion_matrix.get_unchecked_mut((deg - 1, deg - 1)) = diag;
98 }
99 let eigen = companion_matrix.symmetric_eigen();
101
102 let scale_factor = gamma(alpha + 1.0);
103
104 let mut node_weight_pairs: Vec<(f64, f64)> = eigen
106 .eigenvalues
107 .into_iter()
108 .copied()
109 .zip(
110 (eigen.eigenvectors.row(0).map(|x| x * x) * scale_factor)
111 .into_iter()
112 .copied(),
113 )
114 .collect();
115 node_weight_pairs.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
116
117 Ok(GaussLaguerre {
118 node_weight_pairs,
119 alpha,
120 })
121 }
122
123 pub fn integrate<F>(&self, integrand: F) -> f64
127 where
128 F: Fn(f64) -> f64,
129 {
130 let result: f64 = self
131 .node_weight_pairs
132 .iter()
133 .map(|(x_val, w_val)| integrand(*x_val) * w_val)
134 .sum();
135 result
136 }
137
138 #[cfg(feature = "rayon")]
139 pub fn par_integrate<F>(&self, integrand: F) -> f64
141 where
142 F: Fn(f64) -> f64 + Sync,
143 {
144 let result: f64 = self
145 .node_weight_pairs
146 .par_iter()
147 .map(|(x_val, w_val)| integrand(*x_val) * w_val)
148 .sum();
149 result
150 }
151
152 #[inline]
154 pub const fn alpha(&self) -> f64 {
155 self.alpha
156 }
157}
158
159__impl_node_weight_rule! {GaussLaguerre, GaussLaguerreNodes, GaussLaguerreWeights, GaussLaguerreIter, GaussLaguerreIntoIter}
160
161#[derive(Debug)]
163pub struct GaussLaguerreError {
164 reason: GaussLaguerreErrorReason,
165 backtrace: Backtrace,
166}
167
168use core::fmt;
169impl fmt::Display for GaussLaguerreError {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 const DEGREE_LIMIT: &str = "degree must be at least 2";
172 const ALPHA_LIMIT: &str = "alpha must be larger than -1.0";
173 match self.reason() {
174 GaussLaguerreErrorReason::Degree => write!(f, "{DEGREE_LIMIT}"),
175 GaussLaguerreErrorReason::Alpha => write!(f, "{ALPHA_LIMIT}"),
176 GaussLaguerreErrorReason::DegreeAlpha => write!(f, "{DEGREE_LIMIT}, and {ALPHA_LIMIT}"),
177 }
178 }
179}
180
181impl GaussLaguerreError {
182 #[inline]
184 pub(crate) fn new(reason: GaussLaguerreErrorReason) -> Self {
185 Self {
186 reason,
187 backtrace: Backtrace::capture(),
188 }
189 }
190
191 #[inline]
193 pub fn reason(&self) -> GaussLaguerreErrorReason {
194 self.reason
195 }
196
197 #[inline]
201 pub fn backtrace(&self) -> &Backtrace {
202 &self.backtrace
203 }
204}
205
206impl std::error::Error for GaussLaguerreError {}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
210#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
211pub enum GaussLaguerreErrorReason {
212 Degree,
214 Alpha,
216 DegreeAlpha,
218}
219
220impl GaussLaguerreErrorReason {
221 #[inline]
223 pub fn was_bad_degree(&self) -> bool {
224 matches!(self, Self::Degree | Self::DegreeAlpha)
225 }
226
227 #[inline]
229 pub fn was_bad_alpha(&self) -> bool {
230 matches!(self, Self::Alpha | Self::DegreeAlpha)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use approx::assert_abs_diff_eq;
238 use core::f64::consts::PI;
239
240 #[test]
241 fn golub_welsch_2_alpha_5() {
242 let (x, w): (Vec<_>, Vec<_>) = GaussLaguerre::new(2, 5.0).unwrap().into_iter().unzip();
243 let x_should = [4.354_248_688_935_409, 9.645_751_311_064_59];
244 let w_should = [82.677_868_380_553_63, 37.322_131_619_446_37];
245 for (i, x_val) in x_should.iter().enumerate() {
246 assert_abs_diff_eq!(*x_val, x[i], epsilon = 1e-12);
247 }
248 for (i, w_val) in w_should.iter().enumerate() {
249 assert_abs_diff_eq!(*w_val, w[i], epsilon = 1e-12);
250 }
251 }
252
253 #[test]
254 fn golub_welsch_3_alpha_0() {
255 let (x, w): (Vec<_>, Vec<_>) = GaussLaguerre::new(3, 0.0).unwrap().into_iter().unzip();
256 let x_should = [
257 0.415_774_556_783_479_1,
258 2.294_280_360_279_042,
259 6.289_945_082_937_479_4,
260 ];
261 let w_should = [
262 0.711_093_009_929_173,
263 0.278_517_733_569_240_87,
264 0.010_389_256_501_586_135,
265 ];
266 for (i, x_val) in x_should.iter().enumerate() {
267 assert_abs_diff_eq!(*x_val, x[i], epsilon = 1e-14);
268 }
269 for (i, w_val) in w_should.iter().enumerate() {
270 assert_abs_diff_eq!(*w_val, w[i], epsilon = 1e-14);
271 }
272 }
273
274 #[test]
275 fn golub_welsch_3_alpha_1_5() {
276 let (x, w): (Vec<_>, Vec<_>) = GaussLaguerre::new(3, 1.5).unwrap().into_iter().unzip();
277 let x_should = [
278 1.220_402_317_558_883_8,
279 3.808_880_721_467_068,
280 8.470_716_960_974_048,
281 ];
282 let w_should = [
283 0.730_637_894_350_016,
284 0.566_249_100_686_605_7,
285 0.032_453_393_142_515_25,
286 ];
287 for (i, x_val) in x_should.iter().enumerate() {
288 assert_abs_diff_eq!(*x_val, x[i], epsilon = 1e-14);
289 }
290 for (i, w_val) in w_should.iter().enumerate() {
291 assert_abs_diff_eq!(*w_val, w[i], epsilon = 1e-14);
292 }
293 }
294
295 #[test]
296 fn golub_welsch_5_alpha_negative() {
297 let (x, w): (Vec<_>, Vec<_>) = GaussLaguerre::new(5, -0.9).unwrap().into_iter().unzip();
298 let x_should = [
299 0.020_777_151_319_288_104,
300 0.808_997_536_134_602_1,
301 2.674_900_020_624_07,
302 5.869_026_089_963_398,
303 11.126_299_201_958_641,
304 ];
305 let w_should = [
306 8.738_289_241_242_436,
307 0.702_782_353_089_744_5,
308 0.070_111_720_632_849_48,
309 0.002_312_760_116_115_564,
310 1.162_358_758_613_074_8E-5,
311 ];
312 for (i, x_val) in x_should.iter().enumerate() {
313 assert_abs_diff_eq!(*x_val, x[i], epsilon = 1e-14);
314 }
315 for (i, w_val) in w_should.iter().enumerate() {
316 assert_abs_diff_eq!(*w_val, w[i], epsilon = 1e-14);
317 }
318 }
319
320 #[test]
321 fn check_laguerre_error() {
322 let laguerre_rule = GaussLaguerre::new(0, -0.25);
324 assert!(laguerre_rule
325 .as_ref()
326 .is_err_and(|x| x.reason() == GaussLaguerreErrorReason::Degree));
327 assert_eq!(
328 format!("{}", laguerre_rule.err().unwrap()),
329 "degree must be at least 2"
330 );
331 assert_eq!(
332 GaussLaguerre::new(0, -0.25).map_err(|e| e.reason()),
333 Err(GaussLaguerreErrorReason::Degree)
334 );
335
336 assert_eq!(
337 GaussLaguerre::new(1, -0.25).map_err(|e| e.reason()),
338 Err(GaussLaguerreErrorReason::Degree)
339 );
340
341 let laguerre_rule = GaussLaguerre::new(5, -1.0);
342 assert!(laguerre_rule
343 .as_ref()
344 .is_err_and(|x| x.reason() == GaussLaguerreErrorReason::Alpha));
345 assert_eq!(
346 format!("{}", laguerre_rule.err().unwrap()),
347 "alpha must be larger than -1.0"
348 );
349
350 assert_eq!(
351 GaussLaguerre::new(5, -1.0).map_err(|e| e.reason()),
352 Err(GaussLaguerreErrorReason::Alpha)
353 );
354 assert_eq!(
355 GaussLaguerre::new(5, -2.0).map_err(|e| e.reason()),
356 Err(GaussLaguerreErrorReason::Alpha)
357 );
358
359 let laguerre_rule = GaussLaguerre::new(0, -1.0);
360 assert!(laguerre_rule
361 .as_ref()
362 .is_err_and(|x| x.reason() == GaussLaguerreErrorReason::DegreeAlpha));
363 assert_eq!(
364 format!("{}", laguerre_rule.err().unwrap()),
365 "degree must be at least 2, and alpha must be larger than -1.0"
366 );
367
368 assert_eq!(
369 GaussLaguerre::new(0, -1.0).map_err(|e| e.reason()),
370 Err(GaussLaguerreErrorReason::DegreeAlpha)
371 );
372
373 assert_eq!(
374 GaussLaguerre::new(0, -2.0).map_err(|e| e.reason()),
375 Err(GaussLaguerreErrorReason::DegreeAlpha)
376 );
377 assert_eq!(
378 GaussLaguerre::new(1, -1.0).map_err(|e| e.reason()),
379 Err(GaussLaguerreErrorReason::DegreeAlpha)
380 );
381 assert_eq!(
382 GaussLaguerre::new(1, -2.0).map_err(|e| e.reason()),
383 Err(GaussLaguerreErrorReason::DegreeAlpha)
384 );
385 }
386
387 #[test]
388 fn check_derives() {
389 let quad = GaussLaguerre::new(10, 1.0).unwrap();
390 let quad_clone = quad.clone();
391 assert_eq!(quad, quad_clone);
392 let other_quad = GaussLaguerre::new(10, 2.0).unwrap();
393 assert_ne!(quad, other_quad);
394 }
395
396 #[test]
397 fn check_iterators() {
398 let rule = GaussLaguerre::new(3, 0.5).unwrap();
399
400 let ans = 15.0 / 8.0 * core::f64::consts::PI.sqrt();
401
402 assert_abs_diff_eq!(
403 rule.iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
404 ans,
405 epsilon = 1e-14
406 );
407
408 assert_abs_diff_eq!(
409 rule.nodes()
410 .zip(rule.weights())
411 .fold(0.0, |tot, (n, w)| tot + n * n * w),
412 ans,
413 epsilon = 1e-14
414 );
415
416 assert_abs_diff_eq!(
417 rule.into_iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
418 ans,
419 epsilon = 1e-14
420 );
421 }
422
423 #[test]
424 fn check_some_integrals() {
425 let rule = GaussLaguerre::new(10, -0.5).unwrap();
426
427 assert_abs_diff_eq!(
428 rule.integrate(|x| x * x),
429 3.0 * PI.sqrt() / 4.0,
430 epsilon = 1e-14
431 );
432
433 assert_abs_diff_eq!(
434 rule.integrate(|x| x.sin()),
435 (PI.sqrt() * (PI / 8.0).sin()) / (2.0_f64.powf(0.25)),
436 epsilon = 1e-7,
437 );
438 }
439
440 #[cfg(feature = "rayon")]
441 #[test]
442 fn par_check_some_integrals() {
443 let rule = GaussLaguerre::new(10, -0.5).unwrap();
444
445 assert_abs_diff_eq!(
446 rule.par_integrate(|x| x * x),
447 3.0 * PI.sqrt() / 4.0,
448 epsilon = 1e-14
449 );
450
451 assert_abs_diff_eq!(
452 rule.par_integrate(|x| x.sin()),
453 (PI.sqrt() * (PI / 8.0).sin()) / (2.0_f64.powf(0.25)),
454 epsilon = 1e-7,
455 );
456 }
457}