1use std::sync::Arc;
8
9use fast_polynomial::poly_array;
10use thiserror::Error;
11
12use crate::math::slices::Monotonic;
13
14use super::linear_algebra::tridiagonal::Tridiagonal;
15use super::slices::Diff;
16
17const MIN_POINTS_LINEAR: usize = 2;
18const MIN_POINTS_SPLINE: usize = 4;
19
20#[derive(Clone, Debug, Error, PartialEq)]
22pub enum SeriesError {
23 #[error("`x` and `y` must have the same length but were {0} and {1}")]
25 DimensionMismatch(usize, usize),
26 #[error("length of `x` and `y` must at least 2 but was {0}")]
28 InsufficientPoints(usize),
29 #[error("x-axis must be strictly monotonic")]
31 NonMonotonic,
32}
33
34#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub enum Interpolation {
38 Linear,
40 CubicSpline(Arc<[[f64; 4]]>),
42}
43
44#[derive(Clone, Debug, PartialEq)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct Series {
48 x: Arc<[f64]>,
49 y: Arc<[f64]>,
50 interpolation: Interpolation,
51}
52
53pub enum InterpolationType {
55 Linear,
57 CubicSpline,
59}
60
61impl Series {
62 pub fn try_new(
64 x: impl Into<Arc<[f64]>>,
65 y: impl Into<Arc<[f64]>>,
66 interpolation: InterpolationType,
67 ) -> Result<Self, SeriesError> {
68 let x: Arc<[f64]> = x.into();
69 let y: Arc<[f64]> = y.into();
70
71 Self::check(&x, &y)?;
72
73 Ok(Self::new(x, y, interpolation))
74 }
75
76 pub fn new(
78 x: impl Into<Arc<[f64]>>,
79 y: impl Into<Arc<[f64]>>,
80 interpolation: InterpolationType,
81 ) -> Self {
82 let x: Arc<[f64]> = x.into();
83 let y: Arc<[f64]> = y.into();
84
85 Self::assert(&x, &y);
86
87 match interpolation {
88 InterpolationType::Linear => Self::linear(x, y),
89 InterpolationType::CubicSpline => {
90 let n = x.len();
91 if n < MIN_POINTS_SPLINE {
92 Self::linear(x, y)
93 } else {
94 Self::cubic_spline(x, y)
95 }
96 }
97 }
98 }
99
100 fn linear(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
101 Self {
102 x,
103 y,
104 interpolation: Interpolation::Linear,
105 }
106 }
107
108 fn cubic_spline(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
109 let n = x.len();
110
111 let dx = x.diff();
112 let nd = dx.len();
113 let slope: Vec<f64> = y
114 .diff()
115 .iter()
116 .enumerate()
117 .map(|(idx, y)| y / dx[idx])
118 .collect();
119
120 let mut d: Vec<f64> = dx[0..nd - 1]
121 .iter()
122 .enumerate()
123 .map(|(idx, dxi)| 2.0 * (dxi + dx[idx + 1]))
124 .collect();
125 let mut du: Vec<f64> = dx[0..nd - 1].to_vec();
126 let mut dl: Vec<f64> = dx[1..].to_vec();
127 let mut b: Vec<f64> = dx[0..nd - 1]
128 .iter()
129 .enumerate()
130 .map(|(idx, dxi)| 3.0 * (dx[idx + 1] * slope[idx] + dxi * slope[idx + 1]))
131 .collect();
132
133 d.insert(0, dx[1]);
135 du.insert(0, x[2] - x[0]);
136 let delta = x[2] - x[0];
137 b.insert(
138 0,
139 ((dx[0] + 2.0 * delta) * dx[1] * slope[0] + dx[0].powi(2) * slope[1]) / delta,
140 );
141 d.push(dx[nd - 2]);
142 let delta = x[n - 1] - x[n - 3];
143 dl.push(delta);
144 b.push(
145 (dx[nd - 1].powi(2) * slope[nd - 2]
146 + (2.0 * delta + dx[nd - 1]) * dx[nd - 2] * slope[nd - 1])
147 / delta,
148 );
149
150 let tri = Tridiagonal::new(&dl, &d, &du).unwrap_or_else(|err| {
151 unreachable!(
152 "dimensions should be correct for tridiagonal system: {}",
153 err
154 )
155 });
156 let s = tri.solve(&b);
157 let t: Vec<f64> = s[0..n - 1]
158 .iter()
159 .enumerate()
160 .map(|(idx, si)| (si + s[idx + 1] - 2.0 * slope[idx]) / dx[idx])
161 .collect();
162
163 let coeffs: Vec<[f64; 4]> = (0..n - 1)
164 .map(|i| {
165 let c1 = y[i];
166 let c2 = s[i];
167 let c3 = (slope[i] - s[i]) / dx[i] - t[i];
168 let c4 = t[i] / dx[i];
169 [c1, c2, c3, c4]
170 })
171 .collect();
172
173 Self {
174 x,
175 y,
176 interpolation: Interpolation::CubicSpline(coeffs.into()),
177 }
178 }
179
180 #[inline]
182 pub fn find_index(&self, xp: f64) -> usize {
183 let x = self.x.as_ref();
184 let x0 = *x.first().unwrap();
185 let xn = *x.last().unwrap();
186 if xp <= x0 {
187 0
188 } else if xp >= xn {
189 x.len() - 2
190 } else {
191 x.partition_point(|&val| xp > val) - 1
192 }
193 }
194
195 #[inline]
197 pub fn interpolate_at_index(&self, xp: f64, idx: usize) -> f64 {
198 match &self.interpolation {
199 Interpolation::Linear => {
200 let x = self.x.as_ref();
201 let y = self.y.as_ref();
202 let x0 = x[idx];
203 let x1 = x[idx + 1];
204 let y0 = y[idx];
205 let y1 = y[idx + 1];
206 y0 + (y1 - y0) * (xp - x0) / (x1 - x0)
207 }
208 Interpolation::CubicSpline(coeffs) => poly_array(xp - self.x[idx], &coeffs[idx]),
209 }
210 }
211
212 #[inline]
214 pub fn interpolate(&self, xp: f64) -> f64 {
215 let idx = self.find_index(xp);
216 self.interpolate_at_index(xp, idx)
217 }
218
219 pub fn x(&self) -> &[f64] {
221 self.x.as_ref()
222 }
223
224 pub fn y(&self) -> &[f64] {
226 self.y.as_ref()
227 }
228
229 pub fn first(&self) -> (f64, f64) {
231 (*self.x().first().unwrap(), *self.y().first().unwrap())
232 }
233
234 pub fn last(&self) -> (f64, f64) {
236 (*self.x().last().unwrap(), *self.y().last().unwrap())
237 }
238
239 fn check(x: &[f64], y: &[f64]) -> Result<(), SeriesError> {
240 if !x.is_strictly_increasing() {
241 return Err(SeriesError::NonMonotonic);
242 }
243
244 let n = x.len();
245
246 if y.len() != n {
247 return Err(SeriesError::DimensionMismatch(n, y.len()));
248 }
249
250 if n < MIN_POINTS_LINEAR {
251 return Err(SeriesError::InsufficientPoints(n));
252 }
253 Ok(())
254 }
255
256 fn assert(x: &[f64], y: &[f64]) {
257 assert!(x.is_strictly_increasing());
258
259 let n = x.len();
260 assert!(y.len() == n);
261 assert!(n >= MIN_POINTS_LINEAR);
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use rstest::rstest;
268
269 use lox_test_utils::assert_approx_eq;
270
271 use super::*;
272
273 #[rstest]
274 #[case(0.5, 0.5)]
275 #[case(1.0, 1.0)]
276 #[case(1.5, 1.5)]
277 #[case(2.5, 2.5)]
278 #[case(5.5, 5.5)]
279 fn test_series_linear(#[case] xp: f64, #[case] expected: f64) {
280 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
281 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
282
283 let s = Series::try_new(x, y, InterpolationType::Linear).unwrap();
284 let actual = s.interpolate(xp);
285 assert_eq!(actual, expected);
286 }
287
288 #[rstest]
290 #[case(0.0, -14.303290471048534)]
291 #[case(0.1, -12.036932976759344)]
292 #[case(0.2, -9.978070560771739)]
293 #[case(0.3, -8.117883404355377)]
294 #[case(0.4, -6.447551688779917)]
295 #[case(0.5, -4.958255595315013)]
296 #[case(0.6, -3.6411753052303184)]
297 #[case(0.7, -2.487490999795493)]
298 #[case(0.8, -1.4883828602801898)]
299 #[case(0.9, -0.6350310679540686)]
300 #[case(1.0, 0.08138419591321655)]
301 #[case(1.1, 0.6696827500520098)]
302 #[case(1.2, 1.1386844131926532)]
303 #[case(1.3, 1.4972090040654928)]
304 #[case(1.4, 1.754076341400871)]
305 #[case(1.5, 1.9181062439291328)]
306 #[case(1.6, 1.9981185303806206)]
307 #[case(1.7, 2.002933019485679)]
308 #[case(1.8, 1.9413695299746523)]
309 #[case(1.9, 1.8222478805778837)]
310 #[case(2.0, 1.6543878900257172)]
311 #[case(2.1, 1.4466093770484965)]
312 #[case(2.2, 1.2077321603765656)]
313 #[case(2.3, 0.9465760587402696)]
314 #[case(2.4, 0.6719608908699499)]
315 #[case(2.5, 0.3927064754959517)]
316 #[case(2.6, 0.11763263134861876)]
317 #[case(2.7, -0.14444082284170534)]
318 #[case(2.8, -0.384694068344675)]
319 #[case(2.9, -0.5943072864299493)]
320 #[case(3.0, -0.7644606583671828)]
321 #[case(3.1, -0.8886377407066958)]
322 #[case(3.2, -0.9695355911214641)]
323 #[case(3.3, -1.012154642565128)]
324 #[case(3.4, -1.021495327991328)]
325 #[case(3.5, -1.0025580803537035)]
326 #[case(3.6, -0.960343332605895)]
327 #[case(3.7, -0.8998515177015425)]
328 #[case(3.8, -0.8260830685942864)]
329 #[case(3.9, -0.744038418237766)]
330 #[case(4.0, -0.6587179995856219)]
331 #[case(4.1, -0.5751222455914945)]
332 #[case(4.2, -0.4982515892090227)]
333 #[case(4.3, -0.433106463391848)]
334 #[case(4.4, -0.38468730109360944)]
335 #[case(4.5, -0.3579945352679478)]
336 #[case(4.6, -0.3580285988685027)]
337 #[case(4.7, -0.3897899248489146)]
338 #[case(4.8, -0.458278946162823)]
339 #[case(4.9, -0.5684960957638693)]
340 #[case(5.0, -0.7254418066056914)]
341 #[case(5.1, -0.9341165116419302)]
342 #[case(5.2, -1.1995206438262285)]
343 #[case(5.3, -1.5266546361122217)]
344 #[case(5.4, -1.9205189214535554)]
345 #[case(5.5, -2.3861139328038625)]
346 #[case(5.6, -2.9284401031167873)]
347 #[case(5.7, -3.5524978653459742)]
348 #[case(5.8, -4.263287652445054)]
349 #[case(5.9, -5.065809897367678)]
350 #[case(6.0, -5.965065033067472)]
351 fn test_series_spline(#[case] xp: f64, #[case] expected: f64) {
352 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
353 let y = vec![
354 0.08138419591321655,
355 1.6543878900257172,
356 -0.7644606583671828,
357 -0.6587179995856219,
358 -0.7254418066056914,
359 ];
360
361 let s = Series::try_new(x, y, InterpolationType::CubicSpline).unwrap();
362 let actual = s.interpolate(xp);
363 assert_approx_eq!(actual, expected, rtol <= 1e-12);
364 }
365
366 #[rstest]
367 #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::InsufficientPoints(1)))]
368 #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::InsufficientPoints(1)))]
369 #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::DimensionMismatch(2, 1)))]
370 #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::DimensionMismatch(2, 1)))]
371 fn test_series_errors(
372 #[case] actual: Result<Series, SeriesError>,
373 #[case] expected: Result<Series, SeriesError>,
374 ) {
375 assert_eq!(actual, expected);
376 }
377}