Skip to main content

numeris/interp/
hermite.rs

1use crate::traits::FloatScalar;
2
3use super::{InterpError, find_interval, validate_sorted};
4
5/// Cubic Hermite interpolant (fixed-size, stack-allocated).
6///
7/// Uses user-supplied derivatives at each knot. Each segment is a cubic
8/// polynomial matching both value and first derivative at its endpoints.
9/// Requires at least 2 points.
10///
11/// # Example
12///
13/// ```
14/// use numeris::interp::HermiteInterp;
15///
16/// // Interpolate sin(x) with known derivatives cos(x)
17/// let xs = [0.0_f64, 1.0, 2.0];
18/// let ys = [0.0_f64.sin(), 1.0_f64.sin(), 2.0_f64.sin()];
19/// let dys = [0.0_f64.cos(), 1.0_f64.cos(), 2.0_f64.cos()];
20/// let interp = HermiteInterp::new(xs, ys, dys).unwrap();
21/// let mid = 0.5;
22/// assert!((interp.eval(mid) - mid.sin()).abs() < 0.002);
23/// ```
24#[derive(Debug, Clone)]
25pub struct HermiteInterp<T, const N: usize> {
26    xs: [T; N],
27    ys: [T; N],
28    dys: [T; N],
29}
30
31impl<T: FloatScalar, const N: usize> HermiteInterp<T, N> {
32    /// Construct a Hermite interpolant from knots and derivatives.
33    pub fn new(xs: [T; N], ys: [T; N], dys: [T; N]) -> Result<Self, InterpError> {
34        if N < 2 {
35            return Err(InterpError::TooFewPoints);
36        }
37        validate_sorted(&xs)?;
38        Ok(Self { xs, ys, dys })
39    }
40
41    /// Evaluate the interpolant at `x`.
42    pub fn eval(&self, x: T) -> T {
43        let i = find_interval(&self.xs, x);
44        let h = self.xs[i + 1] - self.xs[i];
45        let t = (x - self.xs[i]) / h;
46        let t2 = t * t;
47        let t3 = t2 * t;
48        let two = T::one() + T::one();
49        let three = two + T::one();
50
51        // Hermite basis: h00 = 2t³ - 3t² + 1, h10 = t³ - 2t² + t
52        //                h01 = -2t³ + 3t², h11 = t³ - t²
53        let h00 = two * t3 - three * t2 + T::one();
54        let h10 = t3 - two * t2 + t;
55        let h01 = (T::zero() - two) * t3 + three * t2;
56        let h11 = t3 - t2;
57
58        h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1]
59    }
60
61    /// Evaluate the interpolant and its derivative at `x`.
62    pub fn eval_derivative(&self, x: T) -> (T, T) {
63        let i = find_interval(&self.xs, x);
64        let h = self.xs[i + 1] - self.xs[i];
65        let t = (x - self.xs[i]) / h;
66        let t2 = t * t;
67        let t3 = t2 * t;
68        let two = T::one() + T::one();
69        let three = two + T::one();
70        let six = three + three;
71
72        let h00 = two * t3 - three * t2 + T::one();
73        let h10 = t3 - two * t2 + t;
74        let h01 = (T::zero() - two) * t3 + three * t2;
75        let h11 = t3 - t2;
76
77        let val =
78            h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1];
79
80        // d/dx = (1/h) d/dt of the basis
81        // h00' = 6t² - 6t, h10' = 3t² - 4t + 1
82        // h01' = -6t² + 6t, h11' = 3t² - 2t
83        let dh00 = six * t2 - six * t;
84        let dh10 = three * t2 - (two + two) * t + T::one();
85        let dh01 = (T::zero() - six) * t2 + six * t;
86        let dh11 = three * t2 - two * t;
87
88        let dval = (dh00 * self.ys[i] + dh10 * h * self.dys[i] + dh01 * self.ys[i + 1]
89            + dh11 * h * self.dys[i + 1])
90            / h;
91
92        (val, dval)
93    }
94
95    /// The knot x-values.
96    pub fn xs(&self) -> &[T; N] {
97        &self.xs
98    }
99
100    /// The knot y-values.
101    pub fn ys(&self) -> &[T; N] {
102        &self.ys
103    }
104}
105
106// ---------- Dynamic variant ----------
107
108#[cfg(feature = "alloc")]
109extern crate alloc;
110#[cfg(feature = "alloc")]
111use alloc::vec::Vec;
112
113/// Cubic Hermite interpolant (heap-allocated, runtime-sized).
114///
115/// Dynamic counterpart of [`HermiteInterp`]. Requires at least 2 points.
116///
117/// # Example
118///
119/// ```
120/// use numeris::interp::DynHermiteInterp;
121///
122/// let interp = DynHermiteInterp::new(
123///     vec![0.0_f64, 1.0, 2.0],
124///     vec![0.0, 1.0, 0.0],
125///     vec![1.0, 0.0, -1.0],
126/// ).unwrap();
127/// assert!((interp.eval(0.5) - 0.625).abs() < 1e-14);
128/// ```
129#[cfg(feature = "alloc")]
130#[derive(Debug, Clone)]
131pub struct DynHermiteInterp<T> {
132    xs: Vec<T>,
133    ys: Vec<T>,
134    dys: Vec<T>,
135}
136
137#[cfg(feature = "alloc")]
138impl<T: FloatScalar> DynHermiteInterp<T> {
139    /// Construct a Hermite interpolant from knots and derivatives.
140    pub fn new(xs: Vec<T>, ys: Vec<T>, dys: Vec<T>) -> Result<Self, InterpError> {
141        if xs.len() != ys.len() || xs.len() != dys.len() {
142            return Err(InterpError::LengthMismatch);
143        }
144        if xs.len() < 2 {
145            return Err(InterpError::TooFewPoints);
146        }
147        validate_sorted(&xs)?;
148        Ok(Self { xs, ys, dys })
149    }
150
151    /// Evaluate the interpolant at `x`.
152    pub fn eval(&self, x: T) -> T {
153        let i = find_interval(&self.xs, x);
154        let h = self.xs[i + 1] - self.xs[i];
155        let t = (x - self.xs[i]) / h;
156        let t2 = t * t;
157        let t3 = t2 * t;
158        let two = T::one() + T::one();
159        let three = two + T::one();
160
161        let h00 = two * t3 - three * t2 + T::one();
162        let h10 = t3 - two * t2 + t;
163        let h01 = (T::zero() - two) * t3 + three * t2;
164        let h11 = t3 - t2;
165
166        h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1]
167    }
168
169    /// Evaluate the interpolant and its derivative at `x`.
170    pub fn eval_derivative(&self, x: T) -> (T, T) {
171        let i = find_interval(&self.xs, x);
172        let h = self.xs[i + 1] - self.xs[i];
173        let t = (x - self.xs[i]) / h;
174        let t2 = t * t;
175        let t3 = t2 * t;
176        let two = T::one() + T::one();
177        let three = two + T::one();
178        let six = three + three;
179
180        let h00 = two * t3 - three * t2 + T::one();
181        let h10 = t3 - two * t2 + t;
182        let h01 = (T::zero() - two) * t3 + three * t2;
183        let h11 = t3 - t2;
184
185        let val =
186            h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1];
187
188        let dh00 = six * t2 - six * t;
189        let dh10 = three * t2 - (two + two) * t + T::one();
190        let dh01 = (T::zero() - six) * t2 + six * t;
191        let dh11 = three * t2 - two * t;
192
193        let dval = (dh00 * self.ys[i] + dh10 * h * self.dys[i] + dh01 * self.ys[i + 1]
194            + dh11 * h * self.dys[i + 1])
195            / h;
196
197        (val, dval)
198    }
199
200    /// The knot x-values.
201    pub fn xs(&self) -> &[T] {
202        &self.xs
203    }
204
205    /// The knot y-values.
206    pub fn ys(&self) -> &[T] {
207        &self.ys
208    }
209}