convex_math/interpolation/
log_linear.rs1use crate::error::{MathError, MathResult};
7use crate::interpolation::Interpolator;
8
9#[derive(Debug, Clone)]
35pub struct LogLinearInterpolator {
36 xs: Vec<f64>,
37 ys: Vec<f64>,
38 log_ys: Vec<f64>,
40 allow_extrapolation: bool,
41}
42
43impl LogLinearInterpolator {
44 pub fn new(xs: Vec<f64>, ys: Vec<f64>) -> MathResult<Self> {
58 if xs.len() < 2 {
59 return Err(MathError::insufficient_data(2, xs.len()));
60 }
61 if xs.len() != ys.len() {
62 return Err(MathError::invalid_input(format!(
63 "xs and ys must have same length: {} vs {}",
64 xs.len(),
65 ys.len()
66 )));
67 }
68
69 for i in 1..xs.len() {
71 if xs[i] <= xs[i - 1] {
72 return Err(MathError::invalid_input(
73 "x values must be strictly increasing",
74 ));
75 }
76 }
77
78 let mut log_ys = Vec::with_capacity(ys.len());
80 for (i, &y) in ys.iter().enumerate() {
81 if y <= 0.0 {
82 return Err(MathError::invalid_input(format!(
83 "y[{i}] = {y} is not positive; log-linear requires positive values"
84 )));
85 }
86 log_ys.push(y.ln());
87 }
88
89 Ok(Self {
90 xs,
91 ys,
92 log_ys,
93 allow_extrapolation: false,
94 })
95 }
96
97 #[must_use]
99 pub fn with_extrapolation(mut self) -> Self {
100 self.allow_extrapolation = true;
101 self
102 }
103
104 fn find_segment(&self, x: f64) -> usize {
106 match self
107 .xs
108 .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
109 {
110 Ok(i) => i.min(self.xs.len() - 2),
111 Err(i) => (i.saturating_sub(1)).min(self.xs.len() - 2),
112 }
113 }
114
115 #[must_use]
117 pub fn y_values(&self) -> &[f64] {
118 &self.ys
119 }
120}
121
122impl Interpolator for LogLinearInterpolator {
123 fn interpolate(&self, x: f64) -> MathResult<f64> {
124 if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
126 return Err(MathError::ExtrapolationNotAllowed {
127 x,
128 min: self.xs[0],
129 max: self.xs[self.xs.len() - 1],
130 });
131 }
132
133 let i = self.find_segment(x);
134
135 let x0 = self.xs[i];
136 let x1 = self.xs[i + 1];
137 let log_y0 = self.log_ys[i];
138 let log_y1 = self.log_ys[i + 1];
139
140 let t = (x - x0) / (x1 - x0);
142 let log_y = log_y0 + t * (log_y1 - log_y0);
143
144 Ok(log_y.exp())
145 }
146
147 fn derivative(&self, x: f64) -> MathResult<f64> {
148 if !self.allow_extrapolation && (x < self.xs[0] || x > self.xs[self.xs.len() - 1]) {
150 return Err(MathError::ExtrapolationNotAllowed {
151 x,
152 min: self.xs[0],
153 max: self.xs[self.xs.len() - 1],
154 });
155 }
156
157 let i = self.find_segment(x);
158
159 let x0 = self.xs[i];
160 let x1 = self.xs[i + 1];
161 let log_y0 = self.log_ys[i];
162 let log_y1 = self.log_ys[i + 1];
163
164 let t = (x - x0) / (x1 - x0);
167 let log_y = log_y0 + t * (log_y1 - log_y0);
168 let y = log_y.exp();
169 let d_log_y = (log_y1 - log_y0) / (x1 - x0);
170
171 Ok(y * d_log_y)
172 }
173
174 fn allows_extrapolation(&self) -> bool {
175 self.allow_extrapolation
176 }
177
178 fn min_x(&self) -> f64 {
179 self.xs[0]
180 }
181
182 fn max_x(&self) -> f64 {
183 self.xs[self.xs.len() - 1]
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use approx::assert_relative_eq;
191
192 #[test]
193 fn test_log_linear_through_points() {
194 let xs = vec![0.0, 1.0, 2.0, 3.0];
195 let ys = vec![1.0, 0.97, 0.94, 0.91];
196
197 let interp = LogLinearInterpolator::new(xs.clone(), ys.clone()).unwrap();
198
199 for (x, y) in xs.iter().zip(ys.iter()) {
201 assert_relative_eq!(interp.interpolate(*x).unwrap(), *y, epsilon = 1e-10);
202 }
203 }
204
205 #[test]
206 fn test_log_linear_positive_values() {
207 let xs = vec![0.0, 1.0, 2.0, 3.0];
208 let ys = vec![1.0, 0.5, 0.25, 0.125];
209
210 let interp = LogLinearInterpolator::new(xs, ys).unwrap();
211
212 for x in [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] {
214 let y = interp.interpolate(x).unwrap();
215 assert!(y > 0.0, "y({}) = {} should be positive", x, y);
216 }
217 }
218
219 #[test]
220 fn test_log_linear_exponential_decay() {
221 let r: f64 = 0.05;
223 let xs = vec![0.0, 1.0, 2.0, 3.0];
224 let ys: Vec<f64> = xs.iter().map(|&t: &f64| (-r * t).exp()).collect();
225
226 let interp = LogLinearInterpolator::new(xs, ys).unwrap();
227
228 let t = 1.5;
230 let expected = (-r * t).exp();
231 assert_relative_eq!(interp.interpolate(t).unwrap(), expected, epsilon = 1e-10);
232 }
233
234 #[test]
235 fn test_log_linear_derivative() {
236 let r: f64 = 0.05;
238 let xs = vec![0.0, 1.0, 2.0, 3.0];
239 let ys: Vec<f64> = xs.iter().map(|&t: &f64| (-r * t).exp()).collect();
240
241 let interp = LogLinearInterpolator::new(xs, ys).unwrap();
242
243 let t = 1.5;
244 let expected_derivative = -r * (-r * t).exp();
245 assert_relative_eq!(
246 interp.derivative(t).unwrap(),
247 expected_derivative,
248 epsilon = 1e-10
249 );
250 }
251
252 #[test]
253 fn test_log_linear_rejects_non_positive() {
254 let xs = vec![0.0, 1.0, 2.0];
255 let ys = vec![1.0, 0.0, -1.0]; assert!(LogLinearInterpolator::new(xs, ys).is_err());
258 }
259
260 #[test]
261 fn test_log_linear_extrapolation_disabled() {
262 let xs = vec![0.0, 1.0, 2.0];
263 let ys = vec![1.0, 0.9, 0.8];
264
265 let interp = LogLinearInterpolator::new(xs, ys).unwrap();
266
267 assert!(interp.interpolate(-0.5).is_err());
268 assert!(interp.interpolate(2.5).is_err());
269 }
270
271 #[test]
272 fn test_log_linear_extrapolation_enabled() {
273 let xs = vec![0.0, 1.0, 2.0];
274 let ys = vec![1.0, 0.9, 0.81];
275
276 let interp = LogLinearInterpolator::new(xs, ys)
277 .unwrap()
278 .with_extrapolation();
279
280 let y_left = interp.interpolate(-0.5).unwrap();
282 let y_right = interp.interpolate(2.5).unwrap();
283
284 assert!(y_left > 0.0);
285 assert!(y_right > 0.0);
286 }
287
288 #[test]
289 fn test_log_linear_discount_factors() {
290 let times = vec![0.25, 0.5, 1.0, 2.0, 3.0, 5.0];
292 let dfs = vec![0.9975, 0.9950, 0.9901, 0.9802, 0.9706, 0.9512];
293
294 let interp = LogLinearInterpolator::new(times.clone(), dfs.clone()).unwrap();
295
296 let mut prev = interp.interpolate(times[0]).unwrap();
298 for t in [0.3, 0.75, 1.5, 2.5, 4.0] {
299 let current = interp.interpolate(t).unwrap();
300 assert!(
301 current < prev,
302 "DF should decrease: DF({}) = {} should be < {}",
303 t,
304 current,
305 prev
306 );
307 prev = current;
308 }
309 }
310}