1#[derive(Debug, Clone)]
8pub struct PiecewiseLegendrePoly {
9 pub polyorder: usize,
11 pub xmin: f64,
13 pub xmax: f64,
15 pub knots: Vec<f64>,
17 pub delta_x: Vec<f64>,
19 pub data: mdarray::DTensor<f64, 2>,
21 pub symm: i32,
23 pub l: i32,
25 pub xm: Vec<f64>,
27 pub inv_xs: Vec<f64>,
29 pub norms: Vec<f64>,
31}
32
33impl PiecewiseLegendrePoly {
34 pub fn new(
36 data: mdarray::DTensor<f64, 2>,
37 knots: Vec<f64>,
38 l: i32,
39 delta_x: Option<Vec<f64>>,
40 symm: i32,
41 ) -> Self {
42 let polyorder = data.shape().0;
43 let nsegments = data.shape().1;
44
45 if knots.len() != nsegments + 1 {
46 panic!(
47 "Invalid knots array: expected {} knots, got {}",
48 nsegments + 1,
49 knots.len()
50 );
51 }
52
53 for i in 1..knots.len() {
55 if knots[i] <= knots[i - 1] {
56 panic!("Knots must be monotonically increasing");
57 }
58 }
59
60 let delta_x =
62 delta_x.unwrap_or_else(|| (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect());
63
64 for i in 0..delta_x.len() {
66 let expected = knots[i + 1] - knots[i];
67 if (delta_x[i] - expected).abs() > 1e-10 {
68 panic!("delta_x must match knots");
69 }
70 }
71
72 let xm: Vec<f64> = (0..nsegments)
74 .map(|i| 0.5 * (knots[i] + knots[i + 1]))
75 .collect();
76
77 let inv_xs: Vec<f64> = delta_x.iter().map(|&dx| 2.0 / dx).collect();
79
80 let norms: Vec<f64> = inv_xs.iter().map(|&inv_x| inv_x.sqrt()).collect();
82
83 Self {
84 polyorder,
85 xmin: knots[0],
86 xmax: knots[knots.len() - 1],
87 knots,
88 delta_x,
89 data,
90 symm,
91 l,
92 xm,
93 inv_xs,
94 norms,
95 }
96 }
97
98 pub fn with_data(&self, new_data: mdarray::DTensor<f64, 2>) -> Self {
100 Self {
101 data: new_data,
102 ..self.clone()
103 }
104 }
105
106 pub fn symm(&self) -> i32 {
108 self.symm
109 }
110
111 pub fn with_data_and_symmetry(
113 &self,
114 new_data: mdarray::DTensor<f64, 2>,
115 new_symm: i32,
116 ) -> Self {
117 Self {
118 data: new_data,
119 symm: new_symm,
120 ..self.clone()
121 }
122 }
123
124 pub fn rescale_domain(
139 &self,
140 new_knots: Vec<f64>,
141 new_delta_x: Option<Vec<f64>>,
142 new_symm: Option<i32>,
143 ) -> Self {
144 Self::new(
145 self.data.clone(),
146 new_knots,
147 self.l,
148 new_delta_x,
149 new_symm.unwrap_or(self.symm),
150 )
151 }
152
153 pub fn scale_data(&self, factor: f64) -> Self {
166 Self::with_data(
167 self,
168 mdarray::DTensor::<f64, 2>::from_fn(*self.data.shape(), |idx| self.data[idx] * factor),
169 )
170 }
171
172 pub fn evaluate(&self, x: f64) -> f64 {
174 let (i, x_tilde) = self.split(x);
175 let coeffs: Vec<f64> = (0..self.data.shape().0)
177 .map(|row| self.data[[row, i]])
178 .collect();
179 let value = self.evaluate_legendre_polynomial(x_tilde, &coeffs);
180 value * self.norms[i]
181 }
182
183 pub fn evaluate_many(&self, xs: &[f64]) -> Vec<f64> {
185 xs.iter().map(|&x| self.evaluate(x)).collect()
186 }
187
188 pub fn split(&self, x: f64) -> (usize, f64) {
190 if x < self.xmin || x > self.xmax {
191 panic!("x = {} is outside domain [{}, {}]", x, self.xmin, self.xmax);
192 }
193
194 for i in 0..self.knots.len() - 1 {
196 if x >= self.knots[i] && x <= self.knots[i + 1] {
197 let x_tilde = 2.0 * (x - self.xm[i]) / self.delta_x[i];
199 return (i, x_tilde);
200 }
201 }
202
203 let last_idx = self.knots.len() - 2;
205 let x_tilde = 2.0 * (x - self.xm[last_idx]) / self.delta_x[last_idx];
206 (last_idx, x_tilde)
207 }
208
209 pub fn evaluate_legendre_polynomial(&self, x: f64, coeffs: &[f64]) -> f64 {
211 if coeffs.is_empty() {
212 return 0.0;
213 }
214
215 let mut result = 0.0;
216 let mut p_prev = 1.0; let mut p_curr = x; if !coeffs.is_empty() {
221 result += coeffs[0] * p_prev;
222 }
223 if coeffs.len() > 1 {
224 result += coeffs[1] * p_curr;
225 }
226
227 for n in 1..coeffs.len() - 1 {
229 let p_next =
230 ((2.0 * (n as f64) + 1.0) * x * p_curr - (n as f64) * p_prev) / ((n + 1) as f64);
231 result += coeffs[n + 1] * p_next;
232 p_prev = p_curr;
233 p_curr = p_next;
234 }
235
236 result
237 }
238
239 pub fn deriv(&self, n: usize) -> Self {
241 if n == 0 {
242 return self.clone();
243 }
244
245 let mut ddata = self.data.clone();
247 for _ in 0..n {
248 ddata = self.compute_derivative_coefficients(&ddata);
249 }
250
251 let ddata_shape = *ddata.shape();
253 for i in 0..ddata_shape.1 {
254 let inv_x_power = self.inv_xs[i].powi(n as i32);
255 for j in 0..ddata_shape.0 {
256 ddata[[j, i]] *= inv_x_power;
257 }
258 }
259
260 let new_symm = if n % 2 == 0 { self.symm } else { -self.symm };
262
263 Self {
264 data: ddata,
265 symm: new_symm,
266 ..self.clone()
267 }
268 }
269
270 fn compute_derivative_coefficients(
272 &self,
273 coeffs: &mdarray::DTensor<f64, 2>,
274 ) -> mdarray::DTensor<f64, 2> {
275 let mut c = coeffs.clone();
276 let c_shape = *c.shape();
277 let mut n = c_shape.0;
278
279 if n <= 1 {
281 return mdarray::DTensor::<f64, 2>::from_elem([1, c.shape().1], 0.0);
282 }
283
284 n -= 1;
285 let mut der = mdarray::DTensor::<f64, 2>::from_elem([n, c.shape().1], 0.0);
286
287 for j in (2..=n).rev() {
289 for col in 0..c_shape.1 {
291 der[[j - 1, col]] = (2.0 * (j as f64) - 1.0) * c[[j, col]];
292 }
293 for col in 0..c_shape.1 {
295 c[[j - 2, col]] += c[[j, col]];
296 }
297 }
298
299 if n > 1 {
301 for col in 0..c_shape.1 {
302 der[[1, col]] = 3.0 * c[[2, col]];
303 }
304 }
305
306 for col in 0..c_shape.1 {
308 der[[0, col]] = c[[1, col]];
309 }
310
311 der
312 }
313
314 pub fn derivs(&self, x: f64) -> Vec<f64> {
316 let mut results = Vec::new();
317
318 for n in 0..self.polyorder {
320 let deriv_poly = self.deriv(n);
321 results.push(deriv_poly.evaluate(x));
322 }
323
324 results
325 }
326
327 pub fn overlap<F>(&self, f: F) -> f64
329 where
330 F: Fn(f64) -> f64,
331 {
332 let mut integral = 0.0;
333
334 for i in 0..self.knots.len() - 1 {
335 let segment_integral =
336 self.gauss_legendre_quadrature(self.knots[i], self.knots[i + 1], |x| {
337 self.evaluate(x) * f(x)
338 });
339 integral += segment_integral;
340 }
341
342 integral
343 }
344
345 fn gauss_legendre_quadrature<F>(&self, a: f64, b: f64, f: F) -> f64
347 where
348 F: Fn(f64) -> f64,
349 {
350 const XG: [f64; 5] = [
352 -0.906179845938664,
353 -0.538469310105683,
354 0.0,
355 0.538469310105683,
356 0.906179845938664,
357 ];
358 const WG: [f64; 5] = [
359 0.236926885056189,
360 0.478628670499366,
361 0.568888888888889,
362 0.478628670499366,
363 0.236926885056189,
364 ];
365
366 let c1 = (b - a) / 2.0;
367 let c2 = (b + a) / 2.0;
368
369 let mut integral = 0.0;
370 for j in 0..5 {
371 let x = c1 * XG[j] + c2;
372 integral += WG[j] * f(x);
373 }
374
375 integral * c1
376 }
377
378 pub fn roots(&self) -> Vec<f64> {
380 let refined_grid = self.refine_grid(&self.knots, 4);
383
384 self.find_all_roots(&refined_grid)
386 }
387
388 fn refine_grid(&self, grid: &[f64], alpha: usize) -> Vec<f64> {
390 let mut refined = Vec::new();
391
392 for i in 0..grid.len() - 1 {
393 let start = grid[i];
394 let step = (grid[i + 1] - grid[i]) / (alpha as f64);
395 for j in 0..alpha {
396 refined.push(start + (j as f64) * step);
397 }
398 }
399 refined.push(grid[grid.len() - 1]);
400 refined
401 }
402
403 fn find_all_roots(&self, xgrid: &[f64]) -> Vec<f64> {
405 if xgrid.is_empty() {
406 return Vec::new();
407 }
408
409 let fx: Vec<f64> = xgrid.iter().map(|&x| self.evaluate(x)).collect();
411
412 let mut x_hit = Vec::new();
414 for i in 0..fx.len() {
415 if fx[i] == 0.0 {
416 x_hit.push(xgrid[i]);
417 }
418 }
419
420 let mut sign_change = Vec::new();
422 for i in 0..fx.len() - 1 {
423 let has_sign_change = fx[i].signum() != fx[i + 1].signum();
424 let not_hit = fx[i] != 0.0 && fx[i + 1] != 0.0;
425 let sc = has_sign_change && not_hit;
426 sign_change.push(sc);
427 }
428
429 if sign_change.iter().all(|&sc| !sc) {
431 x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
432 return x_hit;
433 }
434
435 let mut a_intervals = Vec::new();
437 let mut b_intervals = Vec::new();
438 let mut fa_values = Vec::new();
439
440 for i in 0..sign_change.len() {
441 if sign_change[i] {
442 a_intervals.push(xgrid[i]);
443 b_intervals.push(xgrid[i + 1]);
444 fa_values.push(fx[i]);
445 }
446 }
447
448 let max_elm = xgrid.iter().map(|&x| x.abs()).fold(0.0, f64::max);
450 let epsilon_x = f64::EPSILON * max_elm;
451
452 for i in 0..a_intervals.len() {
454 let root = self.bisect(a_intervals[i], b_intervals[i], fa_values[i], epsilon_x);
455 x_hit.push(root);
456 }
457
458 x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
460 x_hit
461 }
462
463 fn bisect(&self, a: f64, b: f64, fa: f64, eps: f64) -> f64 {
465 let mut a = a;
466 let mut b = b;
467 let mut fa = fa;
468
469 loop {
470 let mid = (a + b) / 2.0;
471 if self.close_enough(a, mid, eps) {
472 return mid;
473 }
474
475 let fmid = self.evaluate(mid);
476 if fa.signum() != fmid.signum() {
477 b = mid;
478 } else {
479 a = mid;
480 fa = fmid;
481 }
482 }
483 }
484
485 fn close_enough(&self, a: f64, b: f64, eps: f64) -> bool {
487 (a - b).abs() <= eps
488 }
489
490 pub fn get_xmin(&self) -> f64 {
492 self.xmin
493 }
494 pub fn get_xmax(&self) -> f64 {
495 self.xmax
496 }
497 pub fn get_l(&self) -> i32 {
498 self.l
499 }
500 pub fn get_domain(&self) -> (f64, f64) {
501 (self.xmin, self.xmax)
502 }
503 pub fn get_knots(&self) -> &[f64] {
504 &self.knots
505 }
506 pub fn get_delta_x(&self) -> &[f64] {
507 &self.delta_x
508 }
509 pub fn get_symm(&self) -> i32 {
510 self.symm
511 }
512 pub fn get_data(&self) -> &mdarray::DTensor<f64, 2> {
513 &self.data
514 }
515 pub fn get_norms(&self) -> &[f64] {
516 &self.norms
517 }
518 pub fn get_polyorder(&self) -> usize {
519 self.polyorder
520 }
521}
522
523#[derive(Debug, Clone)]
525pub struct PiecewiseLegendrePolyVector {
526 pub polyvec: Vec<PiecewiseLegendrePoly>,
528}
529
530impl PiecewiseLegendrePolyVector {
531 pub fn new(polyvec: Vec<PiecewiseLegendrePoly>) -> Self {
536 if polyvec.is_empty() {
537 panic!("Cannot create empty PiecewiseLegendrePolyVector");
538 }
539 Self { polyvec }
540 }
541
542 pub fn get_polys(&self) -> &[PiecewiseLegendrePoly] {
544 &self.polyvec
545 }
546
547 pub fn from_3d_data(
549 data3d: mdarray::DTensor<f64, 3>,
550 knots: Vec<f64>,
551 symm: Option<Vec<i32>>,
552 ) -> Self {
553 let npolys = data3d.shape().2;
554 let mut polyvec = Vec::with_capacity(npolys);
555
556 if let Some(ref symm_vec) = symm {
557 if symm_vec.len() != npolys {
558 panic!("Sizes of data and symm don't match");
559 }
560 }
561
562 let delta_x: Vec<f64> = (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect();
564
565 for i in 0..npolys {
566 let data3d_shape = data3d.shape();
568 let mut data =
569 mdarray::DTensor::<f64, 2>::from_elem([data3d_shape.0, data3d_shape.1], 0.0);
570 for j in 0..data3d_shape.0 {
571 for k in 0..data3d_shape.1 {
572 data[[j, k]] = data3d[[j, k, i]];
573 }
574 }
575
576 let poly = PiecewiseLegendrePoly::new(
577 data,
578 knots.clone(),
579 i as i32,
580 Some(delta_x.clone()),
581 symm.as_ref().map_or(0, |s| s[i]),
582 );
583
584 polyvec.push(poly);
585 }
586
587 Self { polyvec }
588 }
589
590 pub fn size(&self) -> usize {
592 self.polyvec.len()
593 }
594
595 pub fn rescale_domain(
610 &self,
611 new_knots: Vec<f64>,
612 new_delta_x: Option<Vec<f64>>,
613 new_symm: Option<Vec<i32>>,
614 ) -> Self {
615 let polyvec = self
616 .polyvec
617 .iter()
618 .enumerate()
619 .map(|(i, poly)| {
620 let symm = new_symm.as_ref().map(|s| s[i]);
621 poly.rescale_domain(new_knots.clone(), new_delta_x.clone(), symm)
622 })
623 .collect();
624
625 Self { polyvec }
626 }
627
628 pub fn scale_data(&self, factor: f64) -> Self {
640 let polyvec = self
641 .polyvec
642 .iter()
643 .map(|poly| poly.scale_data(factor))
644 .collect();
645
646 Self { polyvec }
647 }
648
649 pub fn get(&self, index: usize) -> Option<&PiecewiseLegendrePoly> {
651 self.polyvec.get(index)
652 }
653
654 #[deprecated(
656 note = "PiecewiseLegendrePolyVector is designed to be immutable. Use get() and create new instances for modifications."
657 )]
658 pub fn get_mut(&mut self, index: usize) -> Option<&mut PiecewiseLegendrePoly> {
659 self.polyvec.get_mut(index)
660 }
661
662 pub fn slice_single(&self, index: usize) -> Option<Self> {
664 self.polyvec.get(index).map(|poly| Self {
665 polyvec: vec![poly.clone()],
666 })
667 }
668
669 pub fn slice_multi(&self, indices: &[usize]) -> Self {
671 for &idx in indices {
673 if idx >= self.polyvec.len() {
674 panic!("Index {} out of range", idx);
675 }
676 }
677
678 {
680 let mut unique_indices = indices.to_vec();
681 unique_indices.sort();
682 unique_indices.dedup();
683 if unique_indices.len() != indices.len() {
684 panic!("Duplicate indices not allowed");
685 }
686 }
687
688 let new_polyvec: Vec<_> = indices
689 .iter()
690 .map(|&idx| self.polyvec[idx].clone())
691 .collect();
692
693 Self {
694 polyvec: new_polyvec,
695 }
696 }
697
698 pub fn evaluate_at(&self, x: f64) -> Vec<f64> {
700 self.polyvec.iter().map(|poly| poly.evaluate(x)).collect()
701 }
702
703 pub fn evaluate_at_many(&self, xs: &[f64]) -> mdarray::DTensor<f64, 2> {
705 let n_funcs = self.polyvec.len();
706 let n_points = xs.len();
707 let mut results = mdarray::DTensor::<f64, 2>::from_elem([n_funcs, n_points], 0.0);
708
709 for (i, poly) in self.polyvec.iter().enumerate() {
710 for (j, &x) in xs.iter().enumerate() {
711 results[[i, j]] = poly.evaluate(x);
712 }
713 }
714
715 results
716 }
717
718 pub fn xmin(&self) -> f64 {
720 if self.polyvec.is_empty() {
721 panic!("Cannot get xmin from empty PiecewiseLegendrePolyVector");
722 }
723 self.polyvec[0].xmin
724 }
725
726 pub fn xmax(&self) -> f64 {
727 if self.polyvec.is_empty() {
728 panic!("Cannot get xmax from empty PiecewiseLegendrePolyVector");
729 }
730 self.polyvec[0].xmax
731 }
732
733 pub fn get_knots(&self, tolerance: Option<f64>) -> Vec<f64> {
734 if self.polyvec.is_empty() {
735 panic!("Cannot get knots from empty PiecewiseLegendrePolyVector");
736 }
737 const DEFAULT_TOLERANCE: f64 = 1e-10;
738 let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
739
740 let mut all_knots = Vec::new();
742 for poly in &self.polyvec {
743 for &knot in &poly.knots {
744 all_knots.push(knot);
745 }
746 }
747
748 {
750 all_knots.sort_by(|a, b| a.partial_cmp(b).unwrap());
751 all_knots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
752 }
753 all_knots
754 }
755
756 pub fn get_delta_x(&self) -> Vec<f64> {
757 if self.polyvec.is_empty() {
758 panic!("Cannot get delta_x from empty PiecewiseLegendrePolyVector");
759 }
760 self.polyvec[0].delta_x.clone()
761 }
762
763 pub fn get_polyorder(&self) -> usize {
764 if self.polyvec.is_empty() {
765 panic!("Cannot get polyorder from empty PiecewiseLegendrePolyVector");
766 }
767 self.polyvec[0].polyorder
768 }
769
770 pub fn get_norms(&self) -> &[f64] {
771 if self.polyvec.is_empty() {
772 panic!("Cannot get norms from empty PiecewiseLegendrePolyVector");
773 }
774 &self.polyvec[0].norms
775 }
776
777 pub fn get_symm(&self) -> Vec<i32> {
778 if self.polyvec.is_empty() {
779 panic!("Cannot get symm from empty PiecewiseLegendrePolyVector");
780 }
781 self.polyvec.iter().map(|poly| poly.symm).collect()
782 }
783
784 pub fn get_data(&self) -> mdarray::DTensor<f64, 3> {
786 if self.polyvec.is_empty() {
787 panic!("Cannot get data from empty PiecewiseLegendrePolyVector");
788 }
789
790 let nsegments = self.polyvec[0].data.shape().1;
791 let polyorder = self.polyvec[0].polyorder;
792 let npolys = self.polyvec.len();
793
794 let mut data = mdarray::DTensor::<f64, 3>::from_elem([nsegments, polyorder, npolys], 0.0);
795
796 for (poly_idx, poly) in self.polyvec.iter().enumerate() {
797 for segment in 0..nsegments {
798 for degree in 0..polyorder {
799 data[[segment, degree, poly_idx]] = poly.data[[degree, segment]];
800 }
801 }
802 }
803
804 data
805 }
806
807 pub fn roots(&self, tolerance: Option<f64>) -> Vec<f64> {
809 if self.polyvec.is_empty() {
810 panic!("Cannot get roots from empty PiecewiseLegendrePolyVector");
811 }
812 const DEFAULT_TOLERANCE: f64 = 1e-10;
813 let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
814 let mut all_roots = Vec::new();
815
816 for poly in &self.polyvec {
817 let poly_roots = poly.roots();
818 for root in poly_roots {
819 all_roots.push(root);
820 }
821 }
822
823 {
825 all_roots.sort_by(|a, b| b.partial_cmp(a).unwrap());
826 all_roots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
827 }
828 all_roots
829 }
830
831 pub fn last(&self) -> &PiecewiseLegendrePoly {
835 self.polyvec
836 .last()
837 .expect("Cannot get last from empty PiecewiseLegendrePolyVector")
838 }
839
840 pub fn nroots(&self, tolerance: Option<f64>) -> usize {
842 if self.polyvec.is_empty() {
843 panic!("Cannot get nroots from empty PiecewiseLegendrePolyVector");
844 }
845 self.roots(tolerance).len()
846 }
847}
848
849impl std::ops::Index<usize> for PiecewiseLegendrePolyVector {
850 type Output = PiecewiseLegendrePoly;
851
852 fn index(&self, index: usize) -> &Self::Output {
853 &self.polyvec[index]
854 }
855}
856
857pub fn default_sampling_points(u: &PiecewiseLegendrePolyVector, l: usize) -> Vec<f64> {
871 if (u.xmin() - (-1.0)).abs() > 1e-10 || (u.xmax() - 1.0).abs() > 1e-10 {
874 panic!("Expecting unscaled functions here.");
875 }
876
877 let x0 = if l < u.polyvec.len() {
878 u[l].roots()
880 } else {
881 let poly = u.last();
884 let poly_deriv = poly.deriv(1);
885 let maxima = poly_deriv.roots();
886
887 let left = (maxima[0] + poly.xmin) / 2.0;
889
890 let right = (maxima[maxima.len() - 1] + poly.xmax) / 2.0;
892
893 let mut x0_vec = Vec::with_capacity(maxima.len() + 2);
898 x0_vec.push(left);
899 x0_vec.extend_from_slice(&maxima);
900 x0_vec.push(right);
901 x0_vec
902 };
903
904 if x0.len() != l {
906 eprintln!(
907 "Warning: Expecting to get {} sampling points for corresponding basis function, \
908 instead got {}. This may happen if not enough precision is left in the polynomial.",
909 l,
910 x0.len()
911 );
912 }
913
914 x0
915}
916
917#[cfg(test)]
924#[path = "poly_tests.rs"]
925mod poly_tests;