rsl_interpolation/types/
cubic.rs

1use ndarray::Array1;
2use ndarray_linalg::{Lapack, MatrixLayout, SolveTridiagonal, Tridiagonal};
3use num::One;
4
5use crate::Accelerator;
6use crate::DomainError;
7use crate::InterpType;
8use crate::Interpolation;
9use crate::InterpolationError;
10use crate::types::utils::integ_eval;
11use crate::types::utils::{check_if_inbounds, check1d_data, diff};
12
13const MIN_SIZE: usize = 3;
14
15/// Cubic Interpolation type.
16///
17/// Cubic Interpolation with natural boundary conditions. The resulting curve is piecewise cubic on each
18/// interval, with matching first and second derivatives at the supplied data-points. The second
19/// derivative is chosen to be zero at the first and last point.
20///
21/// ## Reference
22///
23/// Numerical Algorithms with C - Gisela Engeln-Mullges, Frank Uhlig - 1996 -
24/// Algorithm 10.1, pg 254
25#[doc(alias = "gsl_interp_cspline")]
26pub struct Cubic;
27
28impl<T> InterpType<T> for Cubic
29where
30    T: crate::Num + Lapack,
31{
32    type Interpolation = CubicInterp<T>;
33
34    /// Constructs a Cubic Interpolator.
35    ///
36    /// # Example
37    ///
38    /// ```
39    /// # use rsl_interpolation::*;
40    /// #
41    /// # fn main() -> Result<(), InterpolationError>{
42    /// let xa = [0.0, 1.0, 2.0];
43    /// let ya = [0.0, 2.0, 4.0];
44    /// let interp = Cubic.build(&xa, &ya)?;
45    /// # Ok(())
46    /// # }
47    /// ```
48    fn build(&self, xa: &[T], ya: &[T]) -> Result<CubicInterp<T>, InterpolationError> {
49        check1d_data(xa, ya, MIN_SIZE)?;
50
51        // Engeln-Mullges G. - Uhlig F.: Algorithm 10.1, pg 254
52        let sys_size = xa.len() - 2;
53
54        let h = diff(xa);
55        debug_assert_eq!(h.len(), xa.len() - 1);
56
57        let two = T::from(2).unwrap();
58        let three = T::from(3).unwrap();
59
60        // Ac=g setup
61        let mut g = Vec::<T>::with_capacity(sys_size);
62        let mut diag = Vec::<T>::with_capacity(sys_size);
63        let mut offdiag = Vec::<T>::with_capacity(sys_size);
64        for i in 0..sys_size {
65            g.push(if h[i].is_zero() {
66                T::zero()
67            } else {
68                three * (ya[i + 2] - ya[i + 1]) / h[i + 1] - three * (ya[i + 1] - ya[i]) / h[i]
69            });
70            diag.push(two * (h[i] + h[i + 1]));
71            offdiag.push(h[i + 1]);
72        }
73        // The last element of offdiag is not actually valid, by definition. Popping it is not
74        // really needed though, since the solver ignores it. However, it is needed in the
75        // CubicPeriodic case, since it represents the cyclical term.
76        offdiag.pop();
77        debug_assert_eq!(diag.len(), offdiag.len() + 1);
78
79        let matrix = Tridiagonal {
80            l: MatrixLayout::C {
81                row: (sys_size) as i32,
82                lda: (sys_size) as i32,
83            },
84            d: diag.clone(),
85            dl: offdiag.clone(),
86            du: offdiag.clone(),
87        };
88
89        // Ac=g solving
90        let mut c = Vec::<T>::with_capacity(xa.len());
91        c.push(T::zero());
92        if sys_size.is_one() {
93            c.push(g[0] / diag[0]);
94        } else {
95            let coeffs = match matrix.solve_tridiagonal(&Array1::from_vec(g.clone())) {
96                Ok(coeffs) => coeffs,
97                Err(err) => {
98                    return Err(InterpolationError::BLASTridiagError {
99                        which_interp: "Cubic".into(),
100                        source: err,
101                    });
102                }
103            };
104            c = [c, coeffs.to_vec()].concat();
105        }
106        c.push(T::zero());
107
108        // g, diag, and offdiag are only needed for the calculation of c and are not used anywhere
109        // else from this point, but lets keep them.
110        let state = CubicInterp {
111            c,
112            g,
113            diag,
114            offdiag,
115        };
116        Ok(state)
117    }
118
119    fn name(&self) -> &str {
120        "Cubic"
121    }
122
123    fn min_size(&self) -> usize {
124        MIN_SIZE
125    }
126}
127
128// ===============================================================================================
129
130/// Cubic Interpolator.
131///
132/// Provides all the evaluation methods.
133///
134/// Should be constructed through the [`Cubic`] type.
135#[allow(dead_code)]
136pub struct CubicInterp<T>
137where
138    T: crate::Num,
139{
140    c: Vec<T>,
141    g: Vec<T>,
142    diag: Vec<T>,
143    offdiag: Vec<T>,
144}
145
146impl<T> Interpolation<T> for CubicInterp<T>
147where
148    T: crate::Num + Lapack,
149{
150    fn eval(&self, xa: &[T], ya: &[T], x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
151        cubic_eval(xa, ya, &self.c, x, acc)
152    }
153
154    fn eval_deriv(
155        &self,
156        xa: &[T],
157        ya: &[T],
158        x: T,
159        acc: &mut Accelerator,
160    ) -> Result<T, DomainError> {
161        cubic_eval_deriv(xa, ya, &self.c, x, acc)
162    }
163
164    fn eval_deriv2(
165        &self,
166        xa: &[T],
167        ya: &[T],
168        x: T,
169        acc: &mut Accelerator,
170    ) -> Result<T, DomainError> {
171        cubic_eval_deriv2(xa, ya, &self.c, x, acc)
172    }
173
174    fn eval_integ(
175        &self,
176        xa: &[T],
177        ya: &[T],
178        a: T,
179        b: T,
180        acc: &mut Accelerator,
181    ) -> Result<T, DomainError> {
182        cubic_eval_integ(xa, ya, &self.c, a, b, acc)
183    }
184}
185
186//=================================================================================================
187
188/// Cubic Periodic Interpolation type.
189///
190/// Cubic Spline with periodic boundary conditions. The resulting curve is piecewise cubic on each
191/// interval, with matching first and second derivatives at the supplied data-points. The
192/// derivatives at the first and last points are also matched. Note that the last point in the data
193/// must have the same y-value as the first point, otherwise the resulting periodic interpolation
194/// will have a discontinuity at the boundary.
195///
196/// ## Reference
197///
198/// Numerical Algorithms with C - Gisela Engeln-Mullges, Frank Uhlig - 1996 -
199/// Algorithm 10.2, pg 255
200#[doc(alias = "gsl_interp_cspline_periodic")]
201pub struct CubicPeriodic;
202
203impl<T> InterpType<T> for CubicPeriodic
204where
205    T: crate::Num + Lapack,
206{
207    type Interpolation = CubicPeriodicInterp<T>;
208
209    /// Constructs a Cubic Periodic Interpolator.
210    ///
211    /// # Example
212    ///
213    /// ```
214    /// # use rsl_interpolation::*;
215    /// #
216    /// # fn main() -> Result<(), InterpolationError>{
217    /// let xa = [0.0, 1.0, 2.0];
218    /// let ya = [0.0, 2.0, 4.0];
219    /// let interp = CubicPeriodic.build(&xa, &ya)?;
220    /// # Ok(())
221    /// # }
222    /// ```
223    ///
224    fn build(&self, xa: &[T], ya: &[T]) -> Result<CubicPeriodicInterp<T>, InterpolationError> {
225        check1d_data(xa, ya, MIN_SIZE)?;
226
227        // Engeln-Mullges G. - Uhlig F.: Algorithm 10.2, pg 255
228        let sys_size = xa.len() - 1;
229
230        let h = diff(xa);
231        debug_assert!(h.len() == xa.len() - 1);
232
233        let two = T::from(2).unwrap();
234        let three = T::from(3).unwrap();
235
236        // Ac=g setup
237        let mut c = Vec::<T>::with_capacity(xa.len());
238        let mut g = Vec::<T>::with_capacity(sys_size);
239        let mut diag = Vec::<T>::with_capacity(sys_size);
240        let mut offdiag = Vec::<T>::with_capacity(sys_size);
241
242        if sys_size == 2 {
243            let h0 = xa[1] - xa[0];
244            let h1 = xa[2] - xa[1];
245
246            let a = two * (h0 + h1);
247            let b = h0 + h1;
248
249            g.push(three * ((ya[2] - ya[1]) / h1 - (ya[1] - ya[0]) / h0));
250            g.push(three * ((ya[1] - ya[2]) / h0 - (ya[2] - ya[1]) / h1));
251
252            let det = three * (h0 + h1) * (h0 + h1);
253            c.push((-b * g[0] + a * g[1]) / det);
254            c.push((a * g[0] - b * g[1]) / det);
255            c.push(c[0]);
256        } else {
257            // Same as in Cubic case
258            for i in 0..sys_size - 1 {
259                g.push(if h[i].is_zero() {
260                    T::zero()
261                } else {
262                    three * (ya[i + 2] - ya[i + 1]) / h[i + 1] - three * (ya[i + 1] - ya[i]) / h[i]
263                });
264                diag.push(two * (h[i] + h[i + 1]));
265                offdiag.push(h[i + 1]);
266            }
267
268            // But we must add the last point
269            let i = sys_size - 1;
270            let hi = xa[i + 1] - xa[i];
271            let hiplus1 = xa[1] - xa[0];
272            let ydiffi = ya[i + 1] - ya[i];
273            let ydiffplus1 = ya[1] - ya[0];
274            let gi = if !hi.is_zero() {
275                T::one() / hi
276            } else {
277                T::zero()
278            };
279            let giplus1 = if !hiplus1.is_zero() {
280                T::one() / hiplus1
281            } else {
282                T::zero()
283            };
284            offdiag.push(hiplus1);
285            diag.push(two * (hiplus1 + hi));
286            g.push(three * (ydiffplus1 * giplus1 - ydiffi * gi));
287            // offdiag's last element represents the cyclical term
288            debug_assert_eq!(diag.len(), offdiag.len());
289
290            let matrix = Tridiagonal {
291                l: MatrixLayout::C {
292                    row: (sys_size) as i32,
293                    lda: (sys_size) as i32,
294                },
295                d: diag.clone(),
296                dl: offdiag.clone(),
297                du: offdiag.clone(),
298            };
299
300            // Ac=g solving
301            c.push(T::zero());
302            if sys_size.is_one() {
303                c.push(g[0] / diag[0]);
304            } else {
305                // This must solve a cyclically tridiagonal matrix, but its not implemented yet :(
306                // The corner element is stored at the end of the offdiag vec.
307                let coeffs = match matrix.solve_tridiagonal(&Array1::from_vec(g.clone())) {
308                    Ok(coeffs) => coeffs,
309                    Err(err) => {
310                        return Err(InterpolationError::BLASTridiagError {
311                            which_interp: "Cubic Periodic".into(),
312                            source: err,
313                        });
314                    }
315                };
316                c = [c, coeffs.to_vec()].concat();
317            }
318            c[0] = c[sys_size];
319            panic!(
320                "\nNot implemented: Cubic Periodic Splines with more than 3 points require a solver for\
321                cyclically tridiagonal matrices, which is currently not implemented by ndarray_linalg.\n"
322            )
323        }
324
325        // g, diag, and offdiag are only needed for the calculation of c and are not used anywhere
326        // else from this point, but lets keep them.
327        let state = CubicPeriodicInterp {
328            c,
329            g,
330            diag,
331            offdiag,
332        };
333        Ok(state)
334    }
335
336    fn name(&self) -> &str {
337        "Cubic Periodic"
338    }
339
340    fn min_size(&self) -> usize {
341        MIN_SIZE
342    }
343}
344
345// ===============================================================================================
346
347/// Cubic Periodic interpolator.
348///
349/// Provides all the evaluation methods.
350///
351/// Should be constructed through the [`CubicPeriodic`] type.
352#[allow(dead_code)]
353#[doc(alias = "gsl_interp_cspline_periodic")]
354pub struct CubicPeriodicInterp<T>
355where
356    T: crate::Num + Lapack,
357{
358    c: Vec<T>,
359    g: Vec<T>,
360    diag: Vec<T>,
361    offdiag: Vec<T>,
362}
363
364impl<T> Interpolation<T> for CubicPeriodicInterp<T>
365where
366    T: crate::Num + Lapack,
367{
368    fn eval(&self, xa: &[T], ya: &[T], x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
369        cubic_eval(xa, ya, &self.c, x, acc)
370    }
371
372    fn eval_deriv(
373        &self,
374        xa: &[T],
375        ya: &[T],
376        x: T,
377        acc: &mut Accelerator,
378    ) -> Result<T, DomainError> {
379        cubic_eval_deriv(xa, ya, &self.c, x, acc)
380    }
381
382    fn eval_deriv2(
383        &self,
384        xa: &[T],
385        ya: &[T],
386        x: T,
387        acc: &mut Accelerator,
388    ) -> Result<T, DomainError> {
389        cubic_eval_deriv2(xa, ya, &self.c, x, acc)
390    }
391
392    fn eval_integ(
393        &self,
394        xa: &[T],
395        ya: &[T],
396        a: T,
397        b: T,
398        acc: &mut Accelerator,
399    ) -> Result<T, DomainError> {
400        cubic_eval_integ(xa, ya, &self.c, a, b, acc)
401    }
402}
403
404//=================================================================================================
405
406#[inline(always)]
407fn cubic_eval<T>(xa: &[T], ya: &[T], c: &[T], x: T, acc: &mut Accelerator) -> Result<T, DomainError>
408where
409    T: crate::Num + Lapack,
410{
411    check_if_inbounds(xa, x)?;
412    let index = acc.find(xa, x);
413
414    let xlo = xa[index];
415    let xhi = xa[index + 1];
416    let ylo = ya[index];
417    let yhi = ya[index + 1];
418
419    let dx = xhi - xlo;
420    let dy = yhi - ylo;
421
422    let delx = x - xlo;
423    let (b, c, d) = coeff_calc(c, dx, dy, index);
424
425    debug_assert!(dx > T::zero());
426    Ok(ylo + delx * (b + delx * (c + delx * d)))
427}
428
429fn cubic_eval_deriv<T>(
430    xa: &[T],
431    ya: &[T],
432    c: &[T],
433    x: T,
434    acc: &mut Accelerator,
435) -> Result<T, DomainError>
436where
437    T: crate::Num + Lapack,
438{
439    check_if_inbounds(xa, x)?;
440    let index = acc.find(xa, x);
441
442    let xlo = xa[index];
443    let xhi = xa[index + 1];
444    let ylo = ya[index];
445    let yhi = ya[index + 1];
446
447    let dx = xhi - xlo;
448    let dy = yhi - ylo;
449
450    let delx = x - xlo;
451    let (b, c, d) = coeff_calc(c, dx, dy, index);
452
453    let two = T::from(2).unwrap();
454    let three = T::from(3).unwrap();
455
456    debug_assert!(dx > T::zero());
457    Ok(b + delx * (two * c + three * d * delx))
458}
459
460#[inline(always)]
461fn cubic_eval_deriv2<T>(
462    xa: &[T],
463    ya: &[T],
464    c: &[T],
465    x: T,
466    acc: &mut Accelerator,
467) -> Result<T, DomainError>
468where
469    T: crate::Num + Lapack,
470{
471    check_if_inbounds(xa, x)?;
472    let index = acc.find(xa, x);
473
474    let xlo = xa[index];
475    let xhi = xa[index + 1];
476    let ylo = ya[index];
477    let yhi = ya[index + 1];
478
479    let dx = xhi - xlo;
480    let dy = yhi - ylo;
481
482    let delx = x - xlo;
483    let (_, c, d) = coeff_calc(c, dx, dy, index);
484
485    let two = T::from(2).unwrap();
486    let six = T::from(6).unwrap();
487
488    debug_assert!(dx > T::zero());
489    Ok(two * c + six * delx * d)
490}
491
492#[inline(always)]
493fn cubic_eval_integ<T>(
494    xa: &[T],
495    ya: &[T],
496    c: &[T],
497    a: T,
498    b: T,
499    acc: &mut Accelerator,
500) -> Result<T, DomainError>
501where
502    T: crate::Num + Lapack,
503{
504    check_if_inbounds(xa, a)?;
505    check_if_inbounds(xa, b)?;
506    let index_a = acc.find(xa, a);
507    let index_b = acc.find(xa, b);
508
509    let quarter = T::from(0.25).unwrap();
510    let half = T::from(0.5).unwrap();
511    let third = T::from(1.0 / 3.0).unwrap();
512
513    let mut result = T::zero();
514
515    for i in index_a..=index_b {
516        let xlo = xa[i];
517        let xhi = xa[i + 1];
518        let ylo = ya[i];
519        let yhi = ya[i + 1];
520
521        let dx = xhi - xlo;
522        let dy = yhi - ylo;
523
524        // If two x points are the same
525        if dx.is_zero() {
526            continue;
527        }
528
529        let (bi, ci, di) = coeff_calc(c, dx, dy, i);
530
531        if (i == index_a) | (i == index_b) {
532            let x1 = if i == index_a { a } else { xlo };
533            let x2 = if i == index_b { b } else { xhi };
534            result += integ_eval(ylo, bi, ci, di, xlo, x1, x2);
535        } else {
536            result += dx * (ylo + dx * (half * bi + dx * (third * ci + quarter * di * dx)))
537        }
538    }
539    Ok(result)
540}
541/// Function for common coefficient determination. No inline.
542fn coeff_calc<T>(carray: &[T], dx: T, dy: T, index: usize) -> (T, T, T)
543where
544    T: crate::Num + Lapack,
545{
546    let two = T::from(2).unwrap();
547    let three = T::from(3).unwrap();
548
549    let c = carray[index];
550    let cplus1 = carray[index + 1];
551
552    let b = (dy / dx) - dx * (cplus1 + two * c) / three;
553    let d = (cplus1 - c) / (three * dx);
554    (b, c, d)
555}