convex_math/interpolation/
linear.rs

1//! Linear interpolation.
2
3use crate::error::{MathError, MathResult};
4use crate::interpolation::Interpolator;
5
6/// Linear interpolation between data points.
7///
8/// The simplest form of interpolation, connecting consecutive points
9/// with straight lines.
10///
11/// # Example
12///
13/// ```rust
14/// use convex_math::interpolation::{LinearInterpolator, Interpolator};
15///
16/// let xs = vec![0.0, 1.0, 2.0, 3.0];
17/// let ys = vec![0.0, 1.0, 4.0, 9.0];
18///
19/// let interp = LinearInterpolator::new(xs, ys).unwrap();
20/// let y = interp.interpolate(1.5).unwrap();
21/// // y = 2.5 (linear interpolation between (1, 1) and (2, 4))
22/// ```
23#[derive(Debug, Clone)]
24pub struct LinearInterpolator {
25    xs: Vec<f64>,
26    ys: Vec<f64>,
27    allow_extrapolation: bool,
28}
29
30impl LinearInterpolator {
31    /// Creates a new linear interpolator.
32    ///
33    /// # Arguments
34    ///
35    /// * `xs` - X coordinates (must be sorted in ascending order)
36    /// * `ys` - Y coordinates
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if there are fewer than 2 points or if lengths differ.
41    pub fn new(xs: Vec<f64>, ys: Vec<f64>) -> MathResult<Self> {
42        if xs.len() < 2 {
43            return Err(MathError::insufficient_data(2, xs.len()));
44        }
45        if xs.len() != ys.len() {
46            return Err(MathError::invalid_input(format!(
47                "xs and ys must have same length: {} vs {}",
48                xs.len(),
49                ys.len()
50            )));
51        }
52
53        // Check that xs are sorted
54        for i in 1..xs.len() {
55            if xs[i] <= xs[i - 1] {
56                return Err(MathError::invalid_input(
57                    "x values must be strictly increasing",
58                ));
59            }
60        }
61
62        Ok(Self {
63            xs,
64            ys,
65            allow_extrapolation: false,
66        })
67    }
68
69    /// Enables extrapolation beyond the data range.
70    #[must_use]
71    pub fn with_extrapolation(mut self) -> Self {
72        self.allow_extrapolation = true;
73        self
74    }
75
76    /// Finds the index i such that xs[i] <= x < xs[i+1].
77    fn find_segment(&self, x: f64) -> usize {
78        // Binary search
79        match self
80            .xs
81            .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
82        {
83            Ok(i) => i.min(self.xs.len() - 2),
84            Err(i) => (i.saturating_sub(1)).min(self.xs.len() - 2),
85        }
86    }
87}
88
89impl Interpolator for LinearInterpolator {
90    fn interpolate(&self, x: f64) -> MathResult<f64> {
91        // Check bounds
92        if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
93            return Err(MathError::ExtrapolationNotAllowed {
94                x,
95                min: self.xs[0],
96                max: self.xs[self.xs.len() - 1],
97            });
98        }
99
100        let i = self.find_segment(x);
101
102        let x0 = self.xs[i];
103        let x1 = self.xs[i + 1];
104        let y0 = self.ys[i];
105        let y1 = self.ys[i + 1];
106
107        // Linear interpolation formula
108        let t = (x - x0) / (x1 - x0);
109        Ok(y0 + t * (y1 - y0))
110    }
111
112    fn derivative(&self, x: f64) -> MathResult<f64> {
113        // Check bounds
114        if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
115            return Err(MathError::ExtrapolationNotAllowed {
116                x,
117                min: self.xs[0],
118                max: self.xs[self.xs.len() - 1],
119            });
120        }
121
122        let i = self.find_segment(x);
123
124        let x0 = self.xs[i];
125        let x1 = self.xs[i + 1];
126        let y0 = self.ys[i];
127        let y1 = self.ys[i + 1];
128
129        // Derivative of linear interpolation is constant slope
130        Ok((y1 - y0) / (x1 - x0))
131    }
132
133    fn allows_extrapolation(&self) -> bool {
134        self.allow_extrapolation
135    }
136
137    fn min_x(&self) -> f64 {
138        self.xs[0]
139    }
140
141    fn max_x(&self) -> f64 {
142        self.xs[self.xs.len() - 1]
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use approx::assert_relative_eq;
150
151    #[test]
152    fn test_linear_interpolation() {
153        let xs = vec![0.0, 1.0, 2.0];
154        let ys = vec![0.0, 2.0, 4.0];
155
156        let interp = LinearInterpolator::new(xs, ys).unwrap();
157
158        // Test at exact points
159        assert_relative_eq!(interp.interpolate(0.0).unwrap(), 0.0, epsilon = 1e-10);
160        assert_relative_eq!(interp.interpolate(1.0).unwrap(), 2.0, epsilon = 1e-10);
161        assert_relative_eq!(interp.interpolate(2.0).unwrap(), 4.0, epsilon = 1e-10);
162
163        // Test interpolation
164        assert_relative_eq!(interp.interpolate(0.5).unwrap(), 1.0, epsilon = 1e-10);
165        assert_relative_eq!(interp.interpolate(1.5).unwrap(), 3.0, epsilon = 1e-10);
166    }
167
168    #[test]
169    fn test_extrapolation_disabled() {
170        let xs = vec![0.0, 1.0, 2.0];
171        let ys = vec![0.0, 1.0, 2.0];
172
173        let interp = LinearInterpolator::new(xs, ys).unwrap();
174
175        assert!(interp.interpolate(-0.5).is_err());
176        assert!(interp.interpolate(2.5).is_err());
177    }
178
179    #[test]
180    fn test_extrapolation_enabled() {
181        let xs = vec![0.0, 1.0, 2.0];
182        let ys = vec![0.0, 1.0, 2.0];
183
184        let interp = LinearInterpolator::new(xs, ys)
185            .unwrap()
186            .with_extrapolation();
187
188        // Should extrapolate linearly
189        assert_relative_eq!(interp.interpolate(-1.0).unwrap(), -1.0, epsilon = 1e-10);
190        assert_relative_eq!(interp.interpolate(3.0).unwrap(), 3.0, epsilon = 1e-10);
191    }
192
193    #[test]
194    fn test_insufficient_points() {
195        let xs = vec![0.0];
196        let ys = vec![1.0];
197
198        assert!(LinearInterpolator::new(xs, ys).is_err());
199    }
200
201    #[test]
202    fn test_unsorted_error() {
203        let xs = vec![1.0, 0.0, 2.0]; // Not sorted
204        let ys = vec![1.0, 0.0, 2.0];
205
206        assert!(LinearInterpolator::new(xs, ys).is_err());
207    }
208}