cubic_splinterpol/
lib.rs

1//! A library for using cubic spline interpolation on no_std.
2
3#![deny(unsafe_code)]
4#![deny(missing_docs)]
5#![cfg_attr(not(test), no_std)]
6
7mod plot_spline;
8mod thomas_algorithm;
9
10/// The possible errors of this crate
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Error {
13    /// Slice of invalid length passed
14    InvalidSliceLength,
15}
16
17/// Given xs and ys of same length n, calculate the coefficients of n-1 cubic
18/// polynomials.
19pub fn splinterpol<const N: usize>(
20    xs: &[f32],
21    ys: &[f32],
22    coefficients: &mut [(f32, f32, f32, f32)],
23) -> Result<(), Error> {
24    // Array size const expression workaround
25    let mut diagonal = [0f32; N];
26    let mut diagonal = &mut diagonal[0..N - 2];
27
28    calc_diagonal::<N>(&xs, &mut diagonal).unwrap();
29
30    let mut r = [0f32; N];
31    let mut r = &mut r[0..N - 2];
32
33    if let Err(e) = calc_r::<N>(&xs, &ys, &mut r) {
34        return Err(e);
35    }
36
37    let mut sub_diagonal = [0f32; N];
38    let mut sub_diagonal = &mut sub_diagonal[0..N - 3];
39
40    if let Err(e) = calc_subdiagonal(&xs, &mut sub_diagonal) {
41        return Err(e);
42    }
43
44    let c = {
45        let mut c = [0f32; N];
46        let mut c_body = &mut c[1..N - 1];
47        if let Err(e) = thomas_algorithm::thomas_algorithm_symmetric(
48            &sub_diagonal,
49            &mut diagonal,
50            &mut r,
51            &mut c_body,
52        ) {
53            return Err(e);
54        }
55        c
56    };
57
58    let mut b = [0f32; N];
59    let mut b = &mut b[0..N - 1];
60
61    if let Err(e) = calc_b::<N>(&xs, &ys, &c, &mut b) {
62        return Err(e);
63    }
64
65    let mut d = [0f32; N];
66    let mut d = &mut d[0..N - 1];
67
68    if let Err(e) = calc_d::<N>(&xs, &c, &mut d) {
69        return Err(e);
70    }
71
72    for i in 0..N - 1 {
73        coefficients[i].0 = ys[i];
74        coefficients[i].1 = b[i];
75        coefficients[i].2 = c[i];
76        coefficients[i].3 = d[i];
77    }
78    Ok(())
79}
80
81fn calc_subdiagonal(vals: &[f32], sub: &mut [f32]) -> Result<(), Error> {
82    if vals.len() != sub.len() + 3 {
83        return Err(Error::InvalidSliceLength);
84    }
85    let n = vals.len();
86    for i in 0..(n - 3) {
87        sub[i] = vals[i + 2] - vals[i + 1];
88    }
89    Ok(())
90}
91
92fn cubic_spline(a: f32, b: f32, c: f32, d: f32, vec: &mut [f32], step_size: f32) {
93    for (i, v) in vec.iter_mut().enumerate() {
94        let x = i as f32 * step_size;
95        let value = a + b * x + c * (x * x) + d * (x * x * x);
96        *v = value;
97    }
98}
99
100fn h(i: usize, vals: &[f32]) -> f32 {
101    vals[i + 1] - vals[i]
102}
103
104fn calc_diagonal<const N: usize>(xs: &[f32], result: &mut [f32]) -> Result<(), Error> {
105    if xs.len() != N {
106        return Err(Error::InvalidSliceLength);
107    }
108    for i in 0..N - 2 {
109        result[i] = 2f32 * (h(i, &xs) + h(i + 1, &xs));
110    }
111    Ok(())
112}
113
114fn calc_r<const N: usize>(xs: &[f32], ys: &[f32], r: &mut [f32]) -> Result<(), Error> {
115    if r.len() != N - 2 {
116        return Err(Error::InvalidSliceLength);
117    }
118    if xs.len() != N {
119        return Err(Error::InvalidSliceLength);
120    }
121    if ys.len() != N {
122        return Err(Error::InvalidSliceLength);
123    }
124    for i in 0..N - 2 {
125        let div1 = (ys[i + 2] - ys[i + 1]) / (h(i + 1, &xs));
126        let div2 = (ys[i + 1] - ys[i]) / (h(i, &xs));
127        r[i] = 3f32 * (div1 - div2);
128    }
129    Ok(())
130}
131
132fn calc_b<const N: usize>(xs: &[f32], ys: &[f32], cs: &[f32], b: &mut [f32]) -> Result<(), Error> {
133    if cs.len() != N {
134        return Err(Error::InvalidSliceLength);
135    }
136    if b.len() != N - 1 {
137        return Err(Error::InvalidSliceLength);
138    }
139    for i in 0..N - 1 {
140        let div_1 = (ys[i + 1] - ys[i]) / (h(i, &xs));
141        let div_2 = (2f32 * cs[i] + cs[i + 1]) / 3f32;
142        b[i] = div_1 - div_2 * h(i, &xs);
143    }
144    Ok(())
145}
146
147fn calc_d<const N: usize>(xs: &[f32], cs: &[f32], d: &mut [f32]) -> Result<(), Error> {
148    if xs.len() != N {
149        return Err(Error::InvalidSliceLength);
150    }
151    if cs.len() != N {
152        return Err(Error::InvalidSliceLength);
153    }
154    if d.len() != N - 1 {
155        return Err(Error::InvalidSliceLength);
156    }
157    for i in 0..N - 1 {
158        d[i] = (cs[i + 1] - cs[i]) / (3f32 * h(i, &xs));
159    }
160    Ok(())
161}
162
163/// Plot given coefficients into the buffer according to the intervals given in xs
164pub fn plot_coeffs_into(
165    buffer: &mut [f32],
166    coefficients: &[(f32, f32, f32, f32)],
167    xs: &[f32],
168) -> Result<(), ()> {
169    let x_range = xs.last().unwrap() - xs.first().unwrap();
170    let step_size = x_range as f64 / buffer.len() as f64;
171    let mut current_index = 0;
172    for i in 0..coefficients.len() {
173        let range = xs[i + 1] - xs[i];
174        let ratio = range / x_range;
175        // f32::round not available in no_std
176        let buffer_ratio = {
177            let r = buffer.len() as f32 * ratio;
178            if r - ((r as u32) as f32) < 0.5 {
179                r as u32
180            } else {
181                r as u32 + 1
182            }
183        };
184        let mut upper = current_index + buffer_ratio as usize;
185        if upper >= buffer.len() {
186            upper = buffer.len()
187        };
188        let mut current_slice = &mut buffer[current_index..upper];
189        cubic_spline(
190            coefficients[i].0,
191            coefficients[i].1,
192            coefficients[i].2,
193            coefficients[i].3,
194            &mut current_slice,
195            step_size as f32,
196        );
197        current_index += buffer_ratio as usize;
198    }
199    Ok(())
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_plot_coeffs() {
208        let coeffs: [(f32, f32, f32, f32); 15] = [
209            (0.0, -0.16381307, 0.0, 0.6552523),
210            (0.0, 0.32762617, 0.98287845, -0.31050465),
211            (1.0, 1.3618692, 0.051364563, -0.41323376),
212            (2.0, 0.22489715, -1.1883367, 1.2848629),
213            (4.0, 5.3327103, 4.593546, -6.517933),
214            (7.0, 5.0378065, -5.1833534, 2.145547),
215            (9.0, 1.1077404, 1.2532874, -1.3610278),
216            (10.0, -0.46876848, -2.829796, 1.2985644),
217            (8.0, -2.2326672, 1.0658972, -0.83323),
218            (6.0, -2.6005628, -1.4337928, 1.0343556),
219            (3.0, -2.3650815, 1.669274, -0.35799825),
220            (2.0, 0.22625208, 0.05828173, -1.0215718),
221            (2.0, -0.48164505, -1.4740759, 0.9557209),
222            (1.0, -0.56263405, 1.3930869, -0.8304529),
223            (1.0, -0.26781887, -1.0982717, 0.36609057),
224        ];
225        let mut buffer = [0f32; 100];
226        let xs = [
227            0.5f32, 1f32, 2f32, 3f32, 4.5f32, 5f32, 6f32, 7f32, 8f32, 9f32, 10f32, 11.5f32, 12f32,
228            13f32, 14f32, 15f32,
229        ];
230        plot_coeffs_into(&mut buffer, &coeffs, &xs).unwrap();
231        dbg!(buffer);
232    }
233
234    #[test]
235    fn test_splinterpol() {
236        let xs = [
237            0.5f32, 1f32, 2f32, 3f32, 4.5f32, 5f32, 6f32, 7f32, 8f32, 9f32, 10f32, 11.5f32, 12f32,
238            13f32, 14f32, 15f32,
239        ];
240        let ys = [
241            0f32, 0f32, 1f32, 2f32, 4f32, 7f32, 9f32, 10f32, 8f32, 6f32, 3f32, 2f32, 2f32, 1f32,
242            1f32, 0f32,
243        ];
244        let mut coeffs = [(0f32, 0f32, 0f32, 0f32); 15];
245        splinterpol::<16>(&xs, &ys, &mut coeffs).unwrap();
246        let expected: [(f32, f32, f32, f32); 15] = [
247            (0.0, -0.16381307, 0.0, 0.6552523),
248            (0.0, 0.32762617, 0.98287845, -0.31050465),
249            (1.0, 1.3618692, 0.051364563, -0.41323376),
250            (2.0, 0.22489715, -1.1883367, 1.2848629),
251            (4.0, 5.3327103, 4.593546, -6.517933),
252            (7.0, 5.0378065, -5.1833534, 2.145547),
253            (9.0, 1.1077404, 1.2532874, -1.3610278),
254            (10.0, -0.46876848, -2.829796, 1.2985644),
255            (8.0, -2.2326672, 1.0658972, -0.83323),
256            (6.0, -2.6005628, -1.4337928, 1.0343556),
257            (3.0, -2.3650815, 1.669274, -0.35799825),
258            (2.0, 0.22625208, 0.05828173, -1.0215718),
259            (2.0, -0.48164505, -1.4740759, 0.9557209),
260            (1.0, -0.56263405, 1.3930869, -0.8304529),
261            (1.0, -0.26781887, -1.0982717, 0.36609057),
262        ];
263        assert_eq!(expected, coeffs);
264    }
265
266    #[test]
267    fn test_splinterpol_8x8() {
268        let xs = [0.5f32, 1f32, 2f32, 3f32, 4.5f32, 5f32, 6f32, 7f32];
269        let ys = [0f32, 0f32, 1f32, 2f32, 4f32, 7f32, 9f32, 10f32];
270        let mut coeffs = [(0f32, 0f32, 0f32, 0f32); 7];
271        splinterpol::<8>(&xs, &ys, &mut coeffs).unwrap();
272
273        let mut buffer = [0f32; 1000];
274        plot_coeffs_into(&mut buffer, &coeffs, &xs).unwrap();
275
276        let expected = [
277            (0.0, -0.16399321, 0.0, 0.65597284),
278            (0.0, 0.32798642, 0.98395926, -0.31194568),
279            (1.0, 1.360068, 0.048122242, -0.40819016),
280            (2.0, 0.2317419, -1.1764482, 1.273895),
281            (4.0, 5.301188, 4.5560794, -6.316911),
282            (7.0, 5.119584, -4.9192877, 1.7997031),
283            (9.0, 0.6801188, 0.47982174, -0.15994059),
284        ];
285        assert_eq!(expected, coeffs);
286    }
287
288    #[test]
289    fn plot_splinterpol() {
290        use plotters::prelude::*;
291
292        let xs = [
293            0.5f32, 1f32, 2f32, 3f32, 4.5f32, 5f32, 6f32, 7f32, 8f32, 9f32, 10f32, 11.5f32, 12f32,
294            13f32, 14f32, 15f32,
295        ];
296
297        let coeffs: [(f32, f32, f32, f32); 15] = [
298            (0.0, -0.16381307, 0.0, 0.6552523),
299            (0.0, 0.32762617, 0.98287845, -0.31050465),
300            (1.0, 1.3618692, 0.051364563, -0.41323376),
301            (2.0, 0.22489715, -1.1883367, 1.2848629),
302            (4.0, 5.3327103, 4.593546, -6.517933),
303            (7.0, 5.0378065, -5.1833534, 2.145547),
304            (9.0, 1.1077404, 1.2532874, -1.3610278),
305            (10.0, -0.46876848, -2.829796, 1.2985644),
306            (8.0, -2.2326672, 1.0658972, -0.83323),
307            (6.0, -2.6005628, -1.4337928, 1.0343556),
308            (3.0, -2.3650815, 1.669274, -0.35799825),
309            (2.0, 0.22625208, 0.05828173, -1.0215718),
310            (2.0, -0.48164505, -1.4740759, 0.9557209),
311            (1.0, -0.56263405, 1.3930869, -0.8304529),
312            (1.0, -0.26781887, -1.0982717, 0.36609057),
313        ];
314
315        let mut buffer = [0f32; 1000];
316        plot_coeffs_into(&mut buffer, &coeffs, &xs).unwrap();
317
318        let root = BitMapBackend::new("0.png", (640, 480)).into_drawing_area();
319        root.fill(&WHITE).unwrap();
320        let mut chart = ChartBuilder::on(&root)
321            .caption("spline", ("sans-serif", 50).into_font())
322            .margin(5)
323            .x_label_area_size(30)
324            .y_label_area_size(30)
325            .build_cartesian_2d(0f32..1000f32, 0.0f32..15f32)
326            .unwrap();
327
328        chart.configure_mesh().draw().unwrap();
329
330        chart
331            .draw_series(LineSeries::new(
332                buffer.iter().enumerate().map(|(i, v)| (i as f32, *v)),
333                &RED,
334            ))
335            .unwrap();
336
337        chart
338            .configure_series_labels()
339            .background_style(&WHITE.mix(0.8))
340            .border_style(&BLACK)
341            .draw()
342            .unwrap()
343    }
344
345    #[test]
346    fn test_calc_subdiagonal() {
347        let xs = [
348            0.5f32, 1f32, 2f32, 3f32, 4.5f32, 5f32, 6f32, 7f32, 8f32, 9f32, 10f32, 11.5f32, 12f32,
349            13f32, 14f32, 15f32,
350        ];
351        let mut sub = [0f32; 13];
352        calc_subdiagonal(&xs, &mut sub).unwrap();
353        let expected = [
354            1.0, 1.0, 1.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 0.5, 1.0, 1.0,
355        ];
356        assert_eq!(expected, sub);
357    }
358
359    #[test]
360    fn do_cubic_spline() {
361        let mut xs = [0f32; 64];
362        cubic_spline(4.0, 2.0, 2.0, 1.5, &mut xs, 0.05);
363        let expected = [
364            4.0, 4.1051874, 4.2215, 4.350063, 4.492, 4.6484375, 4.8205, 5.009312, 5.2160006,
365            5.4416876, 5.6875, 5.9545627, 6.244, 6.556938, 6.8945003, 7.2578125, 7.6480002,
366            8.066188, 8.5135, 8.991062, 9.5, 10.041438, 10.6165, 11.226313, 11.872001, 12.5546875,
367            13.275501, 14.035563, 14.836, 15.677938, 16.5625, 17.490814, 18.464, 19.483187,
368            20.5495, 21.664063, 22.828003, 24.042439, 25.3085, 26.627316, 28.0, 29.427689,
369            30.911507, 32.452568, 34.052002, 35.710938, 37.4305, 39.21182, 41.056, 42.964188,
370            44.9375, 46.97706, 49.084003, 51.25944, 53.5045, 55.820313, 58.208, 60.668694,
371            63.203506, 65.81357, 68.5, 71.26394, 74.10651, 77.028824,
372        ];
373        assert_eq!(expected, xs);
374    }
375
376    #[test]
377    fn diagonal_test() {
378        const N: usize = 16;
379        let mut xs = [0f32; N];
380        xs.iter_mut().enumerate().for_each(|(i, v)| {
381            *v = i as f32;
382        });
383        xs[4] = 4.5f32;
384        let mut diagonal = [0f32; N - 2];
385        calc_diagonal::<N>(&xs, &mut diagonal).unwrap();
386        let expected = [
387            4f32, 4f32, 5f32, 4f32, 3f32, 4f32, 4f32, 4f32, 4f32, 4f32, 4f32, 4f32, 4f32, 4f32,
388        ];
389        assert_eq!(expected, diagonal);
390    }
391
392    #[test]
393    fn diagonal_test_2() {
394        const N: usize = 16;
395        let mut xs = [0f32; N];
396        xs.iter_mut().enumerate().for_each(|(i, v)| {
397            *v = i as f32;
398        });
399
400        xs[0] = 0.5f32;
401        xs[4] = 4.5f32;
402        xs[11] = 11.5f32;
403
404        let mut diagonal = [0f32; N - 2];
405        calc_diagonal::<N>(&xs, &mut diagonal).unwrap();
406        let expected = [
407            3f32, 4f32, 5f32, 4f32, 3f32, 4f32, 4f32, 4f32, 4f32, 5f32, 4f32, 3f32, 4f32, 4f32,
408        ];
409        assert_eq!(expected, diagonal);
410    }
411
412    #[test]
413    fn calc_r_test() {
414        const N: usize = 16;
415        let mut xs = [0f32; N];
416        xs.iter_mut().enumerate().for_each(|(i, v)| {
417            *v = i as f32;
418        });
419
420        xs[0] = 0.5f32;
421        xs[4] = 4.5f32;
422        xs[11] = 11.5f32;
423
424        let ys = [
425            0f32, 0f32, 1f32, 2f32, 4f32, 7f32, 9f32, 10f32, 8f32, 6f32, 3f32, 2f32, 2f32, 1f32,
426            1f32, 0f32,
427        ];
428
429        let mut r = [0f32; N - 2];
430        calc_r::<N>(&xs, &ys, &mut r).unwrap();
431        let expected = [
432            3f32, 0f32, 1f32, 14f32, -12f32, -3f32, -9f32, 0f32, -3f32, 7f32, 2f32, -3f32, 3f32,
433            -3f32,
434        ];
435        for (r, expected) in (&r).iter().zip(&expected) {
436            assert!(r - expected < 0.0001);
437        }
438    }
439
440    #[test]
441    fn calc_b_test() {
442        const N: usize = 16;
443        let xs: [f32; N] = [
444            0.0, 1.0, 3.0, 6.0, 8.0, 9.0, 10.0, 12.0, 13.0, 14.0, 16.0, 17.0, 18.0, 19.0, 20.0,
445            21.0,
446        ];
447
448        let ys: [f32; N] = [
449            0.0, 1.0, -2.0, 4.0, 1.0, -1.0, 0.0, 0.0, 1.0, 2.0, 4.0, 5.0, 4.0, 3.0, 2.0, 0.0,
450        ];
451        let cs: [f32; N] = [
452            0.0, -1.8847, 1.9041, -1.5906, -0.15336, 2.6013, -1.2517, 0.95437, -0.22289, -0.062811,
453            0.29988, -1.6737, 0.39473, 0.094739, -0.77368, 0.0,
454        ];
455        let mut b = [0f32; N - 1];
456        calc_b::<N>(&xs, &ys, &cs, &mut b).unwrap();
457        let expected: [f32; N - 1] = [
458            1.6282333,
459            -0.25646675,
460            -0.21759987,
461            0.72303987,
462            -2.76486,
463            -0.31696665,
464            1.0326867,
465            0.43804997,
466            1.1695304,
467            0.883828,
468            1.35798,
469            -0.015776694,
470            -1.294733,
471            -0.805266,
472            -1.4842134,
473        ];
474        assert_eq!(expected, b);
475    }
476
477    #[test]
478    fn calc_d_test_1() {
479        const N: usize = 16;
480        let xs: [f32; N] = [
481            0.0, 1.0, 3.0, 6.0, 8.0, 9.0, 10.0, 12.0, 13.0, 14.0, 16.0, 17.0, 18.0, 19.0, 20.0,
482            21.0,
483        ];
484
485        let cs = [
486            0f32, -1.8847, 1.9041, -1.5906, -0.15336, 2.6013, -1.2517, 0.95437, -0.22289,
487            -0.062811, 0.29988, -1.6737, 0.39473, 0.094739, -0.77368, 0f32,
488        ];
489        let mut d = [0f32; N - 1];
490        calc_d::<N>(&xs, &cs, &mut d).unwrap();
491        let expected: [f32; N - 1] = [
492            -0.6282333,
493            0.6314666,
494            -0.3883,
495            0.23954,
496            0.91822,
497            -1.2843333,
498            0.3676783,
499            -0.39242002,
500            0.05335967,
501            0.060448498,
502            -0.65786,
503            0.68947667,
504            -0.09999701,
505            -0.289473,
506            0.25789332,
507        ];
508        assert_eq!(expected, d);
509    }
510
511    #[test]
512    fn calc_d_test_2() {
513        const N: usize = 16;
514        let xs: [f32; N] = [
515            0.5, 1.0, 2.0, 3.0, 4.5, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.5, 12.0, 13.0, 14.0, 15.0,
516        ];
517
518        let cs = [
519            0.0, 0.98288, 0.051365, -1.1883, 4.5935, -5.1834, 1.2533, -2.8298, 1.0659, -1.4338,
520            1.6693, 0.058282, -1.4741, 1.3931, -1.0983, 0.0,
521        ];
522
523        let mut d = [0f32; N - 1];
524        calc_d::<N>(&xs, &cs, &mut d).unwrap();
525        let expected: [f32; N - 1] = [
526            0.65525335,
527            -0.310505,
528            -0.4132217,
529            1.2848445,
530            -6.5179334,
531            2.1455667,
532            -1.3610333,
533            1.2985667,
534            -0.83323336,
535            1.0343666,
536            -0.35800397,
537            -1.021588,
538            0.9557333,
539            -0.8304667,
540            0.36609998,
541        ];
542        assert_eq!(expected, d);
543    }
544}