Skip to main content

lox_core/math/
series.rs

1// SPDX-FileCopyrightText: 2024 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5//! Interpolated data series with linear and cubic spline support.
6
7use std::sync::Arc;
8
9use fast_polynomial::poly_array;
10use thiserror::Error;
11
12use crate::math::slices::Monotonic;
13
14use super::linear_algebra::tridiagonal::Tridiagonal;
15use super::slices::Diff;
16
17const MIN_POINTS_LINEAR: usize = 2;
18const MIN_POINTS_SPLINE: usize = 4;
19
20/// Error returned when constructing a [`Series`] with invalid data.
21#[derive(Clone, Debug, Error, PartialEq)]
22pub enum SeriesError {
23    /// The `x` and `y` arrays have different lengths.
24    #[error("`x` and `y` must have the same length but were {0} and {1}")]
25    DimensionMismatch(usize, usize),
26    /// Fewer than 2 data points were provided.
27    #[error("length of `x` and `y` must at least 2 but was {0}")]
28    InsufficientPoints(usize),
29    /// The x-axis is not strictly monotonically increasing.
30    #[error("x-axis must be strictly monotonic")]
31    NonMonotonic,
32}
33
34/// The interpolation method and its precomputed coefficients.
35#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub enum Interpolation {
38    /// Linear interpolation between data points.
39    Linear,
40    /// Cubic spline interpolation with precomputed polynomial coefficients.
41    CubicSpline(Arc<[[f64; 4]]>),
42}
43
44/// An interpolated 1-D data series.
45#[derive(Clone, Debug, PartialEq)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub struct Series {
48    x: Arc<[f64]>,
49    y: Arc<[f64]>,
50    interpolation: Interpolation,
51}
52
53/// Selects the interpolation method for a [`Series`].
54pub enum InterpolationType {
55    /// Linear interpolation.
56    Linear,
57    /// Natural cubic spline interpolation (falls back to linear if fewer than 4 points).
58    CubicSpline,
59}
60
61impl Series {
62    /// Creates a new series, returning an error if the data is invalid.
63    pub fn try_new(
64        x: impl Into<Arc<[f64]>>,
65        y: impl Into<Arc<[f64]>>,
66        interpolation: InterpolationType,
67    ) -> Result<Self, SeriesError> {
68        let x: Arc<[f64]> = x.into();
69        let y: Arc<[f64]> = y.into();
70
71        Self::check(&x, &y)?;
72
73        Ok(Self::new(x, y, interpolation))
74    }
75
76    /// Creates a new series, panicking if the data is invalid.
77    pub fn new(
78        x: impl Into<Arc<[f64]>>,
79        y: impl Into<Arc<[f64]>>,
80        interpolation: InterpolationType,
81    ) -> Self {
82        let x: Arc<[f64]> = x.into();
83        let y: Arc<[f64]> = y.into();
84
85        Self::assert(&x, &y);
86
87        match interpolation {
88            InterpolationType::Linear => Self::linear(x, y),
89            InterpolationType::CubicSpline => {
90                let n = x.len();
91                if n < MIN_POINTS_SPLINE {
92                    Self::linear(x, y)
93                } else {
94                    Self::cubic_spline(x, y)
95                }
96            }
97        }
98    }
99
100    fn linear(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
101        Self {
102            x,
103            y,
104            interpolation: Interpolation::Linear,
105        }
106    }
107
108    fn cubic_spline(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
109        let n = x.len();
110
111        let dx = x.diff();
112        let nd = dx.len();
113        let slope: Vec<f64> = y
114            .diff()
115            .iter()
116            .enumerate()
117            .map(|(idx, y)| y / dx[idx])
118            .collect();
119
120        let mut d: Vec<f64> = dx[0..nd - 1]
121            .iter()
122            .enumerate()
123            .map(|(idx, dxi)| 2.0 * (dxi + dx[idx + 1]))
124            .collect();
125        let mut du: Vec<f64> = dx[0..nd - 1].to_vec();
126        let mut dl: Vec<f64> = dx[1..].to_vec();
127        let mut b: Vec<f64> = dx[0..nd - 1]
128            .iter()
129            .enumerate()
130            .map(|(idx, dxi)| 3.0 * (dx[idx + 1] * slope[idx] + dxi * slope[idx + 1]))
131            .collect();
132
133        // Not-a-knot boundary condition
134        d.insert(0, dx[1]);
135        du.insert(0, x[2] - x[0]);
136        let delta = x[2] - x[0];
137        b.insert(
138            0,
139            ((dx[0] + 2.0 * delta) * dx[1] * slope[0] + dx[0].powi(2) * slope[1]) / delta,
140        );
141        d.push(dx[nd - 2]);
142        let delta = x[n - 1] - x[n - 3];
143        dl.push(delta);
144        b.push(
145            (dx[nd - 1].powi(2) * slope[nd - 2]
146                + (2.0 * delta + dx[nd - 1]) * dx[nd - 2] * slope[nd - 1])
147                / delta,
148        );
149
150        let tri = Tridiagonal::new(&dl, &d, &du).unwrap_or_else(|err| {
151            unreachable!(
152                "dimensions should be correct for tridiagonal system: {}",
153                err
154            )
155        });
156        let s = tri.solve(&b);
157        let t: Vec<f64> = s[0..n - 1]
158            .iter()
159            .enumerate()
160            .map(|(idx, si)| (si + s[idx + 1] - 2.0 * slope[idx]) / dx[idx])
161            .collect();
162
163        let coeffs: Vec<[f64; 4]> = (0..n - 1)
164            .map(|i| {
165                let c1 = y[i];
166                let c2 = s[i];
167                let c3 = (slope[i] - s[i]) / dx[i] - t[i];
168                let c4 = t[i] / dx[i];
169                [c1, c2, c3, c4]
170            })
171            .collect();
172
173        Self {
174            x,
175            y,
176            interpolation: Interpolation::CubicSpline(coeffs.into()),
177        }
178    }
179
180    /// Finds the interval index for interpolation at point `xp`.
181    #[inline]
182    pub fn find_index(&self, xp: f64) -> usize {
183        let x = self.x.as_ref();
184        let x0 = *x.first().unwrap();
185        let xn = *x.last().unwrap();
186        if xp <= x0 {
187            0
188        } else if xp >= xn {
189            x.len() - 2
190        } else {
191            x.partition_point(|&val| xp > val) - 1
192        }
193    }
194
195    /// Interpolates at point `xp` using the precomputed interval `idx`.
196    #[inline]
197    pub fn interpolate_at_index(&self, xp: f64, idx: usize) -> f64 {
198        match &self.interpolation {
199            Interpolation::Linear => {
200                let x = self.x.as_ref();
201                let y = self.y.as_ref();
202                let x0 = x[idx];
203                let x1 = x[idx + 1];
204                let y0 = y[idx];
205                let y1 = y[idx + 1];
206                y0 + (y1 - y0) * (xp - x0) / (x1 - x0)
207            }
208            Interpolation::CubicSpline(coeffs) => poly_array(xp - self.x[idx], &coeffs[idx]),
209        }
210    }
211
212    /// Interpolates the series at point `xp`.
213    #[inline]
214    pub fn interpolate(&self, xp: f64) -> f64 {
215        let idx = self.find_index(xp);
216        self.interpolate_at_index(xp, idx)
217    }
218
219    /// Returns the x data points.
220    pub fn x(&self) -> &[f64] {
221        self.x.as_ref()
222    }
223
224    /// Returns the y data points.
225    pub fn y(&self) -> &[f64] {
226        self.y.as_ref()
227    }
228
229    /// Returns the first `(x, y)` data point.
230    pub fn first(&self) -> (f64, f64) {
231        (*self.x().first().unwrap(), *self.y().first().unwrap())
232    }
233
234    /// Returns the last `(x, y)` data point.
235    pub fn last(&self) -> (f64, f64) {
236        (*self.x().last().unwrap(), *self.y().last().unwrap())
237    }
238
239    fn check(x: &[f64], y: &[f64]) -> Result<(), SeriesError> {
240        if !x.is_strictly_increasing() {
241            return Err(SeriesError::NonMonotonic);
242        }
243
244        let n = x.len();
245
246        if y.len() != n {
247            return Err(SeriesError::DimensionMismatch(n, y.len()));
248        }
249
250        if n < MIN_POINTS_LINEAR {
251            return Err(SeriesError::InsufficientPoints(n));
252        }
253        Ok(())
254    }
255
256    fn assert(x: &[f64], y: &[f64]) {
257        assert!(x.is_strictly_increasing());
258
259        let n = x.len();
260        assert!(y.len() == n);
261        assert!(n >= MIN_POINTS_LINEAR);
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use rstest::rstest;
268
269    use lox_test_utils::assert_approx_eq;
270
271    use super::*;
272
273    #[rstest]
274    #[case(0.5, 0.5)]
275    #[case(1.0, 1.0)]
276    #[case(1.5, 1.5)]
277    #[case(2.5, 2.5)]
278    #[case(5.5, 5.5)]
279    fn test_series_linear(#[case] xp: f64, #[case] expected: f64) {
280        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
281        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
282
283        let s = Series::try_new(x, y, InterpolationType::Linear).unwrap();
284        let actual = s.interpolate(xp);
285        assert_eq!(actual, expected);
286    }
287
288    // Reference values from AstroBase.jl
289    #[rstest]
290    #[case(0.0, -14.303290471048534)]
291    #[case(0.1, -12.036932976759344)]
292    #[case(0.2, -9.978070560771739)]
293    #[case(0.3, -8.117883404355377)]
294    #[case(0.4, -6.447551688779917)]
295    #[case(0.5, -4.958255595315013)]
296    #[case(0.6, -3.6411753052303184)]
297    #[case(0.7, -2.487490999795493)]
298    #[case(0.8, -1.4883828602801898)]
299    #[case(0.9, -0.6350310679540686)]
300    #[case(1.0, 0.08138419591321655)]
301    #[case(1.1, 0.6696827500520098)]
302    #[case(1.2, 1.1386844131926532)]
303    #[case(1.3, 1.4972090040654928)]
304    #[case(1.4, 1.754076341400871)]
305    #[case(1.5, 1.9181062439291328)]
306    #[case(1.6, 1.9981185303806206)]
307    #[case(1.7, 2.002933019485679)]
308    #[case(1.8, 1.9413695299746523)]
309    #[case(1.9, 1.8222478805778837)]
310    #[case(2.0, 1.6543878900257172)]
311    #[case(2.1, 1.4466093770484965)]
312    #[case(2.2, 1.2077321603765656)]
313    #[case(2.3, 0.9465760587402696)]
314    #[case(2.4, 0.6719608908699499)]
315    #[case(2.5, 0.3927064754959517)]
316    #[case(2.6, 0.11763263134861876)]
317    #[case(2.7, -0.14444082284170534)]
318    #[case(2.8, -0.384694068344675)]
319    #[case(2.9, -0.5943072864299493)]
320    #[case(3.0, -0.7644606583671828)]
321    #[case(3.1, -0.8886377407066958)]
322    #[case(3.2, -0.9695355911214641)]
323    #[case(3.3, -1.012154642565128)]
324    #[case(3.4, -1.021495327991328)]
325    #[case(3.5, -1.0025580803537035)]
326    #[case(3.6, -0.960343332605895)]
327    #[case(3.7, -0.8998515177015425)]
328    #[case(3.8, -0.8260830685942864)]
329    #[case(3.9, -0.744038418237766)]
330    #[case(4.0, -0.6587179995856219)]
331    #[case(4.1, -0.5751222455914945)]
332    #[case(4.2, -0.4982515892090227)]
333    #[case(4.3, -0.433106463391848)]
334    #[case(4.4, -0.38468730109360944)]
335    #[case(4.5, -0.3579945352679478)]
336    #[case(4.6, -0.3580285988685027)]
337    #[case(4.7, -0.3897899248489146)]
338    #[case(4.8, -0.458278946162823)]
339    #[case(4.9, -0.5684960957638693)]
340    #[case(5.0, -0.7254418066056914)]
341    #[case(5.1, -0.9341165116419302)]
342    #[case(5.2, -1.1995206438262285)]
343    #[case(5.3, -1.5266546361122217)]
344    #[case(5.4, -1.9205189214535554)]
345    #[case(5.5, -2.3861139328038625)]
346    #[case(5.6, -2.9284401031167873)]
347    #[case(5.7, -3.5524978653459742)]
348    #[case(5.8, -4.263287652445054)]
349    #[case(5.9, -5.065809897367678)]
350    #[case(6.0, -5.965065033067472)]
351    fn test_series_spline(#[case] xp: f64, #[case] expected: f64) {
352        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
353        let y = vec![
354            0.08138419591321655,
355            1.6543878900257172,
356            -0.7644606583671828,
357            -0.6587179995856219,
358            -0.7254418066056914,
359        ];
360
361        let s = Series::try_new(x, y, InterpolationType::CubicSpline).unwrap();
362        let actual = s.interpolate(xp);
363        assert_approx_eq!(actual, expected, rtol <= 1e-12);
364    }
365
366    #[rstest]
367    #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::InsufficientPoints(1)))]
368    #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::InsufficientPoints(1)))]
369    #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::DimensionMismatch(2, 1)))]
370    #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::DimensionMismatch(2, 1)))]
371    fn test_series_errors(
372        #[case] actual: Result<Series, SeriesError>,
373        #[case] expected: Result<Series, SeriesError>,
374    ) {
375        assert_eq!(actual, expected);
376    }
377}