numerics_rs/
interp.rs

1/// Enum to define the type of interpolation
2#[derive(Debug)]
3pub enum InterpolationType {
4    Linear,           // Linear interpolation (order 1)
5    Quadratic,        // Quadratic spline interpolation (order 2)
6    Cubic,            // Cubic spline interpolation (order 3)
7    ConstantBackward, // Constant interpolation taking the previous value
8    ConstantForward,  // Constant interpolation taking the next value
9}
10
11/// Enum to define the extrapolation strategy
12#[derive(Debug)]
13pub enum ExtrapolationStrategy {
14    None,         // Do not extrapolate, panic on out-of-bounds
15    Constant,     // Use the closest y-value for out-of-bounds x
16    ExtendSpline, // Use the same spline function as interpolation
17}
18
19#[derive(Debug)]
20pub struct Interpolator {
21    x_values: Vec<f64>,
22    y_values: Vec<f64>,
23    b_coeffs: Vec<f64>,
24    c_coeffs: Vec<f64>,
25    d_coeffs: Vec<f64>,
26    interpolation_type: InterpolationType,
27    extrap_strategy: ExtrapolationStrategy,
28}
29
30impl Interpolator {
31    /// Creates a new Interpolator with the given points
32    pub fn new(
33        x_values: Vec<f64>,
34        y_values: Vec<f64>,
35        interpolation_type: InterpolationType,
36        extrap_strategy: ExtrapolationStrategy,
37    ) -> Self {
38        if x_values.len() != y_values.len() || x_values.len() < 2 {
39            panic!(
40                "x_values and y_values must have the same length and contain at least two points."
41            );
42        }
43
44        // Precompute spline coefficients
45        let (b_coeffs, c_coeffs, d_coeffs) =
46            compute_spline_coefficients(&x_values, &y_values, &interpolation_type);
47        Self {
48            x_values,
49            y_values,
50            b_coeffs,
51            c_coeffs,
52            d_coeffs,
53            interpolation_type,
54            extrap_strategy,
55        }
56    }
57
58    /// Performs interpolation for a given x value using the specified type
59    pub fn interpolate(&self, x: f64) -> f64 {
60        for j in 0..self.x_values.len() - 1 {
61            if self.x_values[j] <= x && x <= self.x_values[j + 1] {
62                // We found where the value is bracketed
63                let dx = x - self.x_values[j];
64                return match self.interpolation_type {
65                    InterpolationType::Cubic
66                    | InterpolationType::Quadratic
67                    | InterpolationType::Linear => {
68                        self.y_values[j]
69                            + self.b_coeffs[j] * dx
70                            + self.c_coeffs[j] * dx.powi(2)
71                            + self.d_coeffs[j] * dx.powi(3)
72                    }
73                    InterpolationType::ConstantBackward => self.y_values[j],
74                    InterpolationType::ConstantForward => self.y_values[j + 1],
75                };
76            }
77        }
78        if x < *self.x_values.first().unwrap() || x > *self.x_values.last().unwrap() {
79            return self.extrapolate(x);
80        }
81        unreachable!("This could not be reached as the x is either bracketed or extrapolated");
82    }
83
84    /// Handles extrapolation for out-of-bounds x values
85    fn extrapolate(&self, x: f64) -> f64 {
86        match self.extrap_strategy {
87            ExtrapolationStrategy::None => {
88                panic!(
89                    "Value x = {} is out of bounds and no extrapolation is enabled.",
90                    x
91                );
92            }
93            ExtrapolationStrategy::Constant => {
94                if x < *self.x_values.first().unwrap() {
95                    return *self.y_values.first().unwrap();
96                }
97                *self.y_values.last().unwrap()
98            }
99            ExtrapolationStrategy::ExtendSpline => {
100                let j = if x < *self.x_values.first().unwrap() {
101                    0
102                } else {
103                    self.x_values.len() - 2
104                };
105                let dx = x - self.x_values[j];
106                self.y_values[j]
107                    + self.b_coeffs[j] * dx
108                    + self.c_coeffs[j] * dx.powi(2)
109                    + self.d_coeffs[j] * dx.powi(3)
110            }
111        }
112    }
113}
114
115/// Computes the coefficients for cubic spline interpolation
116fn compute_spline_coefficients(
117    x: &[f64],
118    y: &[f64],
119    interpolation_type: &InterpolationType,
120) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
121    if matches!(
122        interpolation_type,
123        InterpolationType::ConstantForward | InterpolationType::ConstantBackward
124    ) {
125        return (vec![], vec![], vec![]);
126    }
127
128    let n = x.len() - 1; // Number of segments
129    let dx: Vec<f64> = (0..n).map(|i| x[i + 1] - x[i]).collect(); // Spacing between x-values
130    let dy: Vec<f64> = (0..n).map(|i| y[i + 1] - y[i]).collect(); // Spacing between y-values
131    let slopes = (0..n).map(|i| dy[i] / dx[i]).collect();
132
133    match interpolation_type {
134        InterpolationType::Linear => (slopes, vec![0.0; n], vec![0.0; n]),
135        // TODO: The code below could use more declarative way of solving equations, to be changed when I add matrix API
136        InterpolationType::Quadratic => {
137            let mut c = vec![0.0; n]; // Quadratic coefficients
138
139            for i in 1..n - 1 {
140                c[i] = (slopes[i] - slopes[i - 1]) / (dx[i - 1] * 2.0); // Derivative change over interval
141            }
142            c[n - 1] = 0.0; // Natural boundary at the last interval
143            (slopes, c, vec![0.0; n])
144        }
145        InterpolationType::Cubic => {
146            let mut alpha = vec![0.0; n - 1];
147            for i in 1..n {
148                alpha[i - 1] =
149                    3.0 / dx[i] * (y[i + 1] - y[i]) - 3.0 / dx[i - 1] * (y[i] - y[i - 1]);
150            }
151            let mut b = vec![0.0; n];
152            let mut c = vec![0.0; n + 1];
153            let mut d = vec![0.0; n];
154            let mut l = vec![1.0; n + 1];
155            let mut mu = vec![0.0; n];
156            let mut z = vec![0.0; n + 1];
157
158            for i in 1..n {
159                l[i] = 2.0 * (x[i + 1] - x[i - 1]) - dx[i - 1] * mu[i - 1];
160                mu[i] = dx[i] / l[i];
161                z[i] = (alpha[i - 1] - dx[i - 1] * z[i - 1]) / l[i];
162            }
163            for j in (0..n).rev() {
164                c[j] = z[j] - mu[j] * c[j + 1];
165                b[j] = dy[j] / dx[j] - dx[j] * (c[j + 1] + 2.0 * c[j]) / 3.0;
166                d[j] = (c[j + 1] - c[j]) / (3.0 * dx[j]);
167            }
168            (b, c, d)
169        }
170        _ => panic!(
171            "Interpolation type {:?} is not supported.",
172            interpolation_type
173        ),
174    }
175}