Skip to main content

numra_interp/
cubic_spline.rs

1//! Cubic spline interpolation.
2//!
3//! Supports natural, clamped, and not-a-knot boundary conditions.
4//!
5//! Author: Moussa Leblouba
6//! Date: 9 February 2026
7//! Modified: 2 May 2026
8
9use numra_core::Scalar;
10
11use crate::error::InterpError;
12use crate::{eval_piecewise_cubic, eval_piecewise_cubic_deriv, integrate_piecewise_cubic};
13use crate::{validate_data, Interpolant};
14
15/// Cubic spline interpolant.
16///
17/// On each interval `[x_i, x_{i+1}]`, the spline is a cubic polynomial
18/// `S_i(x) = a_i + b_i*(x-x_i) + c_i*(x-x_i)^2 + d_i*(x-x_i)^3`.
19pub struct CubicSpline<S: Scalar> {
20    x: Vec<S>,
21    a: Vec<S>,
22    b: Vec<S>,
23    c: Vec<S>,
24    d: Vec<S>,
25}
26
27impl<S: Scalar> CubicSpline<S> {
28    /// Natural cubic spline: S''(x_0) = S''(x_{n-1}) = 0.
29    pub fn natural(x: &[S], y: &[S]) -> Result<Self, InterpError> {
30        validate_data(x, y, 2)?;
31        let n = x.len();
32        if n == 2 {
33            return Self::from_linear(x, y);
34        }
35
36        let h = compute_h(x);
37        let mut m = vec![S::ZERO; n]; // second derivatives
38
39        // Solve (n-2) x (n-2) tridiagonal for interior m[1..n-2]
40        let n_int = n - 2;
41        let mut sub = vec![S::ZERO; n_int];
42        let mut diag = vec![S::ZERO; n_int];
43        let mut sup = vec![S::ZERO; n_int];
44        let mut rhs = vec![S::ZERO; n_int];
45
46        for k in 0..n_int {
47            let i = k + 1;
48            if k > 0 {
49                sub[k] = h[i - 1];
50            }
51            diag[k] = S::TWO * (h[i - 1] + h[i]);
52            if k < n_int - 1 {
53                sup[k] = h[i];
54            }
55            let s_prev = (y[i] - y[i - 1]) / h[i - 1];
56            let s_next = (y[i + 1] - y[i]) / h[i];
57            rhs[k] = S::from_f64(6.0) * (s_next - s_prev);
58        }
59
60        thomas_solve(&sub, &diag, &sup, &mut rhs);
61        m[1..n_int + 1].copy_from_slice(&rhs[..n_int]);
62
63        Ok(Self::from_second_derivatives(x, y, &h, &m))
64    }
65
66    /// Clamped cubic spline with specified endpoint derivatives.
67    pub fn clamped(x: &[S], y: &[S], dy_left: S, dy_right: S) -> Result<Self, InterpError> {
68        validate_data(x, y, 2)?;
69        let n = x.len();
70        if n == 2 {
71            return Self::from_linear(x, y);
72        }
73
74        let h = compute_h(x);
75
76        // Solve n x n tridiagonal
77        let mut sub = vec![S::ZERO; n];
78        let mut diag = vec![S::ZERO; n];
79        let mut sup = vec![S::ZERO; n];
80        let mut rhs = vec![S::ZERO; n];
81
82        // Row 0: clamped left BC
83        let s0 = (y[1] - y[0]) / h[0];
84        diag[0] = S::TWO * h[0];
85        sup[0] = h[0];
86        rhs[0] = S::from_f64(6.0) * (s0 - dy_left);
87
88        // Interior rows
89        for i in 1..n - 1 {
90            sub[i] = h[i - 1];
91            diag[i] = S::TWO * (h[i - 1] + h[i]);
92            sup[i] = h[i];
93            let s_prev = (y[i] - y[i - 1]) / h[i - 1];
94            let s_next = (y[i + 1] - y[i]) / h[i];
95            rhs[i] = S::from_f64(6.0) * (s_next - s_prev);
96        }
97
98        // Row n-1: clamped right BC
99        let sn = (y[n - 1] - y[n - 2]) / h[n - 2];
100        sub[n - 1] = h[n - 2];
101        diag[n - 1] = S::TWO * h[n - 2];
102        rhs[n - 1] = S::from_f64(6.0) * (dy_right - sn);
103
104        thomas_solve(&sub, &diag, &sup, &mut rhs);
105
106        Ok(Self::from_second_derivatives(x, y, &h, &rhs))
107    }
108
109    /// Not-a-knot cubic spline: third derivative continuous at `x[1]` and `x[n-2]`.
110    ///
111    /// Requires at least 4 points. Falls back to natural for fewer points.
112    pub fn not_a_knot(x: &[S], y: &[S]) -> Result<Self, InterpError> {
113        validate_data(x, y, 2)?;
114        let n = x.len();
115        if n <= 3 {
116            return Self::natural(x, y);
117        }
118
119        let h = compute_h(x);
120        let n_int = n - 2;
121
122        // Build (n-2) x (n-2) tridiagonal for m[1..n-2], with modified first and last rows
123        // Not-a-knot left: m_0 = ((h0+h1)/h1)*m_1 - (h0/h1)*m_2
124        // Not-a-knot right: m_{n-1} = ((h_{n-3}+h_{n-2})/h_{n-3})*m_{n-2} - (h_{n-2}/h_{n-3})*m_{n-3}
125
126        let mut sub = vec![S::ZERO; n_int];
127        let mut diag = vec![S::ZERO; n_int];
128        let mut sup = vec![S::ZERO; n_int];
129        let mut rhs = vec![S::ZERO; n_int];
130
131        // Standard interior equations
132        for (k, rhs_k) in rhs.iter_mut().enumerate().take(n_int) {
133            let i = k + 1;
134            let s_prev = (y[i] - y[i - 1]) / h[i - 1];
135            let s_next = (y[i + 1] - y[i]) / h[i];
136            *rhs_k = S::from_f64(6.0) * (s_next - s_prev);
137        }
138
139        // Row 0 (i=1): h[0]*m_0 + 2*(h[0]+h[1])*m_1 + h[1]*m_2 = rhs[0]
140        // Substitute m_0 = alpha1*m_1 + alpha2*m_2
141        let alpha1 = (h[0] + h[1]) / h[1];
142        let alpha2 = -h[0] / h[1];
143        diag[0] = h[0] * alpha1 + S::TWO * (h[0] + h[1]);
144        sup[0] = h[0] * alpha2 + h[1];
145
146        // Fill standard interior rows 1..n_int-2
147        for k in 1..n_int - 1 {
148            let i = k + 1;
149            sub[k] = h[i - 1];
150            diag[k] = S::TWO * (h[i - 1] + h[i]);
151            sup[k] = h[i];
152        }
153
154        // Last row (i=n-2): h[n-3]*m_{n-3} + 2*(h[n-3]+h[n-2])*m_{n-2} + h[n-2]*m_{n-1} = rhs[n_int-1]
155        // Substitute m_{n-1} = beta1*m_{n-2} + beta2*m_{n-3}
156        let beta1 = (h[n - 3] + h[n - 2]) / h[n - 3];
157        let beta2 = -h[n - 2] / h[n - 3];
158        sub[n_int - 1] = h[n - 3] + h[n - 2] * beta2;
159        diag[n_int - 1] = S::TWO * (h[n - 3] + h[n - 2]) + h[n - 2] * beta1;
160
161        thomas_solve(&sub, &diag, &sup, &mut rhs);
162
163        // Recover full m array
164        let mut m = vec![S::ZERO; n];
165        m[1..n_int + 1].copy_from_slice(&rhs[..n_int]);
166        // m_0 from not-a-knot left
167        m[0] = alpha1 * m[1] + alpha2 * m[2];
168        // m_{n-1} from not-a-knot right
169        m[n - 1] = beta1 * m[n - 2] + beta2 * m[n - 3];
170
171        Ok(Self::from_second_derivatives(x, y, &h, &m))
172    }
173
174    /// Build from second derivatives.
175    fn from_second_derivatives(x: &[S], y: &[S], h: &[S], m: &[S]) -> Self {
176        let n = x.len();
177        let n_seg = n - 1;
178        let mut a = Vec::with_capacity(n_seg);
179        let mut b = Vec::with_capacity(n_seg);
180        let mut c = Vec::with_capacity(n_seg);
181        let mut d = Vec::with_capacity(n_seg);
182
183        let six = S::from_f64(6.0);
184        for i in 0..n_seg {
185            a.push(y[i]);
186            b.push((y[i + 1] - y[i]) / h[i] - h[i] * (S::TWO * m[i] + m[i + 1]) / six);
187            c.push(m[i] * S::HALF);
188            d.push((m[i + 1] - m[i]) / (six * h[i]));
189        }
190
191        Self {
192            x: x.to_vec(),
193            a,
194            b,
195            c,
196            d,
197        }
198    }
199
200    /// Trivial linear case for 2 points.
201    fn from_linear(x: &[S], y: &[S]) -> Result<Self, InterpError> {
202        let h = x[1] - x[0];
203        Ok(Self {
204            x: x.to_vec(),
205            a: vec![y[0]],
206            b: vec![(y[1] - y[0]) / h],
207            c: vec![S::ZERO],
208            d: vec![S::ZERO],
209        })
210    }
211}
212
213impl<S: Scalar> Interpolant<S> for CubicSpline<S> {
214    fn interpolate(&self, x: S) -> S {
215        eval_piecewise_cubic(&self.x, &self.a, &self.b, &self.c, &self.d, x)
216    }
217
218    fn derivative(&self, x: S) -> Option<S> {
219        Some(eval_piecewise_cubic_deriv(
220            &self.x, &self.b, &self.c, &self.d, x,
221        ))
222    }
223
224    fn integrate(&self, a: S, b: S) -> Option<S> {
225        Some(integrate_piecewise_cubic(
226            &self.x, &self.a, &self.b, &self.c, &self.d, a, b,
227        ))
228    }
229}
230
231// ============================================================================
232// Internal helpers
233// ============================================================================
234
235/// Compute interval widths h[i] = x[i+1] - x[i].
236fn compute_h<S: Scalar>(x: &[S]) -> Vec<S> {
237    (0..x.len() - 1).map(|i| x[i + 1] - x[i]).collect()
238}
239
240/// Thomas algorithm for tridiagonal system.
241///
242/// Solves sub[i]*x_{i-1} + diag[i]*x_i + sup[i]*x_{i+1} = rhs[i].
243/// sub[0] and sup[n-1] are ignored. Solution is returned in `rhs`.
244fn thomas_solve<S: Scalar>(sub: &[S], diag: &[S], sup: &[S], rhs: &mut [S]) {
245    let n = diag.len();
246    if n == 0 {
247        return;
248    }
249    if n == 1 {
250        rhs[0] /= diag[0];
251        return;
252    }
253
254    let mut cp = vec![S::ZERO; n];
255    let mut dp = vec![S::ZERO; n];
256
257    cp[0] = sup[0] / diag[0];
258    dp[0] = rhs[0] / diag[0];
259
260    for i in 1..n {
261        let m = diag[i] - sub[i] * cp[i - 1];
262        cp[i] = if i < n - 1 { sup[i] / m } else { S::ZERO };
263        dp[i] = (rhs[i] - sub[i] * dp[i - 1]) / m;
264    }
265
266    rhs[n - 1] = dp[n - 1];
267    for i in (0..n - 1).rev() {
268        rhs[i] = dp[i] - cp[i] * rhs[i + 1];
269    }
270}
271
272/// Build piecewise cubic coefficients from slopes at each knot.
273/// Used by PCHIP and Akima.
274pub(crate) fn coefficients_from_slopes<S: Scalar>(
275    x: &[S],
276    y: &[S],
277    slopes: &[S],
278) -> (Vec<S>, Vec<S>, Vec<S>, Vec<S>) {
279    let n_seg = x.len() - 1;
280    let mut a = Vec::with_capacity(n_seg);
281    let mut b = Vec::with_capacity(n_seg);
282    let mut c = Vec::with_capacity(n_seg);
283    let mut d = Vec::with_capacity(n_seg);
284
285    for i in 0..n_seg {
286        let h = x[i + 1] - x[i];
287        let s = (y[i + 1] - y[i]) / h;
288        a.push(y[i]);
289        b.push(slopes[i]);
290        c.push((S::from_f64(3.0) * s - S::TWO * slopes[i] - slopes[i + 1]) / h);
291        d.push((slopes[i] + slopes[i + 1] - S::TWO * s) / (h * h));
292    }
293    (a, b, c, d)
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use approx::assert_relative_eq;
300
301    fn sample_sin(n: usize) -> (Vec<f64>, Vec<f64>) {
302        let x: Vec<f64> = (0..n)
303            .map(|i| i as f64 * core::f64::consts::PI * 2.0 / (n - 1) as f64)
304            .collect();
305        let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
306        (x, y)
307    }
308
309    #[test]
310    fn test_natural_at_knots() {
311        let (x, y) = sample_sin(10);
312        let cs = CubicSpline::natural(&x, &y).unwrap();
313        for (xi, yi) in x.iter().zip(y.iter()) {
314            assert_relative_eq!(cs.interpolate(*xi), *yi, epsilon = 1e-12);
315        }
316    }
317
318    #[test]
319    fn test_natural_smooth() {
320        let (x, y) = sample_sin(20);
321        let cs = CubicSpline::natural(&x, &y).unwrap();
322        // Interpolation should be close to sin(x)
323        let test_x = 1.0;
324        let err = (cs.interpolate(test_x) - test_x.sin()).abs();
325        assert!(err < 1e-4, "Error too large: {}", err);
326    }
327
328    #[test]
329    fn test_clamped_polynomial() {
330        // Cubic polynomial f(x) = x^3, f'(x) = 3x^2
331        // Clamped spline with exact derivatives should reproduce exactly
332        let x = vec![0.0, 1.0, 2.0, 3.0];
333        let y: Vec<f64> = x.iter().map(|&xi| xi.powi(3)).collect();
334        let cs = CubicSpline::clamped(&x, &y, 0.0, 27.0).unwrap();
335        // Check at midpoints
336        assert_relative_eq!(cs.interpolate(0.5), 0.125, epsilon = 1e-10);
337        assert_relative_eq!(cs.interpolate(1.5), 3.375, epsilon = 1e-10);
338        assert_relative_eq!(cs.interpolate(2.5), 15.625, epsilon = 1e-10);
339    }
340
341    #[test]
342    fn test_not_a_knot_cubic() {
343        // Not-a-knot should reproduce a cubic polynomial exactly
344        let x = vec![0.0, 1.0, 2.0, 3.0, 4.0];
345        let y: Vec<f64> = x.iter().map(|&xi| xi.powi(3) - 2.0 * xi).collect();
346        let cs = CubicSpline::not_a_knot(&x, &y).unwrap();
347        // Check at non-knot points
348        for t in [0.25, 0.75, 1.5, 2.5, 3.5] {
349            let expected = t.powi(3) - 2.0 * t;
350            assert_relative_eq!(cs.interpolate(t), expected, epsilon = 1e-10);
351        }
352    }
353
354    #[test]
355    fn test_derivative() {
356        let x = vec![0.0, 1.0, 2.0, 3.0];
357        let y: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
358        let cs = CubicSpline::natural(&x, &y).unwrap();
359        // Derivative of x^2 is 2x; not exact for natural spline but close
360        let deriv = cs.derivative(1.5).unwrap();
361        assert!(
362            (deriv - 3.0).abs() < 0.5,
363            "Derivative error too large: {}",
364            (deriv - 3.0).abs()
365        );
366    }
367
368    #[test]
369    fn test_integrate() {
370        // Integral of x^2 from 0 to 3 = 9
371        let x = vec![0.0, 1.0, 2.0, 3.0];
372        let y: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
373        let cs = CubicSpline::natural(&x, &y).unwrap();
374        let integral = cs.integrate(0.0, 3.0).unwrap();
375        assert_relative_eq!(integral, 9.0, epsilon = 0.1);
376    }
377
378    #[test]
379    fn test_two_points() {
380        let cs = CubicSpline::natural(&[0.0, 1.0], &[0.0, 1.0]).unwrap();
381        assert_relative_eq!(cs.interpolate(0.5), 0.5, epsilon = 1e-14);
382    }
383
384    #[test]
385    fn test_c2_continuity() {
386        let (x, y) = sample_sin(10);
387        let cs = CubicSpline::natural(&x, &y).unwrap();
388        // First derivative should be continuous at interior knots (C1 check)
389        // This indirectly validates C2 since the spline construction enforces it
390        for i in 1..x.len() - 1 {
391            let eps = 1e-8;
392            let d_left = cs.derivative(x[i] - eps).unwrap();
393            let d_right = cs.derivative(x[i] + eps).unwrap();
394            assert!(
395                (d_left - d_right).abs() < 1e-4,
396                "C1 discontinuity at x[{}]={}: left={}, right={}",
397                i,
398                x[i],
399                d_left,
400                d_right
401            );
402        }
403    }
404
405    #[test]
406    fn test_f32() {
407        let cs = CubicSpline::natural(&[0.0f32, 1.0, 2.0, 3.0], &[0.0, 1.0, 0.0, 1.0]).unwrap();
408        let _ = cs.interpolate(1.5f32);
409    }
410}