convex_math/interpolation/
cubic_spline.rs

1//! Natural cubic spline interpolation.
2
3use crate::error::{MathError, MathResult};
4use crate::interpolation::Interpolator;
5
6/// Natural cubic spline interpolation.
7///
8/// Constructs a smooth curve through data points using piecewise cubic
9/// polynomials with continuous first and second derivatives.
10///
11/// "Natural" means the second derivative is zero at the endpoints.
12///
13/// # Example
14///
15/// ```rust
16/// use convex_math::interpolation::{CubicSpline, Interpolator};
17///
18/// let xs = vec![0.0, 1.0, 2.0, 3.0];
19/// let ys = vec![0.0, 1.0, 4.0, 9.0];
20///
21/// let spline = CubicSpline::new(xs, ys).unwrap();
22/// let y = spline.interpolate(1.5).unwrap();
23/// ```
24#[derive(Debug, Clone)]
25pub struct CubicSpline {
26    xs: Vec<f64>,
27    ys: Vec<f64>,
28    /// Second derivatives at each knot
29    y2s: Vec<f64>,
30    allow_extrapolation: bool,
31}
32
33impl CubicSpline {
34    /// Creates a natural cubic spline interpolator.
35    ///
36    /// # Arguments
37    ///
38    /// * `xs` - X coordinates (must be sorted in ascending order)
39    /// * `ys` - Y coordinates
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if there are fewer than 3 points or if lengths differ.
44    pub fn new(xs: Vec<f64>, ys: Vec<f64>) -> MathResult<Self> {
45        if xs.len() < 3 {
46            return Err(MathError::insufficient_data(3, xs.len()));
47        }
48        if xs.len() != ys.len() {
49            return Err(MathError::invalid_input(format!(
50                "xs and ys must have same length: {} vs {}",
51                xs.len(),
52                ys.len()
53            )));
54        }
55
56        // Check that xs are sorted
57        for i in 1..xs.len() {
58            if xs[i] <= xs[i - 1] {
59                return Err(MathError::invalid_input(
60                    "x values must be strictly increasing",
61                ));
62            }
63        }
64
65        let y2s = compute_second_derivatives(&xs, &ys);
66
67        Ok(Self {
68            xs,
69            ys,
70            y2s,
71            allow_extrapolation: false,
72        })
73    }
74
75    /// Enables extrapolation beyond the data range.
76    #[must_use]
77    pub fn with_extrapolation(mut self) -> Self {
78        self.allow_extrapolation = true;
79        self
80    }
81
82    /// Finds the index i such that xs[i] <= x < xs[i+1].
83    fn find_segment(&self, x: f64) -> usize {
84        match self
85            .xs
86            .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
87        {
88            Ok(i) => i.min(self.xs.len() - 2),
89            Err(i) => (i.saturating_sub(1)).min(self.xs.len() - 2),
90        }
91    }
92}
93
94impl Interpolator for CubicSpline {
95    #[allow(clippy::many_single_char_names)]
96    #[allow(clippy::similar_names)]
97    fn interpolate(&self, x: f64) -> MathResult<f64> {
98        // Check bounds
99        if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
100            return Err(MathError::ExtrapolationNotAllowed {
101                x,
102                min: self.xs[0],
103                max: self.xs[self.xs.len() - 1],
104            });
105        }
106
107        let i = self.find_segment(x);
108
109        let x_lo = self.xs[i];
110        let x_hi = self.xs[i + 1];
111        let y_lo = self.ys[i];
112        let y_hi = self.ys[i + 1];
113        let y2_lo = self.y2s[i];
114        let y2_hi = self.y2s[i + 1];
115
116        let h = x_hi - x_lo;
117        let a = (x_hi - x) / h;
118        let b = (x - x_lo) / h;
119
120        // Cubic spline formula
121        let y = a * y_lo
122            + b * y_hi
123            + ((a * a * a - a) * y2_lo + (b * b * b - b) * y2_hi) * (h * h) / 6.0;
124
125        Ok(y)
126    }
127
128    #[allow(clippy::many_single_char_names)]
129    #[allow(clippy::similar_names)]
130    fn derivative(&self, x: f64) -> MathResult<f64> {
131        // Check bounds
132        if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
133            return Err(MathError::ExtrapolationNotAllowed {
134                x,
135                min: self.xs[0],
136                max: self.xs[self.xs.len() - 1],
137            });
138        }
139
140        let i = self.find_segment(x);
141
142        let x_lo = self.xs[i];
143        let x_hi = self.xs[i + 1];
144        let y_lo = self.ys[i];
145        let y_hi = self.ys[i + 1];
146        let y2_lo = self.y2s[i];
147        let y2_hi = self.y2s[i + 1];
148
149        let h = x_hi - x_lo;
150        let a = (x_hi - x) / h;
151        let b = (x - x_lo) / h;
152
153        // Derivative of cubic spline formula
154        // dy/dx = (y_hi - y_lo)/h - (3*a^2 - 1)/6 * h * y2_lo + (3*b^2 - 1)/6 * h * y2_hi
155        let dy = (y_hi - y_lo) / h - (3.0 * a * a - 1.0) / 6.0 * h * y2_lo
156            + (3.0 * b * b - 1.0) / 6.0 * h * y2_hi;
157
158        Ok(dy)
159    }
160
161    fn allows_extrapolation(&self) -> bool {
162        self.allow_extrapolation
163    }
164
165    fn min_x(&self) -> f64 {
166        self.xs[0]
167    }
168
169    fn max_x(&self) -> f64 {
170        self.xs[self.xs.len() - 1]
171    }
172}
173
174/// Computes the second derivatives for natural cubic spline.
175fn compute_second_derivatives(xs: &[f64], ys: &[f64]) -> Vec<f64> {
176    let n = xs.len();
177    let mut y2s = vec![0.0; n];
178    let mut u = vec![0.0; n - 1];
179
180    // Natural spline: y2[0] = 0
181    y2s[0] = 0.0;
182    u[0] = 0.0;
183
184    // Decomposition loop
185    for i in 1..n - 1 {
186        let sig = (xs[i] - xs[i - 1]) / (xs[i + 1] - xs[i - 1]);
187        let p = sig * y2s[i - 1] + 2.0;
188        y2s[i] = (sig - 1.0) / p;
189        u[i] =
190            (ys[i + 1] - ys[i]) / (xs[i + 1] - xs[i]) - (ys[i] - ys[i - 1]) / (xs[i] - xs[i - 1]);
191        u[i] = (6.0 * u[i] / (xs[i + 1] - xs[i - 1]) - sig * u[i - 1]) / p;
192    }
193
194    // Natural spline: y2[n-1] = 0
195    y2s[n - 1] = 0.0;
196
197    // Back-substitution loop
198    for i in (0..n - 1).rev() {
199        y2s[i] = y2s[i] * y2s[i + 1] + u[i];
200    }
201
202    y2s
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use approx::assert_relative_eq;
209
210    #[test]
211    fn test_cubic_spline_through_points() {
212        let xs = vec![0.0, 1.0, 2.0, 3.0];
213        let ys = vec![0.0, 1.0, 4.0, 9.0];
214
215        let spline = CubicSpline::new(xs.clone(), ys.clone()).unwrap();
216
217        // Should pass through all data points
218        for (x, y) in xs.iter().zip(ys.iter()) {
219            assert_relative_eq!(spline.interpolate(*x).unwrap(), *y, epsilon = 1e-10);
220        }
221    }
222
223    #[test]
224    fn test_cubic_spline_smoothness() {
225        let xs = vec![0.0, 1.0, 2.0, 3.0, 4.0];
226        let ys = vec![0.0, 1.0, 0.0, 1.0, 0.0];
227
228        let spline = CubicSpline::new(xs, ys).unwrap();
229
230        // Check that interpolation produces reasonable values
231        let y = spline.interpolate(0.5).unwrap();
232        assert!(y > 0.0 && y < 1.5); // Should be near the data
233    }
234
235    #[test]
236    fn test_cubic_spline_extrapolation_error() {
237        let xs = vec![0.0, 1.0, 2.0, 3.0];
238        let ys = vec![0.0, 1.0, 4.0, 9.0];
239
240        let spline = CubicSpline::new(xs, ys).unwrap();
241
242        assert!(spline.interpolate(-0.5).is_err());
243        assert!(spline.interpolate(3.5).is_err());
244    }
245
246    #[test]
247    fn test_cubic_spline_extrapolation_enabled() {
248        let xs = vec![0.0, 1.0, 2.0, 3.0];
249        let ys = vec![0.0, 1.0, 4.0, 9.0];
250
251        let spline = CubicSpline::new(xs, ys).unwrap().with_extrapolation();
252
253        // Should allow extrapolation
254        assert!(spline.interpolate(-0.5).is_ok());
255        assert!(spline.interpolate(3.5).is_ok());
256    }
257
258    #[test]
259    fn test_insufficient_points() {
260        let xs = vec![0.0, 1.0];
261        let ys = vec![0.0, 1.0];
262
263        // Cubic spline needs at least 3 points
264        assert!(CubicSpline::new(xs, ys).is_err());
265    }
266}