convex_math/interpolation/
linear.rs1use crate::error::{MathError, MathResult};
4use crate::interpolation::Interpolator;
5
6#[derive(Debug, Clone)]
24pub struct LinearInterpolator {
25 xs: Vec<f64>,
26 ys: Vec<f64>,
27 allow_extrapolation: bool,
28}
29
30impl LinearInterpolator {
31 pub fn new(xs: Vec<f64>, ys: Vec<f64>) -> MathResult<Self> {
42 if xs.len() < 2 {
43 return Err(MathError::insufficient_data(2, xs.len()));
44 }
45 if xs.len() != ys.len() {
46 return Err(MathError::invalid_input(format!(
47 "xs and ys must have same length: {} vs {}",
48 xs.len(),
49 ys.len()
50 )));
51 }
52
53 for i in 1..xs.len() {
55 if xs[i] <= xs[i - 1] {
56 return Err(MathError::invalid_input(
57 "x values must be strictly increasing",
58 ));
59 }
60 }
61
62 Ok(Self {
63 xs,
64 ys,
65 allow_extrapolation: false,
66 })
67 }
68
69 #[must_use]
71 pub fn with_extrapolation(mut self) -> Self {
72 self.allow_extrapolation = true;
73 self
74 }
75
76 fn find_segment(&self, x: f64) -> usize {
78 match self
80 .xs
81 .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
82 {
83 Ok(i) => i.min(self.xs.len() - 2),
84 Err(i) => (i.saturating_sub(1)).min(self.xs.len() - 2),
85 }
86 }
87}
88
89impl Interpolator for LinearInterpolator {
90 fn interpolate(&self, x: f64) -> MathResult<f64> {
91 if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
93 return Err(MathError::ExtrapolationNotAllowed {
94 x,
95 min: self.xs[0],
96 max: self.xs[self.xs.len() - 1],
97 });
98 }
99
100 let i = self.find_segment(x);
101
102 let x0 = self.xs[i];
103 let x1 = self.xs[i + 1];
104 let y0 = self.ys[i];
105 let y1 = self.ys[i + 1];
106
107 let t = (x - x0) / (x1 - x0);
109 Ok(y0 + t * (y1 - y0))
110 }
111
112 fn derivative(&self, x: f64) -> MathResult<f64> {
113 if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
115 return Err(MathError::ExtrapolationNotAllowed {
116 x,
117 min: self.xs[0],
118 max: self.xs[self.xs.len() - 1],
119 });
120 }
121
122 let i = self.find_segment(x);
123
124 let x0 = self.xs[i];
125 let x1 = self.xs[i + 1];
126 let y0 = self.ys[i];
127 let y1 = self.ys[i + 1];
128
129 Ok((y1 - y0) / (x1 - x0))
131 }
132
133 fn allows_extrapolation(&self) -> bool {
134 self.allow_extrapolation
135 }
136
137 fn min_x(&self) -> f64 {
138 self.xs[0]
139 }
140
141 fn max_x(&self) -> f64 {
142 self.xs[self.xs.len() - 1]
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use approx::assert_relative_eq;
150
151 #[test]
152 fn test_linear_interpolation() {
153 let xs = vec![0.0, 1.0, 2.0];
154 let ys = vec![0.0, 2.0, 4.0];
155
156 let interp = LinearInterpolator::new(xs, ys).unwrap();
157
158 assert_relative_eq!(interp.interpolate(0.0).unwrap(), 0.0, epsilon = 1e-10);
160 assert_relative_eq!(interp.interpolate(1.0).unwrap(), 2.0, epsilon = 1e-10);
161 assert_relative_eq!(interp.interpolate(2.0).unwrap(), 4.0, epsilon = 1e-10);
162
163 assert_relative_eq!(interp.interpolate(0.5).unwrap(), 1.0, epsilon = 1e-10);
165 assert_relative_eq!(interp.interpolate(1.5).unwrap(), 3.0, epsilon = 1e-10);
166 }
167
168 #[test]
169 fn test_extrapolation_disabled() {
170 let xs = vec![0.0, 1.0, 2.0];
171 let ys = vec![0.0, 1.0, 2.0];
172
173 let interp = LinearInterpolator::new(xs, ys).unwrap();
174
175 assert!(interp.interpolate(-0.5).is_err());
176 assert!(interp.interpolate(2.5).is_err());
177 }
178
179 #[test]
180 fn test_extrapolation_enabled() {
181 let xs = vec![0.0, 1.0, 2.0];
182 let ys = vec![0.0, 1.0, 2.0];
183
184 let interp = LinearInterpolator::new(xs, ys)
185 .unwrap()
186 .with_extrapolation();
187
188 assert_relative_eq!(interp.interpolate(-1.0).unwrap(), -1.0, epsilon = 1e-10);
190 assert_relative_eq!(interp.interpolate(3.0).unwrap(), 3.0, epsilon = 1e-10);
191 }
192
193 #[test]
194 fn test_insufficient_points() {
195 let xs = vec![0.0];
196 let ys = vec![1.0];
197
198 assert!(LinearInterpolator::new(xs, ys).is_err());
199 }
200
201 #[test]
202 fn test_unsorted_error() {
203 let xs = vec![1.0, 0.0, 2.0]; let ys = vec![1.0, 0.0, 2.0];
205
206 assert!(LinearInterpolator::new(xs, ys).is_err());
207 }
208}