1use numra_core::Scalar;
10
11use crate::error::InterpError;
12use crate::{eval_piecewise_cubic, eval_piecewise_cubic_deriv, integrate_piecewise_cubic};
13use crate::{validate_data, Interpolant};
14
15pub struct CubicSpline<S: Scalar> {
20 x: Vec<S>,
21 a: Vec<S>,
22 b: Vec<S>,
23 c: Vec<S>,
24 d: Vec<S>,
25}
26
27impl<S: Scalar> CubicSpline<S> {
28 pub fn natural(x: &[S], y: &[S]) -> Result<Self, InterpError> {
30 validate_data(x, y, 2)?;
31 let n = x.len();
32 if n == 2 {
33 return Self::from_linear(x, y);
34 }
35
36 let h = compute_h(x);
37 let mut m = vec![S::ZERO; n]; let n_int = n - 2;
41 let mut sub = vec![S::ZERO; n_int];
42 let mut diag = vec![S::ZERO; n_int];
43 let mut sup = vec![S::ZERO; n_int];
44 let mut rhs = vec![S::ZERO; n_int];
45
46 for k in 0..n_int {
47 let i = k + 1;
48 if k > 0 {
49 sub[k] = h[i - 1];
50 }
51 diag[k] = S::TWO * (h[i - 1] + h[i]);
52 if k < n_int - 1 {
53 sup[k] = h[i];
54 }
55 let s_prev = (y[i] - y[i - 1]) / h[i - 1];
56 let s_next = (y[i + 1] - y[i]) / h[i];
57 rhs[k] = S::from_f64(6.0) * (s_next - s_prev);
58 }
59
60 thomas_solve(&sub, &diag, &sup, &mut rhs);
61 m[1..n_int + 1].copy_from_slice(&rhs[..n_int]);
62
63 Ok(Self::from_second_derivatives(x, y, &h, &m))
64 }
65
66 pub fn clamped(x: &[S], y: &[S], dy_left: S, dy_right: S) -> Result<Self, InterpError> {
68 validate_data(x, y, 2)?;
69 let n = x.len();
70 if n == 2 {
71 return Self::from_linear(x, y);
72 }
73
74 let h = compute_h(x);
75
76 let mut sub = vec![S::ZERO; n];
78 let mut diag = vec![S::ZERO; n];
79 let mut sup = vec![S::ZERO; n];
80 let mut rhs = vec![S::ZERO; n];
81
82 let s0 = (y[1] - y[0]) / h[0];
84 diag[0] = S::TWO * h[0];
85 sup[0] = h[0];
86 rhs[0] = S::from_f64(6.0) * (s0 - dy_left);
87
88 for i in 1..n - 1 {
90 sub[i] = h[i - 1];
91 diag[i] = S::TWO * (h[i - 1] + h[i]);
92 sup[i] = h[i];
93 let s_prev = (y[i] - y[i - 1]) / h[i - 1];
94 let s_next = (y[i + 1] - y[i]) / h[i];
95 rhs[i] = S::from_f64(6.0) * (s_next - s_prev);
96 }
97
98 let sn = (y[n - 1] - y[n - 2]) / h[n - 2];
100 sub[n - 1] = h[n - 2];
101 diag[n - 1] = S::TWO * h[n - 2];
102 rhs[n - 1] = S::from_f64(6.0) * (dy_right - sn);
103
104 thomas_solve(&sub, &diag, &sup, &mut rhs);
105
106 Ok(Self::from_second_derivatives(x, y, &h, &rhs))
107 }
108
109 pub fn not_a_knot(x: &[S], y: &[S]) -> Result<Self, InterpError> {
113 validate_data(x, y, 2)?;
114 let n = x.len();
115 if n <= 3 {
116 return Self::natural(x, y);
117 }
118
119 let h = compute_h(x);
120 let n_int = n - 2;
121
122 let mut sub = vec![S::ZERO; n_int];
127 let mut diag = vec![S::ZERO; n_int];
128 let mut sup = vec![S::ZERO; n_int];
129 let mut rhs = vec![S::ZERO; n_int];
130
131 for (k, rhs_k) in rhs.iter_mut().enumerate().take(n_int) {
133 let i = k + 1;
134 let s_prev = (y[i] - y[i - 1]) / h[i - 1];
135 let s_next = (y[i + 1] - y[i]) / h[i];
136 *rhs_k = S::from_f64(6.0) * (s_next - s_prev);
137 }
138
139 let alpha1 = (h[0] + h[1]) / h[1];
142 let alpha2 = -h[0] / h[1];
143 diag[0] = h[0] * alpha1 + S::TWO * (h[0] + h[1]);
144 sup[0] = h[0] * alpha2 + h[1];
145
146 for k in 1..n_int - 1 {
148 let i = k + 1;
149 sub[k] = h[i - 1];
150 diag[k] = S::TWO * (h[i - 1] + h[i]);
151 sup[k] = h[i];
152 }
153
154 let beta1 = (h[n - 3] + h[n - 2]) / h[n - 3];
157 let beta2 = -h[n - 2] / h[n - 3];
158 sub[n_int - 1] = h[n - 3] + h[n - 2] * beta2;
159 diag[n_int - 1] = S::TWO * (h[n - 3] + h[n - 2]) + h[n - 2] * beta1;
160
161 thomas_solve(&sub, &diag, &sup, &mut rhs);
162
163 let mut m = vec![S::ZERO; n];
165 m[1..n_int + 1].copy_from_slice(&rhs[..n_int]);
166 m[0] = alpha1 * m[1] + alpha2 * m[2];
168 m[n - 1] = beta1 * m[n - 2] + beta2 * m[n - 3];
170
171 Ok(Self::from_second_derivatives(x, y, &h, &m))
172 }
173
174 fn from_second_derivatives(x: &[S], y: &[S], h: &[S], m: &[S]) -> Self {
176 let n = x.len();
177 let n_seg = n - 1;
178 let mut a = Vec::with_capacity(n_seg);
179 let mut b = Vec::with_capacity(n_seg);
180 let mut c = Vec::with_capacity(n_seg);
181 let mut d = Vec::with_capacity(n_seg);
182
183 let six = S::from_f64(6.0);
184 for i in 0..n_seg {
185 a.push(y[i]);
186 b.push((y[i + 1] - y[i]) / h[i] - h[i] * (S::TWO * m[i] + m[i + 1]) / six);
187 c.push(m[i] * S::HALF);
188 d.push((m[i + 1] - m[i]) / (six * h[i]));
189 }
190
191 Self {
192 x: x.to_vec(),
193 a,
194 b,
195 c,
196 d,
197 }
198 }
199
200 fn from_linear(x: &[S], y: &[S]) -> Result<Self, InterpError> {
202 let h = x[1] - x[0];
203 Ok(Self {
204 x: x.to_vec(),
205 a: vec![y[0]],
206 b: vec![(y[1] - y[0]) / h],
207 c: vec![S::ZERO],
208 d: vec![S::ZERO],
209 })
210 }
211}
212
213impl<S: Scalar> Interpolant<S> for CubicSpline<S> {
214 fn interpolate(&self, x: S) -> S {
215 eval_piecewise_cubic(&self.x, &self.a, &self.b, &self.c, &self.d, x)
216 }
217
218 fn derivative(&self, x: S) -> Option<S> {
219 Some(eval_piecewise_cubic_deriv(
220 &self.x, &self.b, &self.c, &self.d, x,
221 ))
222 }
223
224 fn integrate(&self, a: S, b: S) -> Option<S> {
225 Some(integrate_piecewise_cubic(
226 &self.x, &self.a, &self.b, &self.c, &self.d, a, b,
227 ))
228 }
229}
230
231fn compute_h<S: Scalar>(x: &[S]) -> Vec<S> {
237 (0..x.len() - 1).map(|i| x[i + 1] - x[i]).collect()
238}
239
240fn thomas_solve<S: Scalar>(sub: &[S], diag: &[S], sup: &[S], rhs: &mut [S]) {
245 let n = diag.len();
246 if n == 0 {
247 return;
248 }
249 if n == 1 {
250 rhs[0] /= diag[0];
251 return;
252 }
253
254 let mut cp = vec![S::ZERO; n];
255 let mut dp = vec![S::ZERO; n];
256
257 cp[0] = sup[0] / diag[0];
258 dp[0] = rhs[0] / diag[0];
259
260 for i in 1..n {
261 let m = diag[i] - sub[i] * cp[i - 1];
262 cp[i] = if i < n - 1 { sup[i] / m } else { S::ZERO };
263 dp[i] = (rhs[i] - sub[i] * dp[i - 1]) / m;
264 }
265
266 rhs[n - 1] = dp[n - 1];
267 for i in (0..n - 1).rev() {
268 rhs[i] = dp[i] - cp[i] * rhs[i + 1];
269 }
270}
271
272pub(crate) fn coefficients_from_slopes<S: Scalar>(
275 x: &[S],
276 y: &[S],
277 slopes: &[S],
278) -> (Vec<S>, Vec<S>, Vec<S>, Vec<S>) {
279 let n_seg = x.len() - 1;
280 let mut a = Vec::with_capacity(n_seg);
281 let mut b = Vec::with_capacity(n_seg);
282 let mut c = Vec::with_capacity(n_seg);
283 let mut d = Vec::with_capacity(n_seg);
284
285 for i in 0..n_seg {
286 let h = x[i + 1] - x[i];
287 let s = (y[i + 1] - y[i]) / h;
288 a.push(y[i]);
289 b.push(slopes[i]);
290 c.push((S::from_f64(3.0) * s - S::TWO * slopes[i] - slopes[i + 1]) / h);
291 d.push((slopes[i] + slopes[i + 1] - S::TWO * s) / (h * h));
292 }
293 (a, b, c, d)
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use approx::assert_relative_eq;
300
301 fn sample_sin(n: usize) -> (Vec<f64>, Vec<f64>) {
302 let x: Vec<f64> = (0..n)
303 .map(|i| i as f64 * core::f64::consts::PI * 2.0 / (n - 1) as f64)
304 .collect();
305 let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
306 (x, y)
307 }
308
309 #[test]
310 fn test_natural_at_knots() {
311 let (x, y) = sample_sin(10);
312 let cs = CubicSpline::natural(&x, &y).unwrap();
313 for (xi, yi) in x.iter().zip(y.iter()) {
314 assert_relative_eq!(cs.interpolate(*xi), *yi, epsilon = 1e-12);
315 }
316 }
317
318 #[test]
319 fn test_natural_smooth() {
320 let (x, y) = sample_sin(20);
321 let cs = CubicSpline::natural(&x, &y).unwrap();
322 let test_x = 1.0;
324 let err = (cs.interpolate(test_x) - test_x.sin()).abs();
325 assert!(err < 1e-4, "Error too large: {}", err);
326 }
327
328 #[test]
329 fn test_clamped_polynomial() {
330 let x = vec![0.0, 1.0, 2.0, 3.0];
333 let y: Vec<f64> = x.iter().map(|&xi| xi.powi(3)).collect();
334 let cs = CubicSpline::clamped(&x, &y, 0.0, 27.0).unwrap();
335 assert_relative_eq!(cs.interpolate(0.5), 0.125, epsilon = 1e-10);
337 assert_relative_eq!(cs.interpolate(1.5), 3.375, epsilon = 1e-10);
338 assert_relative_eq!(cs.interpolate(2.5), 15.625, epsilon = 1e-10);
339 }
340
341 #[test]
342 fn test_not_a_knot_cubic() {
343 let x = vec![0.0, 1.0, 2.0, 3.0, 4.0];
345 let y: Vec<f64> = x.iter().map(|&xi| xi.powi(3) - 2.0 * xi).collect();
346 let cs = CubicSpline::not_a_knot(&x, &y).unwrap();
347 for t in [0.25, 0.75, 1.5, 2.5, 3.5] {
349 let expected = t.powi(3) - 2.0 * t;
350 assert_relative_eq!(cs.interpolate(t), expected, epsilon = 1e-10);
351 }
352 }
353
354 #[test]
355 fn test_derivative() {
356 let x = vec![0.0, 1.0, 2.0, 3.0];
357 let y: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
358 let cs = CubicSpline::natural(&x, &y).unwrap();
359 let deriv = cs.derivative(1.5).unwrap();
361 assert!(
362 (deriv - 3.0).abs() < 0.5,
363 "Derivative error too large: {}",
364 (deriv - 3.0).abs()
365 );
366 }
367
368 #[test]
369 fn test_integrate() {
370 let x = vec![0.0, 1.0, 2.0, 3.0];
372 let y: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
373 let cs = CubicSpline::natural(&x, &y).unwrap();
374 let integral = cs.integrate(0.0, 3.0).unwrap();
375 assert_relative_eq!(integral, 9.0, epsilon = 0.1);
376 }
377
378 #[test]
379 fn test_two_points() {
380 let cs = CubicSpline::natural(&[0.0, 1.0], &[0.0, 1.0]).unwrap();
381 assert_relative_eq!(cs.interpolate(0.5), 0.5, epsilon = 1e-14);
382 }
383
384 #[test]
385 fn test_c2_continuity() {
386 let (x, y) = sample_sin(10);
387 let cs = CubicSpline::natural(&x, &y).unwrap();
388 for i in 1..x.len() - 1 {
391 let eps = 1e-8;
392 let d_left = cs.derivative(x[i] - eps).unwrap();
393 let d_right = cs.derivative(x[i] + eps).unwrap();
394 assert!(
395 (d_left - d_right).abs() < 1e-4,
396 "C1 discontinuity at x[{}]={}: left={}, right={}",
397 i,
398 x[i],
399 d_left,
400 d_right
401 );
402 }
403 }
404
405 #[test]
406 fn test_f32() {
407 let cs = CubicSpline::natural(&[0.0f32, 1.0, 2.0, 3.0], &[0.0, 1.0, 0.0, 1.0]).unwrap();
408 let _ = cs.interpolate(1.5f32);
409 }
410}