ndarray_interp/interp1d/strategies/
cubic_spline.rs

1//! The Cubic Spline interpolation stategy
2//!
3//! This module defines the [`CubicSpline`] struct which can be used with
4//! [`Interp1DBuilder::strategy()`](super::super::Interp1DBuilder::strategy).
5//!
6//! # Boundary conditions
7//! The Cubic Spline Strategy can be customized with bounday conditions.
8//! There are 3 Levels of boundary conditions:
9//!  - [`BoundaryCondition`] The toplevel boundary applys to the whole dataset
10//!  - [`RowBoundary`] applys to a single row in the dataset (use with [`BoundaryCondition::Individual`])
11//!  - [`SingleBoundary`] applys to an individual boundary of a single row (use with [`RowBoundary::Mixed`])
12//!
13
14use std::{
15    fmt::Debug,
16    ops::{Add, Neg, Sub, SubAssign},
17};
18
19use ndarray::{
20    s, Array, Array1, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, Dimension, FoldWhile, Ix1,
21    IxDyn, RemoveAxis, ScalarOperand, Slice, Zip,
22};
23use num_traits::{cast, Euclid, Num, NumCast, Pow};
24
25use crate::{interp1d::Interp1D, BuilderError, InterpolateError};
26
27use super::{Interp1DStrategy, Interp1DStrategyBuilder};
28
29const AX0: Axis = Axis(0);
30
31/// Marker trait that is implemented for anything that satisfies
32/// the trait bounds required to be used as an element in the CubicSpline
33/// strategy.
34pub trait SplineNum:
35    Debug
36    + Num
37    + Copy
38    + PartialOrd
39    + Sub
40    + SubAssign
41    + Neg<Output = Self>
42    + NumCast
43    + Add
44    + Pow<Self, Output = Self>
45    + ScalarOperand
46    + Euclid
47    + Send
48{
49}
50
51/// The CubicSpline 1d interpolation Strategy (Builder)
52///
53/// # Example
54/// From [Wikipedia](https://en.wikipedia.org/wiki/Spline_interpolation#Example)
55/// ```
56/// # use ndarray_interp::*;
57/// # use ndarray_interp::interp1d::*;
58///  # use ndarray_interp::interp1d::cubic_spline::*;
59/// # use ndarray::*;
60/// # use approx::*;
61///
62/// let y = array![ 0.5, 0.0, 3.0];
63/// let x = array![-1.0, 0.0, 3.0];
64/// let query = Array::linspace(-1.0, 3.0, 10);
65/// let interpolator = Interp1DBuilder::new(y)
66///     .strategy(CubicSpline::new())
67///     .x(x)
68///     .build().unwrap();
69///
70/// let result = interpolator.interp_array(&query).unwrap();
71/// let expect = array![
72///     0.5,
73///     0.1851851851851852,
74///     0.01851851851851853,
75///     -5.551115123125783e-17,
76///     0.12962962962962965,
77///     0.40740740740740755,
78///     0.8333333333333331,
79///     1.407407407407407,
80///     2.1296296296296293, 3.0
81/// ];
82/// # assert_abs_diff_eq!(result, expect, epsilon=f64::EPSILON);
83/// ```
84#[derive(Debug)]
85pub struct CubicSpline<T, D: Dimension> {
86    extrapolate: bool,
87    boundary: BoundaryCondition<T, D>,
88}
89
90/// The CubicSpline 1d interpolation Strategy (Implementation)
91///
92/// This is constructed by [`CubicSpline`]
93#[derive(Debug)]
94pub struct CubicSplineStrategy<Sd, D>
95where
96    Sd: Data,
97    D: Dimension + RemoveAxis,
98{
99    a: Array<Sd::Elem, D>,
100    b: Array<Sd::Elem, D>,
101    extrapolate: Extrapolate,
102}
103
104/// Boundary conditions for the whole dataset
105///
106/// The boundary condition is structured in three hirarchic enum's:
107///  - [`BoundaryCondition`] The toplevel boundary applys to the whole dataset
108///  - [`RowBoundary`] applys to a single row in the dataset
109///  - [`SingleBoundary`] applys to an individual boundary of a single row
110///
111/// the default is the [`NotAKnot`](BoundaryCondition::NotAKnot) boundary in each level
112///
113/// There are different possibilities for the boundary condition in each level:
114///  - [`NotAKnot`](BoundaryCondition::NotAKnot) - all levels
115///  - [`Natural`](BoundaryCondition::Natural) - all levels (same as `SecondDeriv(0.0)`)
116///  - [`Clamped`](BoundaryCondition::Clamped) - all levels (same as `FirstDeriv(0.0)`)
117///  - [`Periodic`](BoundaryCondition::Periodic) - not in [`SingleBoundary`]
118///  - [`FirstDeriv`](SingleBoundary::FirstDeriv) - only in [`SingleBoundary`]
119///  - [`SecondDeriv`](SingleBoundary::SecondDeriv) - only in [`SingleBoundary`]
120///
121/// ## Example
122/// In a complex case all boundaries can be set individually:
123/// ``` rust
124/// # use ndarray_interp::*;
125/// # use ndarray_interp::interp1d::*;
126/// # use ndarray_interp::interp1d::cubic_spline::*;
127/// # use ndarray::*;
128/// # use approx::*;
129///
130/// let y = array![
131///     [0.5, 1.0],
132///     [0.0, 1.5],
133///     [3.0, 0.5],
134/// ];
135/// let x = array![-1.0, 0.0, 3.0];
136///
137/// // first data column: natural
138/// // second data column top: NotAKnot
139/// // second data column bottom: first derivative == 0.5
140/// let boundaries = array![
141///     [
142///         RowBoundary::Natural,
143///         RowBoundary::Mixed { left: SingleBoundary::NotAKnot, right: SingleBoundary::FirstDeriv(0.5)}
144///     ],
145/// ];
146/// let strat = CubicSpline::new().boundary(BoundaryCondition::Individual(boundaries));
147/// let interpolator = Interp1DBuilder::new(y)
148///     .x(x)
149///     .strategy(strat)
150///     .build().unwrap();
151///
152/// ```
153#[derive(Debug, PartialEq, Eq)]
154pub enum BoundaryCondition<T, D: Dimension> {
155    /// Not a knot boundary. The first and second segment at a curve end are the same polynomial.
156    NotAKnot,
157    /// Natural boundary. The second derivative at the curve end is 0
158    Natural,
159    /// Clamped boundary. The first derivative at the curve end is 0
160    Clamped,
161    /// Periodic spline.
162    /// The interpolated functions is assumed to be periodic.
163    /// The first and last element in the data must be equal.
164    Periodic,
165    /// Set individual boundary conditions for each row in the data
166    /// and/or individual conditions for the left and right boundary
167    Individual(Array<RowBoundary<T>, D>),
168}
169
170/// Boundary condition for a single data row
171#[derive(Debug, PartialEq, Eq, Clone)]
172pub enum RowBoundary<T> {
173    /// ![`BoundaryCondition::NotAKnot`]
174    NotAKnot,
175    /// ![`BoundaryCondition::Natural`]
176    Natural,
177    /// ![`BoundaryCondition::Clamped`]
178    Clamped,
179    /// Set individual boundary conditions at the left and right end of the curve
180    Mixed {
181        left: SingleBoundary<T>,
182        right: SingleBoundary<T>,
183    },
184}
185
186/// This is essentially [`RowBoundary`] but including the Periodic variant.
187/// The periodic variant can not be applied to a single row only all or nothing.
188/// But we still need it for calculating the coefficients, which may or may not be done
189/// for each row individually.
190#[derive(Debug)]
191enum InternalBoundary<T> {
192    NotAKnot,
193    Natural,
194    Clamped,
195    Periodic,
196    Mixed {
197        left: SingleBoundary<T>,
198        right: SingleBoundary<T>,
199    },
200}
201
202/// Boundary condition for a single boundary (one side of one data row)
203#[derive(Debug, PartialEq, Eq, Clone)]
204pub enum SingleBoundary<T> {
205    /// ![`BoundaryCondition::NotAKnot`]
206    NotAKnot,
207    /// This ist the same as `SingleBoundary::SecondDeriv(0.0)`
208    /// ![`BoundaryCondition::Natural`]
209    Natural,
210    /// This ist the same as `SingleBoundary::FirstDeriv(0.0)`
211    /// ![`BoundaryCondition::Clamped`]
212    Clamped,
213    /// Set a value for the first derivative at the curve end
214    FirstDeriv(T),
215    /// Set a value for the second derivative at the curve end
216    SecondDeriv(T),
217}
218
219#[derive(Debug)]
220enum Extrapolate {
221    Yes,
222    No,
223    Periodic,
224}
225
226impl<T> SplineNum for T where
227    T: Debug
228        + Num
229        + Copy
230        + PartialOrd
231        + Sub
232        + SubAssign
233        + Neg<Output = T>
234        + NumCast
235        + Add
236        + Pow<Self, Output = Self>
237        + ScalarOperand
238        + Euclid
239        + Send
240{
241}
242
243impl<T, D: Dimension> Default for BoundaryCondition<T, D> {
244    fn default() -> Self {
245        Self::NotAKnot
246    }
247}
248
249impl<T: SplineNum> Default for RowBoundary<T> {
250    fn default() -> Self {
251        Self::NotAKnot
252    }
253}
254
255impl<T: SplineNum> InternalBoundary<T> {
256    fn specialize(self) -> Self {
257        use SingleBoundary::*;
258        match self {
259            InternalBoundary::Natural => Self::Mixed {
260                left: Natural,
261                right: Natural,
262            },
263            InternalBoundary::NotAKnot => Self::Mixed {
264                left: NotAKnot,
265                right: NotAKnot,
266            },
267            InternalBoundary::Clamped => Self::Mixed {
268                left: Clamped,
269                right: Clamped,
270            },
271            _ => self,
272        }
273    }
274}
275
276impl<T> From<RowBoundary<T>> for InternalBoundary<T> {
277    fn from(val: RowBoundary<T>) -> Self {
278        match val {
279            RowBoundary::NotAKnot => InternalBoundary::NotAKnot,
280            RowBoundary::Natural => InternalBoundary::Natural,
281            RowBoundary::Clamped => InternalBoundary::Clamped,
282            RowBoundary::Mixed { left, right } => InternalBoundary::Mixed { left, right },
283        }
284    }
285}
286
287impl<T: SplineNum> SingleBoundary<T> {
288    fn specialize(self) -> Self {
289        use SingleBoundary::*;
290        match self {
291            SingleBoundary::Natural => SecondDeriv(cast(0.0).unwrap_or_else(|| unimplemented!())),
292            SingleBoundary::Clamped => FirstDeriv(cast(0.0).unwrap_or_else(|| unimplemented!())),
293            _ => self,
294        }
295    }
296}
297
298impl<T: SplineNum> Default for SingleBoundary<T> {
299    fn default() -> Self {
300        Self::NotAKnot
301    }
302}
303
304impl<T, D> CubicSpline<T, D>
305where
306    D: Dimension + RemoveAxis,
307    T: SplineNum,
308{
309    /// Calculate the coefficients `a` and `b`
310    fn calc_coefficients<Sd, Sx>(
311        &self,
312        x: &ArrayBase<Sx, Ix1>,
313        data: &ArrayBase<Sd, D>,
314    ) -> Result<(Array<Sd::Elem, D>, Array<Sd::Elem, D>), BuilderError>
315    where
316        Sd: Data<Elem = T>,
317        Sx: Data<Elem = T>,
318    {
319        let dim = data.raw_dim();
320        let len = dim[0];
321        let mut k = Array::zeros(dim.clone());
322        let kv = k.view_mut();
323        match self.boundary {
324            BoundaryCondition::Periodic => {
325                Self::solve_for_k(kv, x, data, InternalBoundary::Periodic)
326            }
327            BoundaryCondition::Natural => Self::solve_for_k(kv, x, data, InternalBoundary::Natural),
328            BoundaryCondition::Clamped => Self::solve_for_k(kv, x, data, InternalBoundary::Clamped),
329            BoundaryCondition::NotAKnot => {
330                Self::solve_for_k(kv, x, data, InternalBoundary::NotAKnot)
331            }
332            BoundaryCondition::Individual(ref bounds) => {
333                let mut bounds_shape = kv.raw_dim();
334                bounds_shape[0] = 1;
335                if bounds_shape != bounds.raw_dim() {
336                    return Err(BuilderError::ShapeError(format!(
337                        "Boundary conditions array has wrong shape. Expected: {bounds_shape:?}, got: {:?}",
338                        bounds.raw_dim()
339                    )));
340                }
341                Self::solve_for_k_individual(
342                    kv.into_dyn(),
343                    x,
344                    data.view().into_dyn(),
345                    bounds.view().into_dyn(),
346                )
347            }
348        }?;
349
350        let mut a_b_dim = data.raw_dim();
351        a_b_dim[0] -= 1;
352        let mut c_a = Array::zeros(a_b_dim.clone());
353        let mut c_b = Array::zeros(a_b_dim);
354        for index in 0..len - 1 {
355            Zip::from(c_a.index_axis_mut(AX0, index))
356                .and(c_b.index_axis_mut(AX0, index))
357                .and(k.index_axis(AX0, index))
358                .and(k.index_axis(AX0, index + 1))
359                .and(data.index_axis(AX0, index))
360                .and(data.index_axis(AX0, index + 1))
361                .for_each(|c_a, c_b, &k, &k_right, &y, &y_right| {
362                    *c_a = k * (x[index + 1] - x[index]) - (y_right - y);
363                    *c_b = (y_right - y) - k_right * (x[index + 1] - x[index]);
364                })
365        }
366
367        Ok((c_a, c_b))
368    }
369
370    fn solve_for_k_individual<Sx>(
371        mut k: ArrayViewMut<T, IxDyn>,
372        x: &ArrayBase<Sx, Ix1>,
373        data: ArrayView<T, IxDyn>,
374        boundary: ArrayView<RowBoundary<T>, IxDyn>,
375    ) -> Result<(), BuilderError>
376    where
377        Sx: Data<Elem = T>,
378    {
379        if k.ndim() > 1 {
380            let ax = Axis(k.ndim() - 1);
381            Zip::from(k.axis_iter_mut(ax))
382                .and(data.axis_iter(ax))
383                .and(boundary.axis_iter(ax))
384                .fold_while(Ok(()), |_, k, data, boundary| {
385                    Self::solve_for_k_individual(k, x, data, boundary).map_or_else(
386                        |err| FoldWhile::Done(Err(err)),
387                        |_| FoldWhile::Continue(Ok(())),
388                    )
389                })
390                .into_inner()
391        } else {
392            Self::solve_for_k(
393                k,
394                x,
395                &data,
396                boundary
397                    .first()
398                    .cloned()
399                    .unwrap_or_else(|| unreachable!())
400                    .into(),
401            )
402        }
403    }
404
405    /// solves the linear equation `A * k = rhs` with the [`RowBoundary`] used for
406    /// each row in the data
407    ///  
408    /// **returns** k
409    fn solve_for_k<Sd, Sx, _D>(
410        mut k: ArrayViewMut<T, _D>,
411        x: &ArrayBase<Sx, Ix1>,
412        data: &ArrayBase<Sd, _D>,
413        boundary: InternalBoundary<T>,
414    ) -> Result<(), BuilderError>
415    where
416        _D: Dimension + RemoveAxis,
417        Sd: Data<Elem = T>,
418        Sx: Data<Elem = T>,
419    {
420        let dim = data.raw_dim();
421        let len = dim[0];
422
423        /*
424         * Calculate the coefficients c_a and c_b for the cubic spline the method is outlined on
425         * https://en.wikipedia.org/wiki/Spline_interpolation#Example
426         *
427         * This requires solving the Linear equation A * k = rhs
428         */
429
430        // upper, middle and lower diagonal of A
431        let mut a_up = Array::zeros(len);
432        let mut a_mid = Array::zeros(len);
433        let mut a_low = Array::zeros(len);
434
435        let zero: T = cast(0.0).unwrap_or_else(|| unimplemented!());
436        let one: T = cast(1.0).unwrap_or_else(|| unimplemented!());
437        let two: T = cast(2.0).unwrap_or_else(|| unimplemented!());
438        let three: T = cast(3.0).unwrap_or_else(|| unimplemented!());
439
440        Zip::from(a_up.slice_mut(s![1..-1]))
441            .and(a_mid.slice_mut(s![1..-1]))
442            .and(a_low.slice_mut(s![1..-1]))
443            .and(x.windows(3))
444            .for_each(|a_up, a_mid, a_low, x| {
445                let dxn = x[2] - x[1];
446                let dxn_1 = x[1] - x[0];
447
448                *a_up = dxn_1;
449                *a_mid = two * (dxn + dxn_1);
450                *a_low = dxn;
451            });
452
453        // RHS vector
454        let mut rhs = Array::zeros(dim.clone());
455
456        for n in 1..len - 1 {
457            let rhs = rhs.index_axis_mut(AX0, n);
458            let y_left = data.index_axis(AX0, n - 1);
459            let y_mid = data.index_axis(AX0, n);
460            let y_right = data.index_axis(AX0, n + 1);
461
462            let dxn = x[n + 1] - x[n]; // dx(n)
463            let dxn_1 = x[n] - x[n - 1]; // dx(n-1)
464
465            Zip::from(y_left).and(y_mid).and(y_right).map_assign_into(
466                rhs,
467                |&y_left, &y_mid, &y_right| {
468                    three * (dxn * (y_mid - y_left) / dxn_1 + dxn_1 * (y_right - y_mid) / dxn)
469                },
470            );
471        }
472
473        let dx0 = x[1] - x[0];
474        let dx1 = x[2] - x[1];
475        let dx_1 = x[len - 1] - x[len - 2];
476        let dx_2 = x[len - 2] - x[len - 3];
477
478        // apply boundary conditions
479        match (boundary.specialize(), len) {
480            (InternalBoundary::Periodic, 3) => {
481                let y0 = data.index_axis(AX0, 0);
482                let y2 = data.index_axis(AX0, 2);
483                if y0 != y2 {
484                    if data.ndim() == 1 {
485                        return Err(BuilderError::ValueError(format!("for periodic boundary condition the first and last value must be equal. First: {:?}, last: {:?}", data.first().unwrap_or_else(||unreachable!()), data.last().unwrap_or_else(||unreachable!()))));
486                    } else {
487                        return Err(BuilderError::ValueError(format!("for periodic boundary condition the first and last value must be equal. First: {y0:?}, last: {y2:?}")));
488                    }
489                }
490
491                let y1 = data.index_axis(AX0, 1);
492                let slope0: Array<T, _D::Smaller> = (&y1 - &y0) / dx0;
493                let slope1: Array<T, _D::Smaller> = (&y2 - &y1) / dx1;
494                k.assign(&((slope0 / dx0 + slope1 / dx1) / (one / dx0 + one / dx1)));
495                return Ok(());
496            }
497
498            (InternalBoundary::Periodic, _) => {
499                let y0 = data.index_axis(AX0, 0);
500                let y_1 = data.index_axis(AX0, len - 1);
501                if y0 != y_1 {
502                    if data.ndim() == 1 {
503                        return Err(BuilderError::ValueError(format!("for periodic boundary condition the first and last value must be equal. First: {:?}, last: {:?}", data.first().unwrap_or_else(||unreachable!()), data.last().unwrap_or_else(||unreachable!()))));
504                    } else {
505                        return Err(BuilderError::ValueError(format!("for periodic boundary condition the first and last value must be equal. First: {y0:?}, last: {y_1:?}")));
506                    }
507                }
508
509                // due to the preriodicity we need to solve one less equation
510                // the system matrix a is also condensed
511                // https://web.archive.org/web/20151220180652/http://www.cfm.brown.edu/people/gk/chap6/node14.html
512                a_up.slice_axis_inplace(AX0, Slice::from(0..-2));
513                a_mid.slice_axis_inplace(AX0, Slice::from(0..-2));
514                a_low.slice_axis_inplace(AX0, Slice::from(0..-2));
515                rhs.slice_axis_inplace(AX0, Slice::from(0..-1));
516
517                a_mid[0] = two * (dx_1 + dx0);
518                a_up[0] = dx_1;
519
520                let y1 = data.index_axis(AX0, 1);
521                let slope0: Array<T, _D::Smaller> = (&y1 - &y0) / dx0;
522
523                let y_1 = data.index_axis(AX0, len - 1);
524                let y_2 = data.index_axis(AX0, len - 2);
525                let y_3 = data.index_axis(AX0, len - 3);
526                let slope_1: Array<T, _D::Smaller> = (&y_1 - &y_2) / dx_1;
527                let slope_2: Array<T, _D::Smaller> = (&y_2 - &y_3) / dx_2;
528
529                rhs.index_axis_mut(AX0, 0)
530                    .assign(&((&slope_1 * dx0 + &slope0 * dx_1) * three));
531                rhs.index_axis_mut(AX0, len - 1 - 1)
532                    .assign(&((slope_2 * dx_1 + slope_1 * dx_2) * three));
533
534                let rhs1 = rhs.slice_axis(AX0, Slice::from(0..-1)).to_owned();
535                let mut rhs2 = Array::zeros(rhs1.raw_dim());
536                rhs2.index_axis_mut(AX0, 0).fill(-dx0); // = -dx0;
537                let dx_3 = x[len - 3] - x[len - 4];
538                rhs2.index_axis_mut(AX0, len - 3).fill(-dx_3);
539
540                let mut k1 = Array::zeros(rhs1.raw_dim());
541                let mut k2 = Array::zeros(rhs1.raw_dim());
542
543                Self::thomas(
544                    k1.view_mut(),
545                    a_up.clone(),
546                    a_mid.clone(),
547                    a_low.clone(),
548                    rhs1,
549                );
550                Self::thomas(k2.view_mut(), a_up, a_mid, a_low, rhs2);
551
552                let k_m1 = (&rhs.index_axis(AX0, len - 2)
553                    - &k1.index_axis(AX0, 0) * dx_2
554                    - &k1.index_axis(AX0, len - 3) * dx_1)
555                    / (&k2.index_axis(AX0, 0) * dx_2
556                        + &k2.index_axis(AX0, len - 3) * dx_1
557                        + two * (dx_1 + dx_2));
558
559                k.slice_axis_mut(AX0, Slice::from(0..-2))
560                    .assign(&(k1 + &k_m1 * k2));
561                k.index_axis_mut(AX0, len - 2).assign(&k_m1);
562                let k0 = k.index_axis(AX0, 0).to_owned();
563                k.index_axis_mut(AX0, len - 1).assign(&k0);
564                return Ok(());
565            }
566            (InternalBoundary::Clamped, _) => unreachable!(),
567            (InternalBoundary::Natural, _) => unreachable!(),
568            (InternalBoundary::NotAKnot, _) => unreachable!(),
569            (
570                InternalBoundary::Mixed {
571                    left: SingleBoundary::NotAKnot,
572                    right: SingleBoundary::NotAKnot,
573                },
574                3,
575            ) => {
576                // We handle this case by constructing a parabola passing through given points.
577
578                let y0 = data.index_axis(AX0, 0);
579                let y1 = data.index_axis(AX0, 1);
580                let y2 = data.index_axis(AX0, 2);
581                let slope0 = (y1.to_owned() - y0) / dx0;
582                let slope1 = (y2.to_owned() - y1) / dx1;
583
584                a_mid[0] = one; // [0, 0]
585                a_up[0] = one; // [0, 1]
586                a_low[1] = dx1; // [1, 0]
587                a_mid[1] = two * (dx0 + dx1); // [1, 1]
588                a_up[1] = dx0; // [1, 2]
589                a_low[2] = one; // [2, 1]
590                a_mid[2] = one; // [2, 2]
591
592                rhs.index_axis_mut(AX0, 0).assign(&(&slope0 * two));
593                rhs.index_axis_mut(AX0, 1)
594                    .assign(&((&slope1 * dx0 + &slope0 * dx1) * three));
595                rhs.index_axis_mut(AX0, 2).assign(&(slope1 * two));
596            }
597            (InternalBoundary::Mixed { left, right }, _) => {
598                match left.specialize() {
599                    SingleBoundary::NotAKnot => {
600                        a_mid[0] = dx1;
601                        let d = x[2] - x[0];
602                        a_up[0] = d;
603                        let tmp1 = (dx0 + two * d) * dx1;
604                        Zip::from(rhs.index_axis_mut(AX0, 0))
605                            .and(data.index_axis(AX0, 0))
606                            .and(data.index_axis(AX0, 1))
607                            .and(data.index_axis(AX0, 2))
608                            .for_each(|b, &y0, &y1, &y2| {
609                                *b = (tmp1 * (y1 - y0) / dx0 + dx0.pow(two) * (y2 - y1) / dx1) / d;
610                            });
611                    }
612                    SingleBoundary::Natural => unreachable!(),
613                    SingleBoundary::Clamped => unreachable!(),
614                    SingleBoundary::FirstDeriv(deriv) => {
615                        a_mid[0] = one;
616                        a_up[0] = zero;
617                        rhs.index_axis_mut(AX0, 0).fill(deriv);
618                    }
619                    SingleBoundary::SecondDeriv(deriv) => {
620                        a_up[0] = dx0;
621                        a_mid[0] = two * dx0;
622                        let rhs_0 = rhs.index_axis_mut(AX0, 0);
623                        let data_0 = data.index_axis(AX0, 0);
624                        let data_1 = data.index_axis(AX0, 1);
625                        Zip::from(rhs_0)
626                            .and(data_0)
627                            .and(data_1)
628                            .for_each(|rhs_0, &y_0, &y_1| {
629                                *rhs_0 = three * (y_1 - y_0) - deriv * dx0.pow(two) / two;
630                            });
631                    }
632                };
633                match right.specialize() {
634                    SingleBoundary::NotAKnot => {
635                        a_mid[len - 1] = dx_1;
636                        let d = x[len - 1] - x[len - 3];
637                        a_low[len - 1] = d;
638                        let tmp1 = (two * d + dx_1) * dx_2;
639                        Zip::from(rhs.index_axis_mut(AX0, len - 1))
640                            .and(data.index_axis(AX0, len - 1))
641                            .and(data.index_axis(AX0, len - 2))
642                            .and(data.index_axis(AX0, len - 3))
643                            .for_each(|b, &y_1, &y_2, &y_3| {
644                                *b = (dx_1.pow(two) * (y_2 - y_3) / dx_2
645                                    + tmp1 * (y_1 - y_2) / dx_1)
646                                    / d;
647                            });
648                    }
649                    SingleBoundary::Natural => unreachable!(),
650                    SingleBoundary::Clamped => unreachable!(),
651                    SingleBoundary::FirstDeriv(deriv) => {
652                        a_mid[len - 1] = one;
653                        a_low[len - 1] = zero;
654                        rhs.index_axis_mut(AX0, len - 1).fill(deriv);
655                    }
656                    SingleBoundary::SecondDeriv(deriv) => {
657                        a_mid[len - 1] = two * dx_1;
658                        a_low[len - 1] = dx_1;
659                        let rhs_n = rhs.index_axis_mut(AX0, len - 1);
660                        let data_n = data.index_axis(AX0, len - 1);
661                        let data_n1 = data.index_axis(AX0, len - 2);
662                        Zip::from(rhs_n)
663                            .and(data_n)
664                            .and(data_n1)
665                            .for_each(|rhs_n, &y_n, &y_n1| {
666                                *rhs_n = three * (y_n - y_n1) + deriv * dx_1.pow(two) / two;
667                            });
668                    }
669                };
670            }
671        }
672        Self::thomas(k, a_up, a_mid, a_low, rhs);
673        Ok(())
674    }
675
676    /// The Thomas algorithm is used, because the matrix A will be tridiagonal and diagonally dominant
677    /// [https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm]
678    fn thomas<_D>(
679        mut k: ArrayViewMut<T, _D>,
680        a_up: Array1<T>,
681        mut a_mid: Array1<T>,
682        a_low: Array1<T>,
683        mut rhs: Array<T, _D>,
684    ) where
685        _D: Dimension + RemoveAxis,
686    {
687        let dim = rhs.raw_dim();
688        let len = dim[0];
689        let mut rhs_left = rhs.index_axis(AX0, 0).into_owned();
690        for i in 1..len {
691            let w = a_low[i] / a_mid[i - 1];
692            a_mid[i] -= w * a_up[i - 1];
693
694            let rhs = rhs.index_axis_mut(AX0, i);
695            Zip::from(rhs)
696                .and(rhs_left.view_mut())
697                .for_each(|rhs, rhs_left| {
698                    let new_rhs = *rhs - w * *rhs_left;
699                    *rhs = new_rhs;
700                    *rhs_left = new_rhs;
701                });
702        }
703
704        Zip::from(k.index_axis_mut(AX0, len - 1))
705            .and(rhs.index_axis(AX0, len - 1))
706            .for_each(|k, &rhs| {
707                *k = rhs / a_mid[len - 1];
708            });
709
710        let mut k_right = k.index_axis(AX0, len - 1).into_owned();
711        for i in (0..len - 1).rev() {
712            Zip::from(k.index_axis_mut(AX0, i))
713                .and(k_right.view_mut())
714                .and(rhs.index_axis(AX0, i))
715                .for_each(|k, k_right, &rhs| {
716                    let new_k = (rhs - a_up[i] * *k_right) / a_mid[i];
717                    *k = new_k;
718                    *k_right = new_k;
719                })
720        }
721    }
722
723    /// create a cubic-spline interpolation stratgy
724    pub fn new() -> Self {
725        Self {
726            extrapolate: false,
727            boundary: BoundaryCondition::NotAKnot,
728        }
729    }
730
731    /// does the strategy extrapolate? Default is `false`
732    pub fn extrapolate(mut self, extrapolate: bool) -> Self {
733        self.extrapolate = extrapolate;
734        self
735    }
736
737    /// set the boundary condition. default is [`BoundaryCondition::Natural`]
738    pub fn boundary(mut self, boundary: BoundaryCondition<T, D>) -> Self {
739        self.boundary = boundary;
740        self
741    }
742}
743
744impl<Sd, Sx, D> Interp1DStrategyBuilder<Sd, Sx, D> for CubicSpline<Sd::Elem, D>
745where
746    Sd: Data,
747    Sd::Elem: SplineNum,
748    Sx: Data<Elem = Sd::Elem>,
749    D: Dimension + RemoveAxis,
750{
751    const MINIMUM_DATA_LENGHT: usize = 3;
752    type FinishedStrat = CubicSplineStrategy<Sd, D>;
753
754    fn build<Sx2>(
755        self,
756        x: &ArrayBase<Sx2, Ix1>,
757        data: &ArrayBase<Sd, D>,
758    ) -> Result<Self::FinishedStrat, BuilderError>
759    where
760        Sx2: Data<Elem = Sd::Elem>,
761    {
762        let (a, b) = self.calc_coefficients(x, data)?;
763        let extrapolate = if !self.extrapolate {
764            Extrapolate::No
765        } else if matches!(self.boundary, BoundaryCondition::Periodic) {
766            Extrapolate::Periodic
767        } else {
768            Extrapolate::Yes
769        };
770        Ok(CubicSplineStrategy { a, b, extrapolate })
771    }
772}
773
774impl<T, D> Default for CubicSpline<T, D>
775where
776    D: Dimension + RemoveAxis,
777    T: SplineNum,
778{
779    fn default() -> Self {
780        Self::new()
781    }
782}
783
784impl<Sd, Sx, D> Interp1DStrategy<Sd, Sx, D> for CubicSplineStrategy<Sd, D>
785where
786    Sd: Data,
787    Sd::Elem: SplineNum,
788    Sx: Data<Elem = Sd::Elem>,
789    D: Dimension + RemoveAxis,
790{
791    fn interp_into(
792        &self,
793        interp: &Interp1D<Sd, Sx, D, Self>,
794        target: ArrayViewMut<'_, <Sd>::Elem, <D as Dimension>::Smaller>,
795        x: <Sx>::Elem,
796    ) -> Result<(), InterpolateError> {
797        let in_range = interp.is_in_range(x);
798        if matches!(self.extrapolate, Extrapolate::No) && !in_range {
799            return Err(InterpolateError::OutOfBounds(format!(
800                "x = {x:#?} is not in range",
801            )));
802        }
803
804        let mut x = x;
805        if matches!(self.extrapolate, Extrapolate::Periodic) && !in_range {
806            let x0 = interp.x[0];
807            let xn = interp.x[interp.x.len() - 1];
808            x = ((x - x0).rem_euclid(&(xn - x0))) + x0;
809        }
810
811        let idx = interp.get_index_left_of(x);
812        let (x_left, data_left) = interp.index_point(idx);
813        let (x_right, data_right) = interp.index_point(idx + 1);
814        let a_left = self.a.index_axis(AX0, idx);
815        let b_left = self.b.index_axis(AX0, idx);
816        let one: Sd::Elem = cast(1.0).unwrap_or_else(|| unimplemented!());
817
818        let t = (x - x_left) / (x_right - x_left);
819        Zip::from(data_left)
820            .and(data_right)
821            .and(a_left)
822            .and(b_left)
823            .and(target)
824            .for_each(|&y_left, &y_right, &a_left, &b_left, y| {
825                *y = (one - t) * y_left
826                    + t * y_right
827                    + t * (one - t) * (a_left * (one - t) + b_left * t);
828            });
829        Ok(())
830    }
831}