Skip to main content

numeris/interp/
spline.rs

1use crate::traits::FloatScalar;
2
3use super::{InterpError, find_interval, validate_sorted};
4
5/// Natural cubic spline interpolant (fixed-size, stack-allocated).
6///
7/// Uses natural boundary conditions (S''(x₀) = S''(x_{N-1}) = 0). The tridiagonal
8/// system for second derivatives is solved via the Thomas algorithm in O(N).
9/// Requires at least 3 points.
10///
11/// Each segment stores coefficients `[a, b, c, d]` for:
12/// `S_i(x) = a + b·(x - x_i) + c·(x - x_i)² + d·(x - x_i)³`
13///
14/// # Example
15///
16/// ```
17/// use numeris::interp::CubicSpline;
18///
19/// let xs = [0.0_f64, 1.0, 2.0, 3.0];
20/// let ys = [0.0, 1.0, 0.0, 1.0];
21/// let spline = CubicSpline::new(xs, ys).unwrap();
22///
23/// // Passes through knots exactly
24/// assert!((spline.eval(0.0) - 0.0).abs() < 1e-14);
25/// assert!((spline.eval(1.0) - 1.0).abs() < 1e-14);
26/// assert!((spline.eval(2.0) - 0.0).abs() < 1e-14);
27/// ```
28#[derive(Debug, Clone)]
29pub struct CubicSpline<T, const N: usize> {
30    xs: [T; N],
31    // Per-segment coefficients [a, b, c, d]. Only indices 0..N-1 used.
32    // We store N entries (not N-1) to avoid unstable generic_const_exprs.
33    coeffs: [[T; 4]; N],
34}
35
36impl<T: FloatScalar, const N: usize> CubicSpline<T, N> {
37    /// Construct a natural cubic spline from sorted knots.
38    ///
39    /// Returns `InterpError::TooFewPoints` if `N < 3`,
40    /// `InterpError::NotSorted` if `xs` is not strictly increasing.
41    pub fn new(xs: [T; N], ys: [T; N]) -> Result<Self, InterpError> {
42        if N < 3 {
43            return Err(InterpError::TooFewPoints);
44        }
45        validate_sorted(&xs)?;
46
47        let two = T::one() + T::one();
48        let three = two + T::one();
49        let six = three + three;
50
51        // Thomas algorithm for natural cubic spline.
52        // Tridiagonal system for second derivatives m_i (interior points i=1..N-2).
53        // Natural BCs: m_0 = m_{N-1} = 0.
54        //
55        // For interior point i:
56        //   h_{i-1}·m_{i-1} + 2(h_{i-1}+h_i)·m_i + h_i·m_{i+1} = 6·(δ_i - δ_{i-1})
57        // where h_i = x_{i+1} - x_i, δ_i = (y_{i+1} - y_i) / h_i
58
59        let n = N;
60        // h[i] = x[i+1] - x[i], delta[i] = (y[i+1] - y[i]) / h[i]
61        let mut h = [T::zero(); N];
62        let mut delta = [T::zero(); N];
63        for i in 0..n - 1 {
64            h[i] = xs[i + 1] - xs[i];
65            delta[i] = (ys[i + 1] - ys[i]) / h[i];
66        }
67
68        // Solve tridiagonal system for m[1..n-2]
69        // Use m array for the solution, m[0] = m[n-1] = 0
70        let mut m = [T::zero(); N];
71
72        if n > 3 {
73            // Forward sweep arrays (use coeffs temporarily)
74            let mut cp = [T::zero(); N]; // modified upper diagonal
75            let mut dp = [T::zero(); N]; // modified RHS
76
77            // Row i=1 (first interior point)
78            let diag = two * (h[0] + h[1]);
79            if diag.abs() < T::epsilon() {
80                return Err(InterpError::IllConditioned);
81            }
82            let rhs = six * (delta[1] - delta[0]);
83            cp[1] = h[1] / diag;
84            dp[1] = rhs / diag;
85
86            // Forward sweep i=2..n-2
87            for i in 2..n - 1 {
88                let diag_i = two * (h[i - 1] + h[i]) - h[i - 1] * cp[i - 1];
89                if diag_i.abs() < T::epsilon() {
90                    return Err(InterpError::IllConditioned);
91                }
92                let rhs_i = six * (delta[i] - delta[i - 1]) - h[i - 1] * dp[i - 1];
93                if i < n - 1 {
94                    cp[i] = h[i] / diag_i;
95                }
96                dp[i] = rhs_i / diag_i;
97            }
98
99            // Back substitution
100            m[n - 2] = dp[n - 2];
101            for i in (1..n - 2).rev() {
102                m[i] = dp[i] - cp[i] * m[i + 1];
103            }
104        } else {
105            // n == 3: single interior point, direct solve
106            let diag = two * (h[0] + h[1]);
107            if diag.abs() < T::epsilon() {
108                return Err(InterpError::IllConditioned);
109            }
110            m[1] = six * (delta[1] - delta[0]) / diag;
111        }
112
113        // Compute per-segment coefficients
114        let mut coeffs = [[T::zero(); 4]; N];
115        for i in 0..n - 1 {
116            let a = ys[i];
117            let b = delta[i] - h[i] * (two * m[i] + m[i + 1]) / six;
118            let c = m[i] / two;
119            let d = (m[i + 1] - m[i]) / (six * h[i]);
120            coeffs[i] = [a, b, c, d];
121        }
122
123        Ok(Self { xs, coeffs })
124    }
125
126    /// Evaluate the spline at `x`.
127    pub fn eval(&self, x: T) -> T {
128        let i = find_interval(&self.xs, x);
129        let dx = x - self.xs[i];
130        let [a, b, c, d] = self.coeffs[i];
131        // Horner form: a + dx·(b + dx·(c + dx·d))
132        a + dx * (b + dx * (c + dx * d))
133    }
134
135    /// Evaluate the spline and its derivative at `x`.
136    pub fn eval_derivative(&self, x: T) -> (T, T) {
137        let i = find_interval(&self.xs, x);
138        let dx = x - self.xs[i];
139        let [a, b, c, d] = self.coeffs[i];
140        let two = T::one() + T::one();
141        let three = two + T::one();
142        let val = a + dx * (b + dx * (c + dx * d));
143        let dval = b + dx * (two * c + three * d * dx);
144        (val, dval)
145    }
146
147    /// The knot x-values.
148    pub fn xs(&self) -> &[T; N] {
149        &self.xs
150    }
151}
152
153// ---------- Dynamic variant ----------
154
155#[cfg(feature = "alloc")]
156extern crate alloc;
157#[cfg(feature = "alloc")]
158use alloc::vec::Vec;
159
160/// Natural cubic spline interpolant (heap-allocated, runtime-sized).
161///
162/// Dynamic counterpart of [`CubicSpline`]. Requires at least 3 points.
163///
164/// # Example
165///
166/// ```
167/// use numeris::interp::DynCubicSpline;
168///
169/// let xs = vec![0.0_f64, 1.0, 2.0, 3.0];
170/// let ys = vec![0.0, 1.0, 0.0, 1.0];
171/// let spline = DynCubicSpline::new(xs, ys).unwrap();
172/// assert!((spline.eval(1.0) - 1.0).abs() < 1e-14);
173/// ```
174#[cfg(feature = "alloc")]
175#[derive(Debug, Clone)]
176pub struct DynCubicSpline<T> {
177    xs: Vec<T>,
178    coeffs: Vec<[T; 4]>,
179}
180
181#[cfg(feature = "alloc")]
182impl<T: FloatScalar> DynCubicSpline<T> {
183    /// Construct a natural cubic spline from sorted knots.
184    pub fn new(xs: Vec<T>, ys: Vec<T>) -> Result<Self, InterpError> {
185        if xs.len() != ys.len() {
186            return Err(InterpError::LengthMismatch);
187        }
188        if xs.len() < 3 {
189            return Err(InterpError::TooFewPoints);
190        }
191        validate_sorted(&xs)?;
192
193        let n = xs.len();
194        let two = T::one() + T::one();
195        let three = two + T::one();
196        let six = three + three;
197
198        let mut h = alloc::vec![T::zero(); n];
199        let mut delta = alloc::vec![T::zero(); n];
200        for i in 0..n - 1 {
201            h[i] = xs[i + 1] - xs[i];
202            delta[i] = (ys[i + 1] - ys[i]) / h[i];
203        }
204
205        let mut m = alloc::vec![T::zero(); n];
206
207        if n > 3 {
208            let mut cp = alloc::vec![T::zero(); n];
209            let mut dp = alloc::vec![T::zero(); n];
210
211            let diag = two * (h[0] + h[1]);
212            if diag.abs() < T::epsilon() {
213                return Err(InterpError::IllConditioned);
214            }
215            let rhs = six * (delta[1] - delta[0]);
216            cp[1] = h[1] / diag;
217            dp[1] = rhs / diag;
218
219            for i in 2..n - 1 {
220                let diag_i = two * (h[i - 1] + h[i]) - h[i - 1] * cp[i - 1];
221                if diag_i.abs() < T::epsilon() {
222                    return Err(InterpError::IllConditioned);
223                }
224                let rhs_i = six * (delta[i] - delta[i - 1]) - h[i - 1] * dp[i - 1];
225                if i < n - 1 {
226                    cp[i] = h[i] / diag_i;
227                }
228                dp[i] = rhs_i / diag_i;
229            }
230
231            m[n - 2] = dp[n - 2];
232            for i in (1..n - 2).rev() {
233                m[i] = dp[i] - cp[i] * m[i + 1];
234            }
235        } else {
236            let diag = two * (h[0] + h[1]);
237            if diag.abs() < T::epsilon() {
238                return Err(InterpError::IllConditioned);
239            }
240            m[1] = six * (delta[1] - delta[0]) / diag;
241        }
242
243        let mut coeffs = alloc::vec![[T::zero(); 4]; n - 1];
244        for i in 0..n - 1 {
245            let a = ys[i];
246            let b = delta[i] - h[i] * (two * m[i] + m[i + 1]) / six;
247            let c = m[i] / two;
248            let d = (m[i + 1] - m[i]) / (six * h[i]);
249            coeffs[i] = [a, b, c, d];
250        }
251
252        Ok(Self { xs, coeffs })
253    }
254
255    /// Evaluate the spline at `x`.
256    pub fn eval(&self, x: T) -> T {
257        let i = find_interval(&self.xs, x);
258        let dx = x - self.xs[i];
259        let [a, b, c, d] = self.coeffs[i];
260        a + dx * (b + dx * (c + dx * d))
261    }
262
263    /// Evaluate the spline and its derivative at `x`.
264    pub fn eval_derivative(&self, x: T) -> (T, T) {
265        let i = find_interval(&self.xs, x);
266        let dx = x - self.xs[i];
267        let [a, b, c, d] = self.coeffs[i];
268        let two = T::one() + T::one();
269        let three = two + T::one();
270        let val = a + dx * (b + dx * (c + dx * d));
271        let dval = b + dx * (two * c + three * d * dx);
272        (val, dval)
273    }
274
275    /// The knot x-values.
276    pub fn xs(&self) -> &[T] {
277        &self.xs
278    }
279}