1use numra_core::Scalar;
8
9use crate::error::InterpError;
10use crate::{search_sorted, validate_data, Interpolant};
11
12pub struct Linear<S: Scalar> {
14 x: Vec<S>,
15 y: Vec<S>,
16}
17
18impl<S: Scalar> Linear<S> {
19 pub fn new(x: &[S], y: &[S]) -> Result<Self, InterpError> {
25 validate_data(x, y, 2)?;
26 Ok(Self {
27 x: x.to_vec(),
28 y: y.to_vec(),
29 })
30 }
31}
32
33impl<S: Scalar> Interpolant<S> for Linear<S> {
34 fn interpolate(&self, x: S) -> S {
35 let i = search_sorted(&self.x, x);
36 let h = self.x[i + 1] - self.x[i];
37 let t = (x - self.x[i]) / h;
38 self.y[i] * (S::ONE - t) + self.y[i + 1] * t
39 }
40
41 fn derivative(&self, x: S) -> Option<S> {
42 let i = search_sorted(&self.x, x);
43 let h = self.x[i + 1] - self.x[i];
44 Some((self.y[i + 1] - self.y[i]) / h)
45 }
46
47 fn integrate(&self, a: S, b: S) -> Option<S> {
48 if b.to_f64() <= a.to_f64() {
49 return Some(S::ZERO);
50 }
51 let i_lo = search_sorted(&self.x, a);
52 let i_hi = search_sorted(&self.x, b);
53
54 let mut result = S::ZERO;
55 for i in i_lo..=i_hi {
56 let x_lo = if i == i_lo { a } else { self.x[i] };
57 let x_hi = if i == i_hi { b } else { self.x[i + 1] };
58 let y_lo = self.interpolate(x_lo);
59 let y_hi = self.interpolate(x_hi);
60 result += (y_lo + y_hi) * (x_hi - x_lo) * S::HALF;
61 }
62 Some(result)
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use approx::assert_relative_eq;
70
71 #[test]
72 fn test_linear_at_knots() {
73 let interp = Linear::new(&[0.0, 1.0, 2.0], &[0.0, 2.0, 1.0]).unwrap();
74 assert_relative_eq!(interp.interpolate(0.0), 0.0, epsilon = 1e-14);
75 assert_relative_eq!(interp.interpolate(1.0), 2.0, epsilon = 1e-14);
76 assert_relative_eq!(interp.interpolate(2.0), 1.0, epsilon = 1e-14);
77 }
78
79 #[test]
80 fn test_linear_midpoints() {
81 let interp = Linear::new(&[0.0, 1.0, 2.0], &[0.0, 2.0, 1.0]).unwrap();
82 assert_relative_eq!(interp.interpolate(0.5), 1.0, epsilon = 1e-14);
83 assert_relative_eq!(interp.interpolate(1.5), 1.5, epsilon = 1e-14);
84 }
85
86 #[test]
87 fn test_linear_derivative() {
88 let interp = Linear::new(&[0.0, 1.0, 3.0], &[0.0, 2.0, 6.0]).unwrap();
89 assert_relative_eq!(interp.derivative(0.5).unwrap(), 2.0, epsilon = 1e-14);
90 assert_relative_eq!(interp.derivative(2.0).unwrap(), 2.0, epsilon = 1e-14);
91 }
92
93 #[test]
94 fn test_linear_integrate() {
95 let interp = Linear::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 2.0]).unwrap();
97 assert_relative_eq!(interp.integrate(0.0, 2.0).unwrap(), 2.0, epsilon = 1e-14);
98 }
99
100 #[test]
101 fn test_linear_extrapolation() {
102 let interp = Linear::new(&[0.0, 1.0], &[0.0, 1.0]).unwrap();
103 assert_relative_eq!(interp.interpolate(-0.5), -0.5, epsilon = 1e-14);
105 assert_relative_eq!(interp.interpolate(1.5), 1.5, epsilon = 1e-14);
106 }
107
108 #[test]
109 fn test_linear_errors() {
110 assert!(Linear::<f64>::new(&[1.0], &[1.0]).is_err());
111 assert!(Linear::new(&[1.0, 2.0], &[1.0]).is_err());
112 assert!(Linear::new(&[2.0, 1.0], &[1.0, 2.0]).is_err());
113 }
114
115 #[test]
116 fn test_linear_f32() {
117 let interp = Linear::new(&[0.0f32, 1.0, 2.0], &[0.0, 1.0, 0.0]).unwrap();
118 assert!((interp.interpolate(0.5f32) - 0.5).abs() < 1e-6);
119 }
120}