1#[derive(Debug, Clone)]
5pub enum SplineError {
6 NonAscendingX,
8 InsufficientPoints,
10}
11
12impl std::fmt::Display for SplineError {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 match self {
15 SplineError::NonAscendingX => {
16 write!(f, "control point x values must be strictly ascending")
17 }
18 SplineError::InsufficientPoints => {
19 write!(f, "spline requires at least 2 points")
20 }
21 }
22 }
23}
24
25impl std::error::Error for SplineError {}
26
27pub struct CubicSpline {
29 points: Vec<(f32, f32)>,
31 coefficients: Vec<[f32; 4]>,
33}
34
35impl CubicSpline {
36 pub fn new(points: &[(f32, f32)]) -> Result<Self, SplineError> {
37 if points.len() < 2 {
38 return Err(SplineError::InsufficientPoints);
39 }
40 for i in 1..points.len() {
41 if points[i].0 <= points[i - 1].0 {
42 return Err(SplineError::NonAscendingX);
43 }
44 }
45
46 if points.len() == 2 {
47 let (x0, y0) = points[0];
48 let (x1, y1) = points[1];
49 let slope = (y1 - y0) / (x1 - x0);
50 return Ok(Self {
51 points: points.to_vec(),
52 coefficients: vec![[y0, slope, 0.0, 0.0]],
53 });
54 }
55
56 let n = points.len() - 1;
57 let h: Vec<f32> = (0..n).map(|i| points[i + 1].0 - points[i].0).collect();
58 let f: Vec<f32> = (0..n)
59 .map(|i| (points[i + 1].1 - points[i].1) / h[i])
60 .collect();
61
62 let m = n - 1;
63 if m == 0 {
64 unreachable!();
65 }
66
67 let mut diag: Vec<f32> = Vec::with_capacity(m);
68 let mut sup: Vec<f32> = Vec::with_capacity(m);
69 let mut sub: Vec<f32> = Vec::with_capacity(m);
70 let mut rhs: Vec<f32> = Vec::with_capacity(m);
71
72 for i in 1..n {
73 let idx = i - 1;
74 diag.push(2.0 * (h[i - 1] + h[i]));
75 rhs.push(3.0 * (f[i] - f[i - 1]));
76 if idx > 0 {
77 sub.push(h[i - 1]);
78 }
79 if idx < m - 1 {
80 sup.push(h[i]);
81 }
82 }
83
84 for i in 1..m {
85 let factor = sub[i - 1] / diag[i - 1];
86 diag[i] -= factor * sup[i - 1];
87 rhs[i] -= factor * rhs[i - 1];
88 }
89
90 let mut c_inner = vec![0.0f32; m];
91 c_inner[m - 1] = rhs[m - 1] / diag[m - 1];
92 for i in (0..m - 1).rev() {
93 c_inner[i] = (rhs[i] - sup[i] * c_inner[i + 1]) / diag[i];
94 }
95
96 let mut c = vec![0.0f32; n + 1];
97 c[1..(m + 1)].copy_from_slice(&c_inner[..m]);
98
99 let mut coefficients = Vec::with_capacity(n);
100 for i in 0..n {
101 let a = points[i].1;
102 let b = f[i] - h[i] * (2.0 * c[i] + c[i + 1]) / 3.0;
103 let d = (c[i + 1] - c[i]) / (3.0 * h[i]);
104 coefficients.push([a, b, c[i], d]);
105 }
106
107 Ok(Self {
108 points: points.to_vec(),
109 coefficients,
110 })
111 }
112
113 pub fn evaluate(&self, x: f32) -> f32 {
114 self.evaluate_with_index(x, None).0
115 }
116
117 pub fn evaluate_batch(&self, xs: &[f32]) -> Vec<f32> {
122 if xs.is_empty() {
123 return Vec::new();
124 }
125 if xs.windows(2).all(|w| w[0] <= w[1]) {
126 let mut out = Vec::with_capacity(xs.len());
127 let mut segment = 0usize;
128 for &x in xs {
129 let (value, next_segment) = self.evaluate_with_index(x, Some(segment));
130 out.push(value);
131 segment = next_segment;
132 }
133 out
134 } else {
135 xs.iter().map(|&x| self.evaluate(x)).collect()
136 }
137 }
138
139 fn evaluate_with_index(&self, x: f32, start_segment: Option<usize>) -> (f32, usize) {
140 if x <= self.points[0].0 {
141 return (self.points[0].1, 0);
142 }
143 let last = self.points.len() - 1;
144 if x >= self.points[last].0 {
145 return (self.points[last].1, self.coefficients.len() - 1);
146 }
147
148 let segment = match start_segment {
149 Some(mut idx) if idx < self.coefficients.len() => {
150 while idx + 1 < self.points.len() && x >= self.points[idx + 1].0 {
151 idx += 1;
152 }
153 idx
154 }
155 _ => self.find_segment(x),
156 };
157
158 (self.evaluate_segment(segment, x), segment)
159 }
160
161 fn find_segment(&self, x: f32) -> usize {
162 let mut lo = 0;
163 let mut hi = self.points.len() - 1;
164 while lo < hi - 1 {
165 let mid = (lo + hi) / 2;
166 if x < self.points[mid].0 {
167 hi = mid;
168 } else {
169 lo = mid;
170 }
171 }
172 lo
173 }
174
175 fn evaluate_segment(&self, segment: usize, x: f32) -> f32 {
176 let dx = x - self.points[segment].0;
177 let [a, b, c, d] = self.coefficients[segment];
178 a + b * dx + c * dx * dx + d * dx * dx * dx
179 }
180}
181
182mod bezier;
183
184pub use bezier::BezierCubic;
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn linear_two_points() {
192 let spline = CubicSpline::new(&[(0.0, 0.0), (1.0, 1.0)]).unwrap();
193 assert!((spline.evaluate(0.5) - 0.5).abs() < 1e-6);
194 assert!((spline.evaluate(0.25) - 0.25).abs() < 1e-6);
195 assert!((spline.evaluate(0.75) - 0.75).abs() < 1e-6);
196 }
197
198 #[test]
199 fn three_points_passes_through_control_points() {
200 let points = [(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)];
201 let spline = CubicSpline::new(&points).unwrap();
202 for &(x, y) in &points {
203 assert!(
204 (spline.evaluate(x) - y).abs() < 1e-5,
205 "value mismatch at control point ({}, {}): {}",
206 x,
207 y,
208 spline.evaluate(x)
209 );
210 }
211 }
212
213 #[test]
214 fn three_points_interpolation() {
215 let spline = CubicSpline::new(&[(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)]).unwrap();
216 let val = spline.evaluate(0.25);
217 assert!(val > 0.0 && val < 0.8, "unexpected value {} at 0.25", val);
218 let val_mid = spline.evaluate(0.75);
219 assert!(
220 val_mid > 0.7 && val_mid < 1.1,
221 "unexpected value {} at 0.75",
222 val_mid
223 );
224 }
225
226 #[test]
227 fn clamp_outside_range() {
228 let spline = CubicSpline::new(&[(0.2, 0.3), (0.8, 0.9)]).unwrap();
229 assert!((spline.evaluate(0.0) - 0.3).abs() < 1e-6);
230 assert!((spline.evaluate(-1.0) - 0.3).abs() < 1e-6);
231 assert!((spline.evaluate(1.0) - 0.9).abs() < 1e-6);
232 assert!((spline.evaluate(10.0) - 0.9).abs() < 1e-6);
233 }
234
235 #[test]
236 fn error_on_non_ascending_x() {
237 let result = CubicSpline::new(&[(0.5, 0.0), (0.3, 1.0)]);
238 assert!(result.is_err());
239 }
240
241 #[test]
242 fn error_on_duplicate_x() {
243 let result = CubicSpline::new(&[(0.0, 0.0), (0.5, 0.5), (0.5, 0.8), (1.0, 1.0)]);
244 assert!(result.is_err());
245 }
246
247 #[test]
248 fn error_on_single_point() {
249 let result = CubicSpline::new(&[(0.5, 0.5)]);
250 assert!(result.is_err());
251 }
252
253 #[test]
254 fn four_points_smoothness() {
255 let spline = CubicSpline::new(&[(0.0, 0.0), (0.25, 0.4), (0.75, 0.9), (1.0, 1.0)]).unwrap();
256 assert!((spline.evaluate(0.0) - 0.0).abs() < 1e-5);
257 assert!((spline.evaluate(0.25) - 0.4).abs() < 1e-5);
258 assert!((spline.evaluate(0.75) - 0.9).abs() < 1e-5);
259 assert!((spline.evaluate(1.0) - 1.0).abs() < 1e-5);
260 let mut prev = spline.evaluate(0.0);
261 for i in 1..=100 {
262 let x = i as f32 / 100.0;
263 let val = spline.evaluate(x);
264 assert!(
265 val >= prev - 1e-5,
266 "monotonicity broken at x={}: prev={}, val={}",
267 x,
268 prev,
269 val
270 );
271 prev = val;
272 }
273 }
274
275 #[test]
276 fn evaluate_batch_matches_pointwise_for_sorted_inputs() {
277 let spline = CubicSpline::new(&[(0.0, 0.0), (0.25, 0.4), (0.75, 0.9), (1.0, 1.0)]).unwrap();
278 let xs = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0];
279 let batch = spline.evaluate_batch(&xs);
280 let pointwise: Vec<f32> = xs.iter().map(|&x| spline.evaluate(x)).collect();
281 assert_eq!(batch.len(), pointwise.len());
282 for (actual, expected) in batch.iter().zip(pointwise.iter()) {
283 assert!((actual - expected).abs() < 1e-6);
284 }
285 }
286
287 #[test]
288 fn evaluate_batch_falls_back_for_unsorted_inputs() {
289 let spline = CubicSpline::new(&[(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)]).unwrap();
290 let xs = [0.75, 0.25, 1.0, -1.0, 0.5];
291 let batch = spline.evaluate_batch(&xs);
292 let pointwise: Vec<f32> = xs.iter().map(|&x| spline.evaluate(x)).collect();
293 assert_eq!(batch.len(), pointwise.len());
294 for (actual, expected) in batch.iter().zip(pointwise.iter()) {
295 assert!((actual - expected).abs() < 1e-6);
296 }
297 }
298}