sparse_ir/
poly.rs

1//! Piecewise Legendre polynomial implementations for SparseIR
2//!
3//! This module provides high-performance piecewise Legendre polynomial
4//! functionality compatible with the C++ implementation.
5
6/// A single piecewise Legendre polynomial
7#[derive(Debug, Clone)]
8pub struct PiecewiseLegendrePoly {
9    /// Polynomial order (degree of Legendre polynomials in each segment)
10    pub polyorder: usize,
11    /// Minimum x value of the domain
12    pub xmin: f64,
13    /// Maximum x value of the domain
14    pub xmax: f64,
15    /// Knot points defining the segments
16    pub knots: Vec<f64>,
17    /// Segment widths (for numerical stability)
18    pub delta_x: Vec<f64>,
19    /// Coefficient matrix: [degree][segment_index]
20    pub data: mdarray::DTensor<f64, 2>,
21    /// Symmetry parameter
22    pub symm: i32,
23    /// Polynomial parameter (used in power moments calculation)
24    pub l: i32,
25    /// Segment midpoints
26    pub xm: Vec<f64>,
27    /// Inverse segment widths
28    pub inv_xs: Vec<f64>,
29    /// Normalization factors
30    pub norms: Vec<f64>,
31}
32
33impl PiecewiseLegendrePoly {
34    /// Create a new PiecewiseLegendrePoly from data and knots
35    pub fn new(
36        data: mdarray::DTensor<f64, 2>,
37        knots: Vec<f64>,
38        l: i32,
39        delta_x: Option<Vec<f64>>,
40        symm: i32,
41    ) -> Self {
42        let polyorder = data.shape().0;
43        let nsegments = data.shape().1;
44
45        if knots.len() != nsegments + 1 {
46            panic!(
47                "Invalid knots array: expected {} knots, got {}",
48                nsegments + 1,
49                knots.len()
50            );
51        }
52
53        // Validate knots are sorted
54        for i in 1..knots.len() {
55            if knots[i] <= knots[i - 1] {
56                panic!("Knots must be monotonically increasing");
57            }
58        }
59
60        // Compute delta_x if not provided
61        let delta_x =
62            delta_x.unwrap_or_else(|| (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect());
63
64        // Validate delta_x matches knots
65        for i in 0..delta_x.len() {
66            let expected = knots[i + 1] - knots[i];
67            if (delta_x[i] - expected).abs() > 1e-10 {
68                panic!("delta_x must match knots");
69            }
70        }
71
72        // Compute segment midpoints
73        let xm: Vec<f64> = (0..nsegments)
74            .map(|i| 0.5 * (knots[i] + knots[i + 1]))
75            .collect();
76
77        // Compute inverse segment widths
78        let inv_xs: Vec<f64> = delta_x.iter().map(|&dx| 2.0 / dx).collect();
79
80        // Compute normalization factors
81        let norms: Vec<f64> = inv_xs.iter().map(|&inv_x| inv_x.sqrt()).collect();
82
83        Self {
84            polyorder,
85            xmin: knots[0],
86            xmax: knots[knots.len() - 1],
87            knots,
88            delta_x,
89            data,
90            symm,
91            l,
92            xm,
93            inv_xs,
94            norms,
95        }
96    }
97
98    /// Create a new PiecewiseLegendrePoly with new data but same structure
99    pub fn with_data(&self, new_data: mdarray::DTensor<f64, 2>) -> Self {
100        Self {
101            data: new_data,
102            ..self.clone()
103        }
104    }
105
106    /// Get the symmetry parameter
107    pub fn symm(&self) -> i32 {
108        self.symm
109    }
110
111    /// Create a new PiecewiseLegendrePoly with new data and symmetry
112    pub fn with_data_and_symmetry(
113        &self,
114        new_data: mdarray::DTensor<f64, 2>,
115        new_symm: i32,
116    ) -> Self {
117        Self {
118            data: new_data,
119            symm: new_symm,
120            ..self.clone()
121        }
122    }
123
124    /// Rescale domain: create a new polynomial with the same data but different knots
125    ///
126    /// This is useful for transforming from one domain to another, e.g.,
127    /// from x ∈ [-1, 1] to τ ∈ [0, β].
128    ///
129    /// # Arguments
130    ///
131    /// * `new_knots` - New knot points
132    /// * `new_delta_x` - Optional new segment widths (computed from knots if None)
133    /// * `new_symm` - Optional new symmetry parameter (keeps old if None)
134    ///
135    /// # Returns
136    ///
137    /// New polynomial with rescaled domain
138    pub fn rescale_domain(
139        &self,
140        new_knots: Vec<f64>,
141        new_delta_x: Option<Vec<f64>>,
142        new_symm: Option<i32>,
143    ) -> Self {
144        Self::new(
145            self.data.clone(),
146            new_knots,
147            self.l,
148            new_delta_x,
149            new_symm.unwrap_or(self.symm),
150        )
151    }
152
153    /// Scale all data values by a constant factor
154    ///
155    /// This is useful for normalizations, e.g., multiplying by √β for
156    /// Fourier transform preparations.
157    ///
158    /// # Arguments
159    ///
160    /// * `factor` - Scaling factor to multiply all data by
161    ///
162    /// # Returns
163    ///
164    /// New polynomial with scaled data
165    pub fn scale_data(&self, factor: f64) -> Self {
166        Self::with_data(
167            self,
168            mdarray::DTensor::<f64, 2>::from_fn(*self.data.shape(), |idx| self.data[idx] * factor),
169        )
170    }
171
172    /// Evaluate the polynomial at a given point
173    pub fn evaluate(&self, x: f64) -> f64 {
174        let (i, x_tilde) = self.split(x);
175        // Extract column i into a Vec
176        let coeffs: Vec<f64> = (0..self.data.shape().0)
177            .map(|row| self.data[[row, i]])
178            .collect();
179        let value = self.evaluate_legendre_polynomial(x_tilde, &coeffs);
180        value * self.norms[i]
181    }
182
183    /// Evaluate the polynomial at multiple points
184    pub fn evaluate_many(&self, xs: &[f64]) -> Vec<f64> {
185        xs.iter().map(|&x| self.evaluate(x)).collect()
186    }
187
188    /// Split x into segment index and normalized x
189    pub fn split(&self, x: f64) -> (usize, f64) {
190        if x < self.xmin || x > self.xmax {
191            panic!("x = {} is outside domain [{}, {}]", x, self.xmin, self.xmax);
192        }
193
194        // Find the segment containing x
195        for i in 0..self.knots.len() - 1 {
196            if x >= self.knots[i] && x <= self.knots[i + 1] {
197                // Transform x to [-1, 1] for Legendre polynomials
198                let x_tilde = 2.0 * (x - self.xm[i]) / self.delta_x[i];
199                return (i, x_tilde);
200            }
201        }
202
203        // Handle edge case: x exactly at the last knot
204        let last_idx = self.knots.len() - 2;
205        let x_tilde = 2.0 * (x - self.xm[last_idx]) / self.delta_x[last_idx];
206        (last_idx, x_tilde)
207    }
208
209    /// Evaluate Legendre polynomial using recurrence relation
210    pub fn evaluate_legendre_polynomial(&self, x: f64, coeffs: &[f64]) -> f64 {
211        if coeffs.is_empty() {
212            return 0.0;
213        }
214
215        let mut result = 0.0;
216        let mut p_prev = 1.0; // P_0(x) = 1
217        let mut p_curr = x; // P_1(x) = x
218
219        // Add first two terms
220        if !coeffs.is_empty() {
221            result += coeffs[0] * p_prev;
222        }
223        if coeffs.len() > 1 {
224            result += coeffs[1] * p_curr;
225        }
226
227        // Use recurrence relation: P_{n+1}(x) = ((2n+1)x*P_n(x) - n*P_{n-1}(x))/(n+1)
228        for n in 1..coeffs.len() - 1 {
229            let p_next =
230                ((2.0 * (n as f64) + 1.0) * x * p_curr - (n as f64) * p_prev) / ((n + 1) as f64);
231            result += coeffs[n + 1] * p_next;
232            p_prev = p_curr;
233            p_curr = p_next;
234        }
235
236        result
237    }
238
239    /// Compute derivative of the polynomial
240    pub fn deriv(&self, n: usize) -> Self {
241        if n == 0 {
242            return self.clone();
243        }
244
245        // Compute derivative coefficients
246        let mut ddata = self.data.clone();
247        for _ in 0..n {
248            ddata = self.compute_derivative_coefficients(&ddata);
249        }
250
251        // Apply scaling factors (C++: ddata.col(i) *= std::pow(inv_xs[i], n))
252        let ddata_shape = *ddata.shape();
253        for i in 0..ddata_shape.1 {
254            let inv_x_power = self.inv_xs[i].powi(n as i32);
255            for j in 0..ddata_shape.0 {
256                ddata[[j, i]] *= inv_x_power;
257            }
258        }
259
260        // Update symmetry: C++: int new_symm = std::pow(-1, n) * symm;
261        let new_symm = if n % 2 == 0 { self.symm } else { -self.symm };
262
263        Self {
264            data: ddata,
265            symm: new_symm,
266            ..self.clone()
267        }
268    }
269
270    /// Compute derivative coefficients using the same algorithm as C++ legder function
271    fn compute_derivative_coefficients(
272        &self,
273        coeffs: &mdarray::DTensor<f64, 2>,
274    ) -> mdarray::DTensor<f64, 2> {
275        let mut c = coeffs.clone();
276        let c_shape = *c.shape();
277        let mut n = c_shape.0;
278
279        // Single derivative step (equivalent to C++ legder with cnt=1)
280        if n <= 1 {
281            return mdarray::DTensor::<f64, 2>::from_elem([1, c.shape().1], 0.0);
282        }
283
284        n -= 1;
285        let mut der = mdarray::DTensor::<f64, 2>::from_elem([n, c.shape().1], 0.0);
286
287        // C++ implementation: for (int j = n; j >= 2; --j)
288        for j in (2..=n).rev() {
289            // C++: der.row(j - 1) = (2 * j - 1) * c.row(j);
290            for col in 0..c_shape.1 {
291                der[[j - 1, col]] = (2.0 * (j as f64) - 1.0) * c[[j, col]];
292            }
293            // C++: c.row(j - 2) += c.row(j);
294            for col in 0..c_shape.1 {
295                c[[j - 2, col]] += c[[j, col]];
296            }
297        }
298
299        // C++: if (n > 1) der.row(1) = 3 * c.row(2);
300        if n > 1 {
301            for col in 0..c_shape.1 {
302                der[[1, col]] = 3.0 * c[[2, col]];
303            }
304        }
305
306        // C++: der.row(0) = c.row(1);
307        for col in 0..c_shape.1 {
308            der[[0, col]] = c[[1, col]];
309        }
310
311        der
312    }
313
314    /// Compute derivatives at a point x
315    pub fn derivs(&self, x: f64) -> Vec<f64> {
316        let mut results = Vec::new();
317
318        // Compute up to polyorder derivatives
319        for n in 0..self.polyorder {
320            let deriv_poly = self.deriv(n);
321            results.push(deriv_poly.evaluate(x));
322        }
323
324        results
325    }
326
327    /// Compute overlap integral with a function
328    pub fn overlap<F>(&self, f: F) -> f64
329    where
330        F: Fn(f64) -> f64,
331    {
332        let mut integral = 0.0;
333
334        for i in 0..self.knots.len() - 1 {
335            let segment_integral =
336                self.gauss_legendre_quadrature(self.knots[i], self.knots[i + 1], |x| {
337                    self.evaluate(x) * f(x)
338                });
339            integral += segment_integral;
340        }
341
342        integral
343    }
344
345    /// Gauss-Legendre quadrature over [a, b]
346    fn gauss_legendre_quadrature<F>(&self, a: f64, b: f64, f: F) -> f64
347    where
348        F: Fn(f64) -> f64,
349    {
350        // 5-point Gauss-Legendre quadrature
351        const XG: [f64; 5] = [
352            -0.906179845938664,
353            -0.538469310105683,
354            0.0,
355            0.538469310105683,
356            0.906179845938664,
357        ];
358        const WG: [f64; 5] = [
359            0.236926885056189,
360            0.478628670499366,
361            0.568888888888889,
362            0.478628670499366,
363            0.236926885056189,
364        ];
365
366        let c1 = (b - a) / 2.0;
367        let c2 = (b + a) / 2.0;
368
369        let mut integral = 0.0;
370        for j in 0..5 {
371            let x = c1 * XG[j] + c2;
372            integral += WG[j] * f(x);
373        }
374
375        integral * c1
376    }
377
378    /// Find roots of the polynomial using C++ compatible algorithm
379    pub fn roots(&self) -> Vec<f64> {
380        // Refine the grid by factor of 4 for better root finding
381        // (C++ uses 2, but RegularizedBoseKernel needs finer resolution)
382        let refined_grid = self.refine_grid(&self.knots, 4);
383
384        // Find all roots using the refined grid
385        self.find_all_roots(&refined_grid)
386    }
387
388    /// Refine grid by factor alpha (C++ compatible)
389    fn refine_grid(&self, grid: &[f64], alpha: usize) -> Vec<f64> {
390        let mut refined = Vec::new();
391
392        for i in 0..grid.len() - 1 {
393            let start = grid[i];
394            let step = (grid[i + 1] - grid[i]) / (alpha as f64);
395            for j in 0..alpha {
396                refined.push(start + (j as f64) * step);
397            }
398        }
399        refined.push(grid[grid.len() - 1]);
400        refined
401    }
402
403    /// Find all roots using refined grid (C++ compatible)
404    fn find_all_roots(&self, xgrid: &[f64]) -> Vec<f64> {
405        if xgrid.is_empty() {
406            return Vec::new();
407        }
408
409        // Evaluate function at all grid points
410        let fx: Vec<f64> = xgrid.iter().map(|&x| self.evaluate(x)).collect();
411
412        // Find exact zeros (direct hits)
413        let mut x_hit = Vec::new();
414        for i in 0..fx.len() {
415            if fx[i] == 0.0 {
416                x_hit.push(xgrid[i]);
417            }
418        }
419
420        // Find sign changes
421        let mut sign_change = Vec::new();
422        for i in 0..fx.len() - 1 {
423            let has_sign_change = fx[i].signum() != fx[i + 1].signum();
424            let not_hit = fx[i] != 0.0 && fx[i + 1] != 0.0;
425            let sc = has_sign_change && not_hit;
426            sign_change.push(sc);
427        }
428
429        // If no sign changes, return only direct hits
430        if sign_change.iter().all(|&sc| !sc) {
431            x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
432            return x_hit;
433        }
434
435        // Find intervals with sign changes
436        let mut a_intervals = Vec::new();
437        let mut b_intervals = Vec::new();
438        let mut fa_values = Vec::new();
439
440        for i in 0..sign_change.len() {
441            if sign_change[i] {
442                a_intervals.push(xgrid[i]);
443                b_intervals.push(xgrid[i + 1]);
444                fa_values.push(fx[i]);
445            }
446        }
447
448        // Calculate epsilon for convergence
449        let max_elm = xgrid.iter().map(|&x| x.abs()).fold(0.0, f64::max);
450        let epsilon_x = f64::EPSILON * max_elm;
451
452        // Use bisection for each interval with sign change
453        for i in 0..a_intervals.len() {
454            let root = self.bisect(a_intervals[i], b_intervals[i], fa_values[i], epsilon_x);
455            x_hit.push(root);
456        }
457
458        // Sort and return
459        x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
460        x_hit
461    }
462
463    /// Bisection method to find root (C++ compatible)
464    fn bisect(&self, a: f64, b: f64, fa: f64, eps: f64) -> f64 {
465        let mut a = a;
466        let mut b = b;
467        let mut fa = fa;
468
469        loop {
470            let mid = (a + b) / 2.0;
471            if self.close_enough(a, mid, eps) {
472                return mid;
473            }
474
475            let fmid = self.evaluate(mid);
476            if fa.signum() != fmid.signum() {
477                b = mid;
478            } else {
479                a = mid;
480                fa = fmid;
481            }
482        }
483    }
484
485    /// Check if two values are close enough (C++ compatible)
486    fn close_enough(&self, a: f64, b: f64, eps: f64) -> bool {
487        (a - b).abs() <= eps
488    }
489
490    // Accessor methods to match C++ interface
491    pub fn get_xmin(&self) -> f64 {
492        self.xmin
493    }
494    pub fn get_xmax(&self) -> f64 {
495        self.xmax
496    }
497    pub fn get_l(&self) -> i32 {
498        self.l
499    }
500    pub fn get_domain(&self) -> (f64, f64) {
501        (self.xmin, self.xmax)
502    }
503    pub fn get_knots(&self) -> &[f64] {
504        &self.knots
505    }
506    pub fn get_delta_x(&self) -> &[f64] {
507        &self.delta_x
508    }
509    pub fn get_symm(&self) -> i32 {
510        self.symm
511    }
512    pub fn get_data(&self) -> &mdarray::DTensor<f64, 2> {
513        &self.data
514    }
515    pub fn get_norms(&self) -> &[f64] {
516        &self.norms
517    }
518    pub fn get_polyorder(&self) -> usize {
519        self.polyorder
520    }
521}
522
523/// Vector of piecewise Legendre polynomials
524#[derive(Debug, Clone)]
525pub struct PiecewiseLegendrePolyVector {
526    /// Individual polynomials
527    pub polyvec: Vec<PiecewiseLegendrePoly>,
528}
529
530impl PiecewiseLegendrePolyVector {
531    /// Constructor with a vector of PiecewiseLegendrePoly
532    ///
533    /// # Panics
534    /// Panics if the input vector is empty, as empty PiecewiseLegendrePolyVector is not meaningful
535    pub fn new(polyvec: Vec<PiecewiseLegendrePoly>) -> Self {
536        if polyvec.is_empty() {
537            panic!("Cannot create empty PiecewiseLegendrePolyVector");
538        }
539        Self { polyvec }
540    }
541
542    /// Get the polynomials
543    pub fn get_polys(&self) -> &[PiecewiseLegendrePoly] {
544        &self.polyvec
545    }
546
547    /// Constructor with a 3D array, knots, and symmetry vector
548    pub fn from_3d_data(
549        data3d: mdarray::DTensor<f64, 3>,
550        knots: Vec<f64>,
551        symm: Option<Vec<i32>>,
552    ) -> Self {
553        let npolys = data3d.shape().2;
554        let mut polyvec = Vec::with_capacity(npolys);
555
556        if let Some(ref symm_vec) = symm {
557            if symm_vec.len() != npolys {
558                panic!("Sizes of data and symm don't match");
559            }
560        }
561
562        // Compute delta_x from knots
563        let delta_x: Vec<f64> = (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect();
564
565        for i in 0..npolys {
566            // Extract 2D data for this polynomial
567            let data3d_shape = data3d.shape();
568            let mut data =
569                mdarray::DTensor::<f64, 2>::from_elem([data3d_shape.0, data3d_shape.1], 0.0);
570            for j in 0..data3d_shape.0 {
571                for k in 0..data3d_shape.1 {
572                    data[[j, k]] = data3d[[j, k, i]];
573                }
574            }
575
576            let poly = PiecewiseLegendrePoly::new(
577                data,
578                knots.clone(),
579                i as i32,
580                Some(delta_x.clone()),
581                symm.as_ref().map_or(0, |s| s[i]),
582            );
583
584            polyvec.push(poly);
585        }
586
587        Self { polyvec }
588    }
589
590    /// Get the size of the vector
591    pub fn size(&self) -> usize {
592        self.polyvec.len()
593    }
594
595    /// Rescale domain for all polynomials in the vector
596    ///
597    /// Creates a new PiecewiseLegendrePolyVector where each polynomial has
598    /// the same data but new knots and delta_x.
599    ///
600    /// # Arguments
601    ///
602    /// * `new_knots` - New knot points (same for all polynomials)
603    /// * `new_delta_x` - Optional new segment widths
604    /// * `new_symm` - Optional vector of new symmetry parameters (one per polynomial)
605    ///
606    /// # Returns
607    ///
608    /// New vector with rescaled domains
609    pub fn rescale_domain(
610        &self,
611        new_knots: Vec<f64>,
612        new_delta_x: Option<Vec<f64>>,
613        new_symm: Option<Vec<i32>>,
614    ) -> Self {
615        let polyvec = self
616            .polyvec
617            .iter()
618            .enumerate()
619            .map(|(i, poly)| {
620                let symm = new_symm.as_ref().map(|s| s[i]);
621                poly.rescale_domain(new_knots.clone(), new_delta_x.clone(), symm)
622            })
623            .collect();
624
625        Self { polyvec }
626    }
627
628    /// Scale all data values by a constant factor
629    ///
630    /// Multiplies the data of all polynomials by the same factor.
631    ///
632    /// # Arguments
633    ///
634    /// * `factor` - Scaling factor to multiply all data by
635    ///
636    /// # Returns
637    ///
638    /// New vector with scaled data
639    pub fn scale_data(&self, factor: f64) -> Self {
640        let polyvec = self
641            .polyvec
642            .iter()
643            .map(|poly| poly.scale_data(factor))
644            .collect();
645
646        Self { polyvec }
647    }
648
649    /// Get polynomial by index (immutable)
650    pub fn get(&self, index: usize) -> Option<&PiecewiseLegendrePoly> {
651        self.polyvec.get(index)
652    }
653
654    /// Get polynomial by index (mutable) - deprecated, use immutable design instead
655    #[deprecated(
656        note = "PiecewiseLegendrePolyVector is designed to be immutable. Use get() and create new instances for modifications."
657    )]
658    pub fn get_mut(&mut self, index: usize) -> Option<&mut PiecewiseLegendrePoly> {
659        self.polyvec.get_mut(index)
660    }
661
662    /// Extract a single polynomial as a vector
663    pub fn slice_single(&self, index: usize) -> Option<Self> {
664        self.polyvec.get(index).map(|poly| Self {
665            polyvec: vec![poly.clone()],
666        })
667    }
668
669    /// Extract multiple polynomials by indices
670    pub fn slice_multi(&self, indices: &[usize]) -> Self {
671        // Validate indices
672        for &idx in indices {
673            if idx >= self.polyvec.len() {
674                panic!("Index {} out of range", idx);
675            }
676        }
677
678        // Check for duplicates
679        {
680            let mut unique_indices = indices.to_vec();
681            unique_indices.sort();
682            unique_indices.dedup();
683            if unique_indices.len() != indices.len() {
684                panic!("Duplicate indices not allowed");
685            }
686        }
687
688        let new_polyvec: Vec<_> = indices
689            .iter()
690            .map(|&idx| self.polyvec[idx].clone())
691            .collect();
692
693        Self {
694            polyvec: new_polyvec,
695        }
696    }
697
698    /// Evaluate all polynomials at a single point
699    pub fn evaluate_at(&self, x: f64) -> Vec<f64> {
700        self.polyvec.iter().map(|poly| poly.evaluate(x)).collect()
701    }
702
703    /// Evaluate all polynomials at multiple points
704    pub fn evaluate_at_many(&self, xs: &[f64]) -> mdarray::DTensor<f64, 2> {
705        let n_funcs = self.polyvec.len();
706        let n_points = xs.len();
707        let mut results = mdarray::DTensor::<f64, 2>::from_elem([n_funcs, n_points], 0.0);
708
709        for (i, poly) in self.polyvec.iter().enumerate() {
710            for (j, &x) in xs.iter().enumerate() {
711                results[[i, j]] = poly.evaluate(x);
712            }
713        }
714
715        results
716    }
717
718    // Accessor methods to match C++ interface
719    pub fn xmin(&self) -> f64 {
720        if self.polyvec.is_empty() {
721            panic!("Cannot get xmin from empty PiecewiseLegendrePolyVector");
722        }
723        self.polyvec[0].xmin
724    }
725
726    pub fn xmax(&self) -> f64 {
727        if self.polyvec.is_empty() {
728            panic!("Cannot get xmax from empty PiecewiseLegendrePolyVector");
729        }
730        self.polyvec[0].xmax
731    }
732
733    pub fn get_knots(&self, tolerance: Option<f64>) -> Vec<f64> {
734        if self.polyvec.is_empty() {
735            panic!("Cannot get knots from empty PiecewiseLegendrePolyVector");
736        }
737        const DEFAULT_TOLERANCE: f64 = 1e-10;
738        let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
739
740        // Collect all knots from all polynomials
741        let mut all_knots = Vec::new();
742        for poly in &self.polyvec {
743            for &knot in &poly.knots {
744                all_knots.push(knot);
745            }
746        }
747
748        // Sort and remove duplicates
749        {
750            all_knots.sort_by(|a, b| a.partial_cmp(b).unwrap());
751            all_knots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
752        }
753        all_knots
754    }
755
756    pub fn get_delta_x(&self) -> Vec<f64> {
757        if self.polyvec.is_empty() {
758            panic!("Cannot get delta_x from empty PiecewiseLegendrePolyVector");
759        }
760        self.polyvec[0].delta_x.clone()
761    }
762
763    pub fn get_polyorder(&self) -> usize {
764        if self.polyvec.is_empty() {
765            panic!("Cannot get polyorder from empty PiecewiseLegendrePolyVector");
766        }
767        self.polyvec[0].polyorder
768    }
769
770    pub fn get_norms(&self) -> &[f64] {
771        if self.polyvec.is_empty() {
772            panic!("Cannot get norms from empty PiecewiseLegendrePolyVector");
773        }
774        &self.polyvec[0].norms
775    }
776
777    pub fn get_symm(&self) -> Vec<i32> {
778        if self.polyvec.is_empty() {
779            panic!("Cannot get symm from empty PiecewiseLegendrePolyVector");
780        }
781        self.polyvec.iter().map(|poly| poly.symm).collect()
782    }
783
784    /// Get data as 3D tensor: [segment][degree][polynomial]
785    pub fn get_data(&self) -> mdarray::DTensor<f64, 3> {
786        if self.polyvec.is_empty() {
787            panic!("Cannot get data from empty PiecewiseLegendrePolyVector");
788        }
789
790        let nsegments = self.polyvec[0].data.shape().1;
791        let polyorder = self.polyvec[0].polyorder;
792        let npolys = self.polyvec.len();
793
794        let mut data = mdarray::DTensor::<f64, 3>::from_elem([nsegments, polyorder, npolys], 0.0);
795
796        for (poly_idx, poly) in self.polyvec.iter().enumerate() {
797            for segment in 0..nsegments {
798                for degree in 0..polyorder {
799                    data[[segment, degree, poly_idx]] = poly.data[[degree, segment]];
800                }
801            }
802        }
803
804        data
805    }
806
807    /// Find roots of all polynomials
808    pub fn roots(&self, tolerance: Option<f64>) -> Vec<f64> {
809        if self.polyvec.is_empty() {
810            panic!("Cannot get roots from empty PiecewiseLegendrePolyVector");
811        }
812        const DEFAULT_TOLERANCE: f64 = 1e-10;
813        let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
814        let mut all_roots = Vec::new();
815
816        for poly in &self.polyvec {
817            let poly_roots = poly.roots();
818            for root in poly_roots {
819                all_roots.push(root);
820            }
821        }
822
823        // Sort in descending order and remove duplicates (like C++ implementation)
824        {
825            all_roots.sort_by(|a, b| b.partial_cmp(a).unwrap());
826            all_roots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
827        }
828        all_roots
829    }
830
831    /// Get reference to last polynomial
832    ///
833    /// C++ equivalent: u.polyvec.back()
834    pub fn last(&self) -> &PiecewiseLegendrePoly {
835        self.polyvec
836            .last()
837            .expect("Cannot get last from empty PiecewiseLegendrePolyVector")
838    }
839
840    /// Get the number of roots
841    pub fn nroots(&self, tolerance: Option<f64>) -> usize {
842        if self.polyvec.is_empty() {
843            panic!("Cannot get nroots from empty PiecewiseLegendrePolyVector");
844        }
845        self.roots(tolerance).len()
846    }
847}
848
849impl std::ops::Index<usize> for PiecewiseLegendrePolyVector {
850    type Output = PiecewiseLegendrePoly;
851
852    fn index(&self, index: usize) -> &Self::Output {
853        &self.polyvec[index]
854    }
855}
856
857/// Get default sampling points in [-1, 1]
858///
859/// C++ implementation: libsparseir/include/sparseir/basis.hpp:287-310
860///
861/// For orthogonal polynomials (the high-T limit of IR), we know that the
862/// ideal sampling points for a basis of size L are the roots of the L-th
863/// polynomial. We empirically find that these stay good sampling points
864/// for our kernels (probably because the kernels are totally positive).
865///
866/// If we do not have enough polynomials in the basis, we approximate the
867/// roots of the L'th polynomial by the extrema of the last basis function,
868/// which is sensible due to the strong interleaving property of these
869/// functions' roots.
870pub fn default_sampling_points(u: &PiecewiseLegendrePolyVector, l: usize) -> Vec<f64> {
871    // C++: if (u.xmin() != -1.0 || u.xmax() != 1.0)
872    //          throw std::runtime_error("Expecting unscaled functions here.");
873    if (u.xmin() - (-1.0)).abs() > 1e-10 || (u.xmax() - 1.0).abs() > 1e-10 {
874        panic!("Expecting unscaled functions here.");
875    }
876
877    let x0 = if l < u.polyvec.len() {
878        // C++: return u.polyvec[L].roots();
879        u[l].roots()
880    } else {
881        // C++: PiecewiseLegendrePoly poly = u.polyvec.back();
882        //      Eigen::VectorXd maxima = poly.deriv().roots();
883        let poly = u.last();
884        let poly_deriv = poly.deriv(1);
885        let maxima = poly_deriv.roots();
886
887        // C++: double left = (maxima[0] + poly.xmin) / 2.0;
888        let left = (maxima[0] + poly.xmin) / 2.0;
889
890        // C++: double right = (maxima[maxima.size() - 1] + poly.xmax) / 2.0;
891        let right = (maxima[maxima.len() - 1] + poly.xmax) / 2.0;
892
893        // C++: Eigen::VectorXd x0(maxima.size() + 2);
894        //      x0[0] = left;
895        //      x0.segment(1, maxima.size()) = maxima;
896        //      x0[x0.size() - 1] = right;
897        let mut x0_vec = Vec::with_capacity(maxima.len() + 2);
898        x0_vec.push(left);
899        x0_vec.extend_from_slice(&maxima);
900        x0_vec.push(right);
901        x0_vec
902    };
903
904    // C++: if (x0.size() != L) { warning }
905    if x0.len() != l {
906        eprintln!(
907            "Warning: Expecting to get {} sampling points for corresponding basis function, \
908             instead got {}. This may happen if not enough precision is left in the polynomial.",
909            l,
910            x0.len()
911        );
912    }
913
914    x0
915}
916
917// IndexMut implementation removed - PiecewiseLegendrePolyVector is designed to be immutable
918// If modification is needed, create a new instance instead
919
920// Note: FnOnce implementation removed due to experimental nature
921// Use evaluate_at() and evaluate_at_many() methods directly
922
923#[cfg(test)]
924#[path = "poly_tests.rs"]
925mod poly_tests;