Skip to main content

lox_core/math/
series.rs

1// SPDX-FileCopyrightText: 2024 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5use std::sync::Arc;
6
7use fast_polynomial::poly_array;
8use thiserror::Error;
9
10use crate::math::slices::Monotonic;
11
12use super::linear_algebra::tridiagonal::Tridiagonal;
13use super::slices::Diff;
14
15const MIN_POINTS_LINEAR: usize = 2;
16const MIN_POINTS_SPLINE: usize = 4;
17
18#[derive(Clone, Debug, Error, PartialEq)]
19pub enum SeriesError {
20    #[error("`x` and `y` must have the same length but were {0} and {1}")]
21    DimensionMismatch(usize, usize),
22    #[error("length of `x` and `y` must at least 2 but was {0}")]
23    InsufficientPoints(usize),
24    #[error("x-axis must be strictly monotonic")]
25    NonMonotonic,
26}
27
28#[derive(Clone, Debug, PartialEq)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum Interpolation {
31    Linear,
32    CubicSpline(Arc<[[f64; 4]]>),
33}
34
35#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct Series {
38    x: Arc<[f64]>,
39    y: Arc<[f64]>,
40    interpolation: Interpolation,
41}
42
43pub enum InterpolationType {
44    Linear,
45    CubicSpline,
46}
47
48impl Series {
49    pub fn try_new(
50        x: impl Into<Arc<[f64]>>,
51        y: impl Into<Arc<[f64]>>,
52        interpolation: InterpolationType,
53    ) -> Result<Self, SeriesError> {
54        let x: Arc<[f64]> = x.into();
55        let y: Arc<[f64]> = y.into();
56
57        Self::check(&x, &y)?;
58
59        Ok(Self::new(x, y, interpolation))
60    }
61
62    pub fn new(
63        x: impl Into<Arc<[f64]>>,
64        y: impl Into<Arc<[f64]>>,
65        interpolation: InterpolationType,
66    ) -> Self {
67        let x: Arc<[f64]> = x.into();
68        let y: Arc<[f64]> = y.into();
69
70        Self::assert(&x, &y);
71
72        match interpolation {
73            InterpolationType::Linear => Self::linear(x, y),
74            InterpolationType::CubicSpline => {
75                let n = x.len();
76                if n < MIN_POINTS_SPLINE {
77                    Self::linear(x, y)
78                } else {
79                    Self::cubic_spline(x, y)
80                }
81            }
82        }
83    }
84
85    fn linear(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
86        Self {
87            x,
88            y,
89            interpolation: Interpolation::Linear,
90        }
91    }
92
93    fn cubic_spline(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
94        let n = x.len();
95
96        let dx = x.diff();
97        let nd = dx.len();
98        let slope: Vec<f64> = y
99            .diff()
100            .iter()
101            .enumerate()
102            .map(|(idx, y)| y / dx[idx])
103            .collect();
104
105        let mut d: Vec<f64> = dx[0..nd - 1]
106            .iter()
107            .enumerate()
108            .map(|(idx, dxi)| 2.0 * (dxi + dx[idx + 1]))
109            .collect();
110        let mut du: Vec<f64> = dx[0..nd - 1].to_vec();
111        let mut dl: Vec<f64> = dx[1..].to_vec();
112        let mut b: Vec<f64> = dx[0..nd - 1]
113            .iter()
114            .enumerate()
115            .map(|(idx, dxi)| 3.0 * (dx[idx + 1] * slope[idx] + dxi * slope[idx + 1]))
116            .collect();
117
118        // Not-a-knot boundary condition
119        d.insert(0, dx[1]);
120        du.insert(0, x[2] - x[0]);
121        let delta = x[2] - x[0];
122        b.insert(
123            0,
124            ((dx[0] + 2.0 * delta) * dx[1] * slope[0] + dx[0].powi(2) * slope[1]) / delta,
125        );
126        d.push(dx[nd - 2]);
127        let delta = x[n - 1] - x[n - 3];
128        dl.push(delta);
129        b.push(
130            (dx[nd - 1].powi(2) * slope[nd - 2]
131                + (2.0 * delta + dx[nd - 1]) * dx[nd - 2] * slope[nd - 1])
132                / delta,
133        );
134
135        let tri = Tridiagonal::new(&dl, &d, &du).unwrap_or_else(|err| {
136            unreachable!(
137                "dimensions should be correct for tridiagonal system: {}",
138                err
139            )
140        });
141        let s = tri.solve(&b);
142        let t: Vec<f64> = s[0..n - 1]
143            .iter()
144            .enumerate()
145            .map(|(idx, si)| (si + s[idx + 1] - 2.0 * slope[idx]) / dx[idx])
146            .collect();
147
148        let coeffs: Vec<[f64; 4]> = (0..n - 1)
149            .map(|i| {
150                let c1 = y[i];
151                let c2 = s[i];
152                let c3 = (slope[i] - s[i]) / dx[i] - t[i];
153                let c4 = t[i] / dx[i];
154                [c1, c2, c3, c4]
155            })
156            .collect();
157
158        Self {
159            x,
160            y,
161            interpolation: Interpolation::CubicSpline(coeffs.into()),
162        }
163    }
164
165    #[inline]
166    pub fn find_index(&self, xp: f64) -> usize {
167        let x = self.x.as_ref();
168        let x0 = *x.first().unwrap();
169        let xn = *x.last().unwrap();
170        if xp <= x0 {
171            0
172        } else if xp >= xn {
173            x.len() - 2
174        } else {
175            x.partition_point(|&val| xp > val) - 1
176        }
177    }
178
179    #[inline]
180    pub fn interpolate_at_index(&self, xp: f64, idx: usize) -> f64 {
181        match &self.interpolation {
182            Interpolation::Linear => {
183                let x = self.x.as_ref();
184                let y = self.y.as_ref();
185                let x0 = x[idx];
186                let x1 = x[idx + 1];
187                let y0 = y[idx];
188                let y1 = y[idx + 1];
189                y0 + (y1 - y0) * (xp - x0) / (x1 - x0)
190            }
191            Interpolation::CubicSpline(coeffs) => poly_array(xp - self.x[idx], &coeffs[idx]),
192        }
193    }
194
195    #[inline]
196    pub fn interpolate(&self, xp: f64) -> f64 {
197        let idx = self.find_index(xp);
198        self.interpolate_at_index(xp, idx)
199    }
200
201    pub fn x(&self) -> &[f64] {
202        self.x.as_ref()
203    }
204
205    pub fn y(&self) -> &[f64] {
206        self.y.as_ref()
207    }
208
209    pub fn first(&self) -> (f64, f64) {
210        (*self.x().first().unwrap(), *self.y().first().unwrap())
211    }
212
213    pub fn last(&self) -> (f64, f64) {
214        (*self.x().last().unwrap(), *self.y().last().unwrap())
215    }
216
217    fn check(x: &[f64], y: &[f64]) -> Result<(), SeriesError> {
218        if !x.is_strictly_increasing() {
219            return Err(SeriesError::NonMonotonic);
220        }
221
222        let n = x.len();
223
224        if y.len() != n {
225            return Err(SeriesError::DimensionMismatch(n, y.len()));
226        }
227
228        if n < MIN_POINTS_LINEAR {
229            return Err(SeriesError::InsufficientPoints(n));
230        }
231        Ok(())
232    }
233
234    fn assert(x: &[f64], y: &[f64]) {
235        assert!(x.is_strictly_increasing());
236
237        let n = x.len();
238        assert!(y.len() == n);
239        assert!(n >= MIN_POINTS_LINEAR);
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use rstest::rstest;
246
247    use lox_test_utils::assert_approx_eq;
248
249    use super::*;
250
251    #[rstest]
252    #[case(0.5, 0.5)]
253    #[case(1.0, 1.0)]
254    #[case(1.5, 1.5)]
255    #[case(2.5, 2.5)]
256    #[case(5.5, 5.5)]
257    fn test_series_linear(#[case] xp: f64, #[case] expected: f64) {
258        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
259        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
260
261        let s = Series::try_new(x, y, InterpolationType::Linear).unwrap();
262        let actual = s.interpolate(xp);
263        assert_eq!(actual, expected);
264    }
265
266    // Reference values from AstroBase.jl
267    #[rstest]
268    #[case(0.0, -14.303290471048534)]
269    #[case(0.1, -12.036932976759344)]
270    #[case(0.2, -9.978070560771739)]
271    #[case(0.3, -8.117883404355377)]
272    #[case(0.4, -6.447551688779917)]
273    #[case(0.5, -4.958255595315013)]
274    #[case(0.6, -3.6411753052303184)]
275    #[case(0.7, -2.487490999795493)]
276    #[case(0.8, -1.4883828602801898)]
277    #[case(0.9, -0.6350310679540686)]
278    #[case(1.0, 0.08138419591321655)]
279    #[case(1.1, 0.6696827500520098)]
280    #[case(1.2, 1.1386844131926532)]
281    #[case(1.3, 1.4972090040654928)]
282    #[case(1.4, 1.754076341400871)]
283    #[case(1.5, 1.9181062439291328)]
284    #[case(1.6, 1.9981185303806206)]
285    #[case(1.7, 2.002933019485679)]
286    #[case(1.8, 1.9413695299746523)]
287    #[case(1.9, 1.8222478805778837)]
288    #[case(2.0, 1.6543878900257172)]
289    #[case(2.1, 1.4466093770484965)]
290    #[case(2.2, 1.2077321603765656)]
291    #[case(2.3, 0.9465760587402696)]
292    #[case(2.4, 0.6719608908699499)]
293    #[case(2.5, 0.3927064754959517)]
294    #[case(2.6, 0.11763263134861876)]
295    #[case(2.7, -0.14444082284170534)]
296    #[case(2.8, -0.384694068344675)]
297    #[case(2.9, -0.5943072864299493)]
298    #[case(3.0, -0.7644606583671828)]
299    #[case(3.1, -0.8886377407066958)]
300    #[case(3.2, -0.9695355911214641)]
301    #[case(3.3, -1.012154642565128)]
302    #[case(3.4, -1.021495327991328)]
303    #[case(3.5, -1.0025580803537035)]
304    #[case(3.6, -0.960343332605895)]
305    #[case(3.7, -0.8998515177015425)]
306    #[case(3.8, -0.8260830685942864)]
307    #[case(3.9, -0.744038418237766)]
308    #[case(4.0, -0.6587179995856219)]
309    #[case(4.1, -0.5751222455914945)]
310    #[case(4.2, -0.4982515892090227)]
311    #[case(4.3, -0.433106463391848)]
312    #[case(4.4, -0.38468730109360944)]
313    #[case(4.5, -0.3579945352679478)]
314    #[case(4.6, -0.3580285988685027)]
315    #[case(4.7, -0.3897899248489146)]
316    #[case(4.8, -0.458278946162823)]
317    #[case(4.9, -0.5684960957638693)]
318    #[case(5.0, -0.7254418066056914)]
319    #[case(5.1, -0.9341165116419302)]
320    #[case(5.2, -1.1995206438262285)]
321    #[case(5.3, -1.5266546361122217)]
322    #[case(5.4, -1.9205189214535554)]
323    #[case(5.5, -2.3861139328038625)]
324    #[case(5.6, -2.9284401031167873)]
325    #[case(5.7, -3.5524978653459742)]
326    #[case(5.8, -4.263287652445054)]
327    #[case(5.9, -5.065809897367678)]
328    #[case(6.0, -5.965065033067472)]
329    fn test_series_spline(#[case] xp: f64, #[case] expected: f64) {
330        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
331        let y = vec![
332            0.08138419591321655,
333            1.6543878900257172,
334            -0.7644606583671828,
335            -0.6587179995856219,
336            -0.7254418066056914,
337        ];
338
339        let s = Series::try_new(x, y, InterpolationType::CubicSpline).unwrap();
340        let actual = s.interpolate(xp);
341        assert_approx_eq!(actual, expected, rtol <= 1e-12);
342    }
343
344    #[rstest]
345    #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::InsufficientPoints(1)))]
346    #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::InsufficientPoints(1)))]
347    #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::DimensionMismatch(2, 1)))]
348    #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::DimensionMismatch(2, 1)))]
349    fn test_series_errors(
350        #[case] actual: Result<Series, SeriesError>,
351        #[case] expected: Result<Series, SeriesError>,
352    ) {
353        assert_eq!(actual, expected);
354    }
355}