1use std::sync::Arc;
6
7use fast_polynomial::poly_array;
8use thiserror::Error;
9
10use crate::math::slices::Monotonic;
11
12use super::linear_algebra::tridiagonal::Tridiagonal;
13use super::slices::Diff;
14
15const MIN_POINTS_LINEAR: usize = 2;
16const MIN_POINTS_SPLINE: usize = 4;
17
18#[derive(Clone, Debug, Error, PartialEq)]
19pub enum SeriesError {
20 #[error("`x` and `y` must have the same length but were {0} and {1}")]
21 DimensionMismatch(usize, usize),
22 #[error("length of `x` and `y` must at least 2 but was {0}")]
23 InsufficientPoints(usize),
24 #[error("x-axis must be strictly monotonic")]
25 NonMonotonic,
26}
27
28#[derive(Clone, Debug, PartialEq)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum Interpolation {
31 Linear,
32 CubicSpline(Arc<[[f64; 4]]>),
33}
34
35#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct Series {
38 x: Arc<[f64]>,
39 y: Arc<[f64]>,
40 interpolation: Interpolation,
41}
42
43pub enum InterpolationType {
44 Linear,
45 CubicSpline,
46}
47
48impl Series {
49 pub fn try_new(
50 x: impl Into<Arc<[f64]>>,
51 y: impl Into<Arc<[f64]>>,
52 interpolation: InterpolationType,
53 ) -> Result<Self, SeriesError> {
54 let x: Arc<[f64]> = x.into();
55 let y: Arc<[f64]> = y.into();
56
57 Self::check(&x, &y)?;
58
59 Ok(Self::new(x, y, interpolation))
60 }
61
62 pub fn new(
63 x: impl Into<Arc<[f64]>>,
64 y: impl Into<Arc<[f64]>>,
65 interpolation: InterpolationType,
66 ) -> Self {
67 let x: Arc<[f64]> = x.into();
68 let y: Arc<[f64]> = y.into();
69
70 Self::assert(&x, &y);
71
72 match interpolation {
73 InterpolationType::Linear => Self::linear(x, y),
74 InterpolationType::CubicSpline => {
75 let n = x.len();
76 if n < MIN_POINTS_SPLINE {
77 Self::linear(x, y)
78 } else {
79 Self::cubic_spline(x, y)
80 }
81 }
82 }
83 }
84
85 fn linear(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
86 Self {
87 x,
88 y,
89 interpolation: Interpolation::Linear,
90 }
91 }
92
93 fn cubic_spline(x: Arc<[f64]>, y: Arc<[f64]>) -> Self {
94 let n = x.len();
95
96 let dx = x.diff();
97 let nd = dx.len();
98 let slope: Vec<f64> = y
99 .diff()
100 .iter()
101 .enumerate()
102 .map(|(idx, y)| y / dx[idx])
103 .collect();
104
105 let mut d: Vec<f64> = dx[0..nd - 1]
106 .iter()
107 .enumerate()
108 .map(|(idx, dxi)| 2.0 * (dxi + dx[idx + 1]))
109 .collect();
110 let mut du: Vec<f64> = dx[0..nd - 1].to_vec();
111 let mut dl: Vec<f64> = dx[1..].to_vec();
112 let mut b: Vec<f64> = dx[0..nd - 1]
113 .iter()
114 .enumerate()
115 .map(|(idx, dxi)| 3.0 * (dx[idx + 1] * slope[idx] + dxi * slope[idx + 1]))
116 .collect();
117
118 d.insert(0, dx[1]);
120 du.insert(0, x[2] - x[0]);
121 let delta = x[2] - x[0];
122 b.insert(
123 0,
124 ((dx[0] + 2.0 * delta) * dx[1] * slope[0] + dx[0].powi(2) * slope[1]) / delta,
125 );
126 d.push(dx[nd - 2]);
127 let delta = x[n - 1] - x[n - 3];
128 dl.push(delta);
129 b.push(
130 (dx[nd - 1].powi(2) * slope[nd - 2]
131 + (2.0 * delta + dx[nd - 1]) * dx[nd - 2] * slope[nd - 1])
132 / delta,
133 );
134
135 let tri = Tridiagonal::new(&dl, &d, &du).unwrap_or_else(|err| {
136 unreachable!(
137 "dimensions should be correct for tridiagonal system: {}",
138 err
139 )
140 });
141 let s = tri.solve(&b);
142 let t: Vec<f64> = s[0..n - 1]
143 .iter()
144 .enumerate()
145 .map(|(idx, si)| (si + s[idx + 1] - 2.0 * slope[idx]) / dx[idx])
146 .collect();
147
148 let coeffs: Vec<[f64; 4]> = (0..n - 1)
149 .map(|i| {
150 let c1 = y[i];
151 let c2 = s[i];
152 let c3 = (slope[i] - s[i]) / dx[i] - t[i];
153 let c4 = t[i] / dx[i];
154 [c1, c2, c3, c4]
155 })
156 .collect();
157
158 Self {
159 x,
160 y,
161 interpolation: Interpolation::CubicSpline(coeffs.into()),
162 }
163 }
164
165 #[inline]
166 pub fn find_index(&self, xp: f64) -> usize {
167 let x = self.x.as_ref();
168 let x0 = *x.first().unwrap();
169 let xn = *x.last().unwrap();
170 if xp <= x0 {
171 0
172 } else if xp >= xn {
173 x.len() - 2
174 } else {
175 x.partition_point(|&val| xp > val) - 1
176 }
177 }
178
179 #[inline]
180 pub fn interpolate_at_index(&self, xp: f64, idx: usize) -> f64 {
181 match &self.interpolation {
182 Interpolation::Linear => {
183 let x = self.x.as_ref();
184 let y = self.y.as_ref();
185 let x0 = x[idx];
186 let x1 = x[idx + 1];
187 let y0 = y[idx];
188 let y1 = y[idx + 1];
189 y0 + (y1 - y0) * (xp - x0) / (x1 - x0)
190 }
191 Interpolation::CubicSpline(coeffs) => poly_array(xp - self.x[idx], &coeffs[idx]),
192 }
193 }
194
195 #[inline]
196 pub fn interpolate(&self, xp: f64) -> f64 {
197 let idx = self.find_index(xp);
198 self.interpolate_at_index(xp, idx)
199 }
200
201 pub fn x(&self) -> &[f64] {
202 self.x.as_ref()
203 }
204
205 pub fn y(&self) -> &[f64] {
206 self.y.as_ref()
207 }
208
209 pub fn first(&self) -> (f64, f64) {
210 (*self.x().first().unwrap(), *self.y().first().unwrap())
211 }
212
213 pub fn last(&self) -> (f64, f64) {
214 (*self.x().last().unwrap(), *self.y().last().unwrap())
215 }
216
217 fn check(x: &[f64], y: &[f64]) -> Result<(), SeriesError> {
218 if !x.is_strictly_increasing() {
219 return Err(SeriesError::NonMonotonic);
220 }
221
222 let n = x.len();
223
224 if y.len() != n {
225 return Err(SeriesError::DimensionMismatch(n, y.len()));
226 }
227
228 if n < MIN_POINTS_LINEAR {
229 return Err(SeriesError::InsufficientPoints(n));
230 }
231 Ok(())
232 }
233
234 fn assert(x: &[f64], y: &[f64]) {
235 assert!(x.is_strictly_increasing());
236
237 let n = x.len();
238 assert!(y.len() == n);
239 assert!(n >= MIN_POINTS_LINEAR);
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use rstest::rstest;
246
247 use lox_test_utils::assert_approx_eq;
248
249 use super::*;
250
251 #[rstest]
252 #[case(0.5, 0.5)]
253 #[case(1.0, 1.0)]
254 #[case(1.5, 1.5)]
255 #[case(2.5, 2.5)]
256 #[case(5.5, 5.5)]
257 fn test_series_linear(#[case] xp: f64, #[case] expected: f64) {
258 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
259 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
260
261 let s = Series::try_new(x, y, InterpolationType::Linear).unwrap();
262 let actual = s.interpolate(xp);
263 assert_eq!(actual, expected);
264 }
265
266 #[rstest]
268 #[case(0.0, -14.303290471048534)]
269 #[case(0.1, -12.036932976759344)]
270 #[case(0.2, -9.978070560771739)]
271 #[case(0.3, -8.117883404355377)]
272 #[case(0.4, -6.447551688779917)]
273 #[case(0.5, -4.958255595315013)]
274 #[case(0.6, -3.6411753052303184)]
275 #[case(0.7, -2.487490999795493)]
276 #[case(0.8, -1.4883828602801898)]
277 #[case(0.9, -0.6350310679540686)]
278 #[case(1.0, 0.08138419591321655)]
279 #[case(1.1, 0.6696827500520098)]
280 #[case(1.2, 1.1386844131926532)]
281 #[case(1.3, 1.4972090040654928)]
282 #[case(1.4, 1.754076341400871)]
283 #[case(1.5, 1.9181062439291328)]
284 #[case(1.6, 1.9981185303806206)]
285 #[case(1.7, 2.002933019485679)]
286 #[case(1.8, 1.9413695299746523)]
287 #[case(1.9, 1.8222478805778837)]
288 #[case(2.0, 1.6543878900257172)]
289 #[case(2.1, 1.4466093770484965)]
290 #[case(2.2, 1.2077321603765656)]
291 #[case(2.3, 0.9465760587402696)]
292 #[case(2.4, 0.6719608908699499)]
293 #[case(2.5, 0.3927064754959517)]
294 #[case(2.6, 0.11763263134861876)]
295 #[case(2.7, -0.14444082284170534)]
296 #[case(2.8, -0.384694068344675)]
297 #[case(2.9, -0.5943072864299493)]
298 #[case(3.0, -0.7644606583671828)]
299 #[case(3.1, -0.8886377407066958)]
300 #[case(3.2, -0.9695355911214641)]
301 #[case(3.3, -1.012154642565128)]
302 #[case(3.4, -1.021495327991328)]
303 #[case(3.5, -1.0025580803537035)]
304 #[case(3.6, -0.960343332605895)]
305 #[case(3.7, -0.8998515177015425)]
306 #[case(3.8, -0.8260830685942864)]
307 #[case(3.9, -0.744038418237766)]
308 #[case(4.0, -0.6587179995856219)]
309 #[case(4.1, -0.5751222455914945)]
310 #[case(4.2, -0.4982515892090227)]
311 #[case(4.3, -0.433106463391848)]
312 #[case(4.4, -0.38468730109360944)]
313 #[case(4.5, -0.3579945352679478)]
314 #[case(4.6, -0.3580285988685027)]
315 #[case(4.7, -0.3897899248489146)]
316 #[case(4.8, -0.458278946162823)]
317 #[case(4.9, -0.5684960957638693)]
318 #[case(5.0, -0.7254418066056914)]
319 #[case(5.1, -0.9341165116419302)]
320 #[case(5.2, -1.1995206438262285)]
321 #[case(5.3, -1.5266546361122217)]
322 #[case(5.4, -1.9205189214535554)]
323 #[case(5.5, -2.3861139328038625)]
324 #[case(5.6, -2.9284401031167873)]
325 #[case(5.7, -3.5524978653459742)]
326 #[case(5.8, -4.263287652445054)]
327 #[case(5.9, -5.065809897367678)]
328 #[case(6.0, -5.965065033067472)]
329 fn test_series_spline(#[case] xp: f64, #[case] expected: f64) {
330 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
331 let y = vec![
332 0.08138419591321655,
333 1.6543878900257172,
334 -0.7644606583671828,
335 -0.6587179995856219,
336 -0.7254418066056914,
337 ];
338
339 let s = Series::try_new(x, y, InterpolationType::CubicSpline).unwrap();
340 let actual = s.interpolate(xp);
341 assert_approx_eq!(actual, expected, rtol <= 1e-12);
342 }
343
344 #[rstest]
345 #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::InsufficientPoints(1)))]
346 #[case(Series::try_new(vec![1.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::InsufficientPoints(1)))]
347 #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::Linear), Err(SeriesError::DimensionMismatch(2, 1)))]
348 #[case(Series::try_new(vec![1.0, 2.0], vec![1.0], InterpolationType::CubicSpline), Err(SeriesError::DimensionMismatch(2, 1)))]
349 fn test_series_errors(
350 #[case] actual: Result<Series, SeriesError>,
351 #[case] expected: Result<Series, SeriesError>,
352 ) {
353 assert_eq!(actual, expected);
354 }
355}