1use std::cell::Cell;
8
9pub struct Bspline {
16 num: usize,
17 x: Vec<f64>,
18 y: Vec<f64>,
19 am: Vec<f64>,
20 last_idx: Cell<i32>,
21}
22
23impl Bspline {
24 pub fn new() -> Self {
26 Self {
27 num: 0,
28 x: Vec::new(),
29 y: Vec::new(),
30 am: Vec::new(),
31 last_idx: Cell::new(-1),
32 }
33 }
34
35 pub fn new_with_points(x: &[f64], y: &[f64]) -> Self {
37 let mut s = Self::new();
38 s.init(x, y);
39 s
40 }
41
42 pub fn init_num(&mut self, max: usize) {
44 if max > 2 {
45 self.x.resize(max, 0.0);
46 self.y.resize(max, 0.0);
47 self.am.resize(max, 0.0);
48 }
49 self.num = 0;
50 self.last_idx.set(-1);
51 }
52
53 pub fn add_point(&mut self, x: f64, y: f64) {
55 if self.num < self.x.len() {
56 self.x[self.num] = x;
57 self.y[self.num] = y;
58 self.num += 1;
59 }
60 }
61
62 pub fn prepare(&mut self) {
64 if self.num > 2 {
65 let n1 = self.num;
66
67 for k in 0..n1 {
68 self.am[k] = 0.0;
69 }
70
71 let mut al = vec![0.0; 3 * n1];
72 let n1 = self.num - 1;
73 let mut d = self.x[1] - self.x[0];
74 let mut e = (self.y[1] - self.y[0]) / d;
75
76 for k in 1..n1 {
77 let h = d;
78 d = self.x[k + 1] - self.x[k];
79 let f = e;
80 e = (self.y[k + 1] - self.y[k]) / d;
81 al[k] = d / (d + h);
82 al[self.num + k] = 1.0 - al[k]; al[self.num * 2 + k] = 6.0 * (e - f) / (h + d); }
85
86 for k in 1..n1 {
87 let p = 1.0 / (al[self.num + k] * al[k - 1] + 2.0);
88 al[k] *= -p;
89 al[self.num * 2 + k] =
90 (al[self.num * 2 + k] - al[self.num + k] * al[self.num * 2 + k - 1]) * p;
91 }
92
93 self.am[n1] = 0.0;
94 al[n1 - 1] = al[self.num * 2 + n1 - 1];
95 self.am[n1 - 1] = al[n1 - 1];
96
97 let mut k = n1 as i32 - 2;
98 for _i in 0..self.num - 2 {
99 let ku = k as usize;
100 al[ku] = al[ku] * al[ku + 1] + al[self.num * 2 + ku];
101 self.am[ku] = al[ku];
102 k -= 1;
103 }
104 }
105 self.last_idx.set(-1);
106 }
107
108 pub fn init(&mut self, x: &[f64], y: &[f64]) {
110 let num = x.len().min(y.len());
111 if num > 2 {
112 self.init_num(num);
113 for i in 0..num {
114 self.add_point(x[i], y[i]);
115 }
116 self.prepare();
117 }
118 self.last_idx.set(-1);
119 }
120
121 pub fn get(&self, x: f64) -> f64 {
123 if self.num > 2 {
124 if x < self.x[0] {
125 return self.extrapolation_left(x);
126 }
127 if x >= self.x[self.num - 1] {
128 return self.extrapolation_right(x);
129 }
130 let i = self.bsearch(x);
131 return self.interpolation(x, i);
132 }
133 0.0
134 }
135
136 pub fn get_stateful(&self, x: f64) -> f64 {
138 if self.num > 2 {
139 if x < self.x[0] {
140 return self.extrapolation_left(x);
141 }
142 if x >= self.x[self.num - 1] {
143 return self.extrapolation_right(x);
144 }
145
146 let last = self.last_idx.get();
147 if last >= 0 {
148 let li = last as usize;
149 if x < self.x[li] || x > self.x[li + 1] {
150 if li < self.num - 2 && x >= self.x[li + 1] && x <= self.x[li + 2] {
151 self.last_idx.set(last + 1);
152 } else if li > 0 && x >= self.x[li - 1] && x <= self.x[li] {
153 self.last_idx.set(last - 1);
154 } else {
155 let i = self.bsearch(x);
156 self.last_idx.set(i as i32);
157 }
158 }
159 return self.interpolation(x, self.last_idx.get() as usize);
160 } else {
161 let i = self.bsearch(x);
162 self.last_idx.set(i as i32);
163 return self.interpolation(x, i);
164 }
165 }
166 0.0
167 }
168
169 fn bsearch(&self, x0: f64) -> usize {
171 let mut lo = 0usize;
172 let mut hi = self.num - 1;
173 while hi - lo > 1 {
174 let mid = (lo + hi) >> 1;
175 if x0 < self.x[mid] {
176 hi = mid;
177 } else {
178 lo = mid;
179 }
180 }
181 lo
182 }
183
184 fn interpolation(&self, x: f64, i: usize) -> f64 {
186 let j = i + 1;
187 let d = self.x[i] - self.x[j];
188 let h = x - self.x[j];
189 let r = self.x[i] - x;
190 let p = d * d / 6.0;
191 (self.am[j] * r * r * r + self.am[i] * h * h * h) / 6.0 / d
192 + ((self.y[j] - self.am[j] * p) * r + (self.y[i] - self.am[i] * p) * h) / d
193 }
194
195 fn extrapolation_left(&self, x: f64) -> f64 {
197 let d = self.x[1] - self.x[0];
198 (-d * self.am[1] / 6.0 + (self.y[1] - self.y[0]) / d) * (x - self.x[0]) + self.y[0]
199 }
200
201 fn extrapolation_right(&self, x: f64) -> f64 {
203 let d = self.x[self.num - 1] - self.x[self.num - 2];
204 (d * self.am[self.num - 2] / 6.0 + (self.y[self.num - 1] - self.y[self.num - 2]) / d)
205 * (x - self.x[self.num - 1])
206 + self.y[self.num - 1]
207 }
208}
209
210impl Default for Bspline {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_linear_interpolation() {
226 let x = [0.0, 1.0, 2.0, 3.0, 4.0];
228 let y = [0.0, 2.0, 4.0, 6.0, 8.0];
229 let s = Bspline::new_with_points(&x, &y);
230
231 assert!((s.get(0.0) - 0.0).abs() < 1e-6);
232 assert!((s.get(2.0) - 4.0).abs() < 1e-6);
233 assert!((s.get(4.0) - 8.0).abs() < 1e-6);
234 assert!((s.get(0.5) - 1.0).abs() < 0.1);
235 }
236
237 #[test]
238 fn test_data_points_exact() {
239 let x = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
241 let y = [0.0, 1.0, 0.0, -1.0, 0.0, 1.0];
242 let s = Bspline::new_with_points(&x, &y);
243
244 for i in 0..x.len() {
245 assert!(
246 (s.get(x[i]) - y[i]).abs() < 1e-6,
247 "at x={}, expected {}, got {}",
248 x[i],
249 y[i],
250 s.get(x[i])
251 );
252 }
253 }
254
255 #[test]
256 fn test_extrapolation_left() {
257 let x = [0.0, 1.0, 2.0, 3.0];
258 let y = [0.0, 1.0, 4.0, 9.0];
259 let s = Bspline::new_with_points(&x, &y);
260
261 let v1 = s.get(-1.0);
263 let v2 = s.get(-2.0);
264 let slope = v1 - s.get(0.0);
266 assert!((v2 - v1 - slope).abs() < 1e-6);
267 }
268
269 #[test]
270 fn test_extrapolation_right() {
271 let x = [0.0, 1.0, 2.0, 3.0];
272 let y = [0.0, 1.0, 4.0, 9.0];
273 let s = Bspline::new_with_points(&x, &y);
274
275 let v1 = s.get(4.0);
276 let v2 = s.get(5.0);
277 let slope = v1 - s.get(3.0);
278 assert!((v2 - v1 - slope).abs() < 1e-6);
279 }
280
281 #[test]
282 fn test_get_stateful_matches_get() {
283 let x = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
284 let y = [0.0, 0.5, 1.0, 1.5, 1.0, 0.0];
285 let s = Bspline::new_with_points(&x, &y);
286
287 for i in 0..50 {
289 let xi = i as f64 * 0.1;
290 let v1 = s.get(xi);
291 let v2 = s.get_stateful(xi);
292 assert!(
293 (v1 - v2).abs() < 1e-10,
294 "at x={}, get={}, get_stateful={}",
295 xi,
296 v1,
297 v2
298 );
299 }
300 }
301
302 #[test]
303 fn test_empty_spline() {
304 let s = Bspline::new();
305 assert_eq!(s.get(1.0), 0.0);
306 assert_eq!(s.get_stateful(1.0), 0.0);
307 }
308
309 #[test]
310 fn test_point_by_point_api() {
311 let mut s = Bspline::new();
312 s.init_num(4);
313 s.add_point(0.0, 0.0);
314 s.add_point(1.0, 1.0);
315 s.add_point(2.0, 0.0);
316 s.add_point(3.0, 1.0);
317 s.prepare();
318
319 assert!((s.get(0.0) - 0.0).abs() < 1e-6);
320 assert!((s.get(1.0) - 1.0).abs() < 1e-6);
321 assert!((s.get(2.0) - 0.0).abs() < 1e-6);
322 assert!((s.get(3.0) - 1.0).abs() < 1e-6);
323 }
324
325 #[test]
326 fn test_monotonic_in_monotonic_region() {
327 let x = [0.0, 1.0, 2.0, 3.0, 4.0];
329 let y = [0.0, 1.0, 2.0, 3.0, 4.0];
330 let s = Bspline::new_with_points(&x, &y);
331
332 let mut prev = s.get(0.0);
333 for i in 1..40 {
334 let xi = i as f64 * 0.1;
335 let v = s.get(xi);
336 assert!(v >= prev - 1e-6, "not monotonic at x={}", xi);
337 prev = v;
338 }
339 }
340}