Skip to main content

neco_spline/
lib.rs

1//! zero dependency natural cubic spline interpolation.
2
3/// Spline interpolation error type.
4#[derive(Debug, Clone)]
5pub enum SplineError {
6    /// X values are not strictly ascending.
7    NonAscendingX,
8    /// Fewer than 2 points provided.
9    InsufficientPoints,
10}
11
12impl std::fmt::Display for SplineError {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        match self {
15            SplineError::NonAscendingX => {
16                write!(f, "control point x values must be strictly ascending")
17            }
18            SplineError::InsufficientPoints => {
19                write!(f, "spline requires at least 2 points")
20            }
21        }
22    }
23}
24
25impl std::error::Error for SplineError {}
26
27/// Natural cubic spline interpolation.
28pub struct CubicSpline {
29    /// Control points, sorted by ascending x.
30    points: Vec<(f32, f32)>,
31    /// Per-segment coefficients [a, b, c, d]: y = a + b*(x-xi) + c*(x-xi)^2 + d*(x-xi)^3
32    coefficients: Vec<[f32; 4]>,
33}
34
35impl CubicSpline {
36    pub fn new(points: &[(f32, f32)]) -> Result<Self, SplineError> {
37        if points.len() < 2 {
38            return Err(SplineError::InsufficientPoints);
39        }
40        for i in 1..points.len() {
41            if points[i].0 <= points[i - 1].0 {
42                return Err(SplineError::NonAscendingX);
43            }
44        }
45
46        if points.len() == 2 {
47            let (x0, y0) = points[0];
48            let (x1, y1) = points[1];
49            let slope = (y1 - y0) / (x1 - x0);
50            return Ok(Self {
51                points: points.to_vec(),
52                coefficients: vec![[y0, slope, 0.0, 0.0]],
53            });
54        }
55
56        let n = points.len() - 1;
57        let h: Vec<f32> = (0..n).map(|i| points[i + 1].0 - points[i].0).collect();
58        let f: Vec<f32> = (0..n)
59            .map(|i| (points[i + 1].1 - points[i].1) / h[i])
60            .collect();
61
62        let m = n - 1;
63        if m == 0 {
64            unreachable!();
65        }
66
67        let mut diag: Vec<f32> = Vec::with_capacity(m);
68        let mut sup: Vec<f32> = Vec::with_capacity(m);
69        let mut sub: Vec<f32> = Vec::with_capacity(m);
70        let mut rhs: Vec<f32> = Vec::with_capacity(m);
71
72        for i in 1..n {
73            let idx = i - 1;
74            diag.push(2.0 * (h[i - 1] + h[i]));
75            rhs.push(3.0 * (f[i] - f[i - 1]));
76            if idx > 0 {
77                sub.push(h[i - 1]);
78            }
79            if idx < m - 1 {
80                sup.push(h[i]);
81            }
82        }
83
84        for i in 1..m {
85            let factor = sub[i - 1] / diag[i - 1];
86            diag[i] -= factor * sup[i - 1];
87            rhs[i] -= factor * rhs[i - 1];
88        }
89
90        let mut c_inner = vec![0.0f32; m];
91        c_inner[m - 1] = rhs[m - 1] / diag[m - 1];
92        for i in (0..m - 1).rev() {
93            c_inner[i] = (rhs[i] - sup[i] * c_inner[i + 1]) / diag[i];
94        }
95
96        let mut c = vec![0.0f32; n + 1];
97        c[1..(m + 1)].copy_from_slice(&c_inner[..m]);
98
99        let mut coefficients = Vec::with_capacity(n);
100        for i in 0..n {
101            let a = points[i].1;
102            let b = f[i] - h[i] * (2.0 * c[i] + c[i + 1]) / 3.0;
103            let d = (c[i + 1] - c[i]) / (3.0 * h[i]);
104            coefficients.push([a, b, c[i], d]);
105        }
106
107        Ok(Self {
108            points: points.to_vec(),
109            coefficients,
110        })
111    }
112
113    pub fn evaluate(&self, x: f32) -> f32 {
114        self.evaluate_with_index(x, None).0
115    }
116
117    /// Evaluate the spline for multiple x values.
118    ///
119    /// When `xs` is sorted in non-decreasing order, span lookup is amortized O(n).
120    /// Unsorted input falls back to pointwise binary-search evaluation.
121    pub fn evaluate_batch(&self, xs: &[f32]) -> Vec<f32> {
122        if xs.is_empty() {
123            return Vec::new();
124        }
125        if xs.windows(2).all(|w| w[0] <= w[1]) {
126            let mut out = Vec::with_capacity(xs.len());
127            let mut segment = 0usize;
128            for &x in xs {
129                let (value, next_segment) = self.evaluate_with_index(x, Some(segment));
130                out.push(value);
131                segment = next_segment;
132            }
133            out
134        } else {
135            xs.iter().map(|&x| self.evaluate(x)).collect()
136        }
137    }
138
139    fn evaluate_with_index(&self, x: f32, start_segment: Option<usize>) -> (f32, usize) {
140        if x <= self.points[0].0 {
141            return (self.points[0].1, 0);
142        }
143        let last = self.points.len() - 1;
144        if x >= self.points[last].0 {
145            return (self.points[last].1, self.coefficients.len() - 1);
146        }
147
148        let segment = match start_segment {
149            Some(mut idx) if idx < self.coefficients.len() => {
150                while idx + 1 < self.points.len() && x >= self.points[idx + 1].0 {
151                    idx += 1;
152                }
153                idx
154            }
155            _ => self.find_segment(x),
156        };
157
158        (self.evaluate_segment(segment, x), segment)
159    }
160
161    fn find_segment(&self, x: f32) -> usize {
162        let mut lo = 0;
163        let mut hi = self.points.len() - 1;
164        while lo < hi - 1 {
165            let mid = (lo + hi) / 2;
166            if x < self.points[mid].0 {
167                hi = mid;
168            } else {
169                lo = mid;
170            }
171        }
172        lo
173    }
174
175    fn evaluate_segment(&self, segment: usize, x: f32) -> f32 {
176        let dx = x - self.points[segment].0;
177        let [a, b, c, d] = self.coefficients[segment];
178        a + b * dx + c * dx * dx + d * dx * dx * dx
179    }
180}
181
182mod bezier;
183
184pub use bezier::BezierCubic;
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn linear_two_points() {
192        let spline = CubicSpline::new(&[(0.0, 0.0), (1.0, 1.0)]).unwrap();
193        assert!((spline.evaluate(0.5) - 0.5).abs() < 1e-6);
194        assert!((spline.evaluate(0.25) - 0.25).abs() < 1e-6);
195        assert!((spline.evaluate(0.75) - 0.75).abs() < 1e-6);
196    }
197
198    #[test]
199    fn three_points_passes_through_control_points() {
200        let points = [(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)];
201        let spline = CubicSpline::new(&points).unwrap();
202        for &(x, y) in &points {
203            assert!(
204                (spline.evaluate(x) - y).abs() < 1e-5,
205                "value mismatch at control point ({}, {}): {}",
206                x,
207                y,
208                spline.evaluate(x)
209            );
210        }
211    }
212
213    #[test]
214    fn three_points_interpolation() {
215        let spline = CubicSpline::new(&[(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)]).unwrap();
216        let val = spline.evaluate(0.25);
217        assert!(val > 0.0 && val < 0.8, "unexpected value {} at 0.25", val);
218        let val_mid = spline.evaluate(0.75);
219        assert!(
220            val_mid > 0.7 && val_mid < 1.1,
221            "unexpected value {} at 0.75",
222            val_mid
223        );
224    }
225
226    #[test]
227    fn clamp_outside_range() {
228        let spline = CubicSpline::new(&[(0.2, 0.3), (0.8, 0.9)]).unwrap();
229        assert!((spline.evaluate(0.0) - 0.3).abs() < 1e-6);
230        assert!((spline.evaluate(-1.0) - 0.3).abs() < 1e-6);
231        assert!((spline.evaluate(1.0) - 0.9).abs() < 1e-6);
232        assert!((spline.evaluate(10.0) - 0.9).abs() < 1e-6);
233    }
234
235    #[test]
236    fn error_on_non_ascending_x() {
237        let result = CubicSpline::new(&[(0.5, 0.0), (0.3, 1.0)]);
238        assert!(result.is_err());
239    }
240
241    #[test]
242    fn error_on_duplicate_x() {
243        let result = CubicSpline::new(&[(0.0, 0.0), (0.5, 0.5), (0.5, 0.8), (1.0, 1.0)]);
244        assert!(result.is_err());
245    }
246
247    #[test]
248    fn error_on_single_point() {
249        let result = CubicSpline::new(&[(0.5, 0.5)]);
250        assert!(result.is_err());
251    }
252
253    #[test]
254    fn four_points_smoothness() {
255        let spline = CubicSpline::new(&[(0.0, 0.0), (0.25, 0.4), (0.75, 0.9), (1.0, 1.0)]).unwrap();
256        assert!((spline.evaluate(0.0) - 0.0).abs() < 1e-5);
257        assert!((spline.evaluate(0.25) - 0.4).abs() < 1e-5);
258        assert!((spline.evaluate(0.75) - 0.9).abs() < 1e-5);
259        assert!((spline.evaluate(1.0) - 1.0).abs() < 1e-5);
260        let mut prev = spline.evaluate(0.0);
261        for i in 1..=100 {
262            let x = i as f32 / 100.0;
263            let val = spline.evaluate(x);
264            assert!(
265                val >= prev - 1e-5,
266                "monotonicity broken at x={}: prev={}, val={}",
267                x,
268                prev,
269                val
270            );
271            prev = val;
272        }
273    }
274
275    #[test]
276    fn evaluate_batch_matches_pointwise_for_sorted_inputs() {
277        let spline = CubicSpline::new(&[(0.0, 0.0), (0.25, 0.4), (0.75, 0.9), (1.0, 1.0)]).unwrap();
278        let xs = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0];
279        let batch = spline.evaluate_batch(&xs);
280        let pointwise: Vec<f32> = xs.iter().map(|&x| spline.evaluate(x)).collect();
281        assert_eq!(batch.len(), pointwise.len());
282        for (actual, expected) in batch.iter().zip(pointwise.iter()) {
283            assert!((actual - expected).abs() < 1e-6);
284        }
285    }
286
287    #[test]
288    fn evaluate_batch_falls_back_for_unsorted_inputs() {
289        let spline = CubicSpline::new(&[(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)]).unwrap();
290        let xs = [0.75, 0.25, 1.0, -1.0, 0.5];
291        let batch = spline.evaluate_batch(&xs);
292        let pointwise: Vec<f32> = xs.iter().map(|&x| spline.evaluate(x)).collect();
293        assert_eq!(batch.len(), pointwise.len());
294        for (actual, expected) in batch.iter().zip(pointwise.iter()) {
295            assert!((actual - expected).abs() < 1e-6);
296        }
297    }
298}