convex_math/interpolation/
cubic_spline.rs1use crate::error::{MathError, MathResult};
4use crate::interpolation::Interpolator;
5
6#[derive(Debug, Clone)]
25pub struct CubicSpline {
26 xs: Vec<f64>,
27 ys: Vec<f64>,
28 y2s: Vec<f64>,
30 allow_extrapolation: bool,
31}
32
33impl CubicSpline {
34 pub fn new(xs: Vec<f64>, ys: Vec<f64>) -> MathResult<Self> {
45 if xs.len() < 3 {
46 return Err(MathError::insufficient_data(3, xs.len()));
47 }
48 if xs.len() != ys.len() {
49 return Err(MathError::invalid_input(format!(
50 "xs and ys must have same length: {} vs {}",
51 xs.len(),
52 ys.len()
53 )));
54 }
55
56 for i in 1..xs.len() {
58 if xs[i] <= xs[i - 1] {
59 return Err(MathError::invalid_input(
60 "x values must be strictly increasing",
61 ));
62 }
63 }
64
65 let y2s = compute_second_derivatives(&xs, &ys);
66
67 Ok(Self {
68 xs,
69 ys,
70 y2s,
71 allow_extrapolation: false,
72 })
73 }
74
75 #[must_use]
77 pub fn with_extrapolation(mut self) -> Self {
78 self.allow_extrapolation = true;
79 self
80 }
81
82 fn find_segment(&self, x: f64) -> usize {
84 match self
85 .xs
86 .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
87 {
88 Ok(i) => i.min(self.xs.len() - 2),
89 Err(i) => (i.saturating_sub(1)).min(self.xs.len() - 2),
90 }
91 }
92}
93
94impl Interpolator for CubicSpline {
95 #[allow(clippy::many_single_char_names)]
96 #[allow(clippy::similar_names)]
97 fn interpolate(&self, x: f64) -> MathResult<f64> {
98 if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
100 return Err(MathError::ExtrapolationNotAllowed {
101 x,
102 min: self.xs[0],
103 max: self.xs[self.xs.len() - 1],
104 });
105 }
106
107 let i = self.find_segment(x);
108
109 let x_lo = self.xs[i];
110 let x_hi = self.xs[i + 1];
111 let y_lo = self.ys[i];
112 let y_hi = self.ys[i + 1];
113 let y2_lo = self.y2s[i];
114 let y2_hi = self.y2s[i + 1];
115
116 let h = x_hi - x_lo;
117 let a = (x_hi - x) / h;
118 let b = (x - x_lo) / h;
119
120 let y = a * y_lo
122 + b * y_hi
123 + ((a * a * a - a) * y2_lo + (b * b * b - b) * y2_hi) * (h * h) / 6.0;
124
125 Ok(y)
126 }
127
128 #[allow(clippy::many_single_char_names)]
129 #[allow(clippy::similar_names)]
130 fn derivative(&self, x: f64) -> MathResult<f64> {
131 if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
133 return Err(MathError::ExtrapolationNotAllowed {
134 x,
135 min: self.xs[0],
136 max: self.xs[self.xs.len() - 1],
137 });
138 }
139
140 let i = self.find_segment(x);
141
142 let x_lo = self.xs[i];
143 let x_hi = self.xs[i + 1];
144 let y_lo = self.ys[i];
145 let y_hi = self.ys[i + 1];
146 let y2_lo = self.y2s[i];
147 let y2_hi = self.y2s[i + 1];
148
149 let h = x_hi - x_lo;
150 let a = (x_hi - x) / h;
151 let b = (x - x_lo) / h;
152
153 let dy = (y_hi - y_lo) / h - (3.0 * a * a - 1.0) / 6.0 * h * y2_lo
156 + (3.0 * b * b - 1.0) / 6.0 * h * y2_hi;
157
158 Ok(dy)
159 }
160
161 fn allows_extrapolation(&self) -> bool {
162 self.allow_extrapolation
163 }
164
165 fn min_x(&self) -> f64 {
166 self.xs[0]
167 }
168
169 fn max_x(&self) -> f64 {
170 self.xs[self.xs.len() - 1]
171 }
172}
173
174fn compute_second_derivatives(xs: &[f64], ys: &[f64]) -> Vec<f64> {
176 let n = xs.len();
177 let mut y2s = vec![0.0; n];
178 let mut u = vec![0.0; n - 1];
179
180 y2s[0] = 0.0;
182 u[0] = 0.0;
183
184 for i in 1..n - 1 {
186 let sig = (xs[i] - xs[i - 1]) / (xs[i + 1] - xs[i - 1]);
187 let p = sig * y2s[i - 1] + 2.0;
188 y2s[i] = (sig - 1.0) / p;
189 u[i] =
190 (ys[i + 1] - ys[i]) / (xs[i + 1] - xs[i]) - (ys[i] - ys[i - 1]) / (xs[i] - xs[i - 1]);
191 u[i] = (6.0 * u[i] / (xs[i + 1] - xs[i - 1]) - sig * u[i - 1]) / p;
192 }
193
194 y2s[n - 1] = 0.0;
196
197 for i in (0..n - 1).rev() {
199 y2s[i] = y2s[i] * y2s[i + 1] + u[i];
200 }
201
202 y2s
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use approx::assert_relative_eq;
209
210 #[test]
211 fn test_cubic_spline_through_points() {
212 let xs = vec![0.0, 1.0, 2.0, 3.0];
213 let ys = vec![0.0, 1.0, 4.0, 9.0];
214
215 let spline = CubicSpline::new(xs.clone(), ys.clone()).unwrap();
216
217 for (x, y) in xs.iter().zip(ys.iter()) {
219 assert_relative_eq!(spline.interpolate(*x).unwrap(), *y, epsilon = 1e-10);
220 }
221 }
222
223 #[test]
224 fn test_cubic_spline_smoothness() {
225 let xs = vec![0.0, 1.0, 2.0, 3.0, 4.0];
226 let ys = vec![0.0, 1.0, 0.0, 1.0, 0.0];
227
228 let spline = CubicSpline::new(xs, ys).unwrap();
229
230 let y = spline.interpolate(0.5).unwrap();
232 assert!(y > 0.0 && y < 1.5); }
234
235 #[test]
236 fn test_cubic_spline_extrapolation_error() {
237 let xs = vec![0.0, 1.0, 2.0, 3.0];
238 let ys = vec![0.0, 1.0, 4.0, 9.0];
239
240 let spline = CubicSpline::new(xs, ys).unwrap();
241
242 assert!(spline.interpolate(-0.5).is_err());
243 assert!(spline.interpolate(3.5).is_err());
244 }
245
246 #[test]
247 fn test_cubic_spline_extrapolation_enabled() {
248 let xs = vec![0.0, 1.0, 2.0, 3.0];
249 let ys = vec![0.0, 1.0, 4.0, 9.0];
250
251 let spline = CubicSpline::new(xs, ys).unwrap().with_extrapolation();
252
253 assert!(spline.interpolate(-0.5).is_ok());
255 assert!(spline.interpolate(3.5).is_ok());
256 }
257
258 #[test]
259 fn test_insufficient_points() {
260 let xs = vec![0.0, 1.0];
261 let ys = vec![0.0, 1.0];
262
263 assert!(CubicSpline::new(xs, ys).is_err());
265 }
266}