1use crate::traits::FloatScalar;
2
3use super::{InterpError, find_interval, validate_sorted};
4
5#[derive(Debug, Clone)]
29pub struct CubicSpline<T, const N: usize> {
30 xs: [T; N],
31 coeffs: [[T; 4]; N],
34}
35
36impl<T: FloatScalar, const N: usize> CubicSpline<T, N> {
37 pub fn new(xs: [T; N], ys: [T; N]) -> Result<Self, InterpError> {
42 if N < 3 {
43 return Err(InterpError::TooFewPoints);
44 }
45 validate_sorted(&xs)?;
46
47 let two = T::one() + T::one();
48 let three = two + T::one();
49 let six = three + three;
50
51 let n = N;
60 let mut h = [T::zero(); N];
62 let mut delta = [T::zero(); N];
63 for i in 0..n - 1 {
64 h[i] = xs[i + 1] - xs[i];
65 delta[i] = (ys[i + 1] - ys[i]) / h[i];
66 }
67
68 let mut m = [T::zero(); N];
71
72 if n > 3 {
73 let mut cp = [T::zero(); N]; let mut dp = [T::zero(); N]; let diag = two * (h[0] + h[1]);
79 if diag.abs() < T::epsilon() {
80 return Err(InterpError::IllConditioned);
81 }
82 let rhs = six * (delta[1] - delta[0]);
83 cp[1] = h[1] / diag;
84 dp[1] = rhs / diag;
85
86 for i in 2..n - 1 {
88 let diag_i = two * (h[i - 1] + h[i]) - h[i - 1] * cp[i - 1];
89 if diag_i.abs() < T::epsilon() {
90 return Err(InterpError::IllConditioned);
91 }
92 let rhs_i = six * (delta[i] - delta[i - 1]) - h[i - 1] * dp[i - 1];
93 if i < n - 1 {
94 cp[i] = h[i] / diag_i;
95 }
96 dp[i] = rhs_i / diag_i;
97 }
98
99 m[n - 2] = dp[n - 2];
101 for i in (1..n - 2).rev() {
102 m[i] = dp[i] - cp[i] * m[i + 1];
103 }
104 } else {
105 let diag = two * (h[0] + h[1]);
107 if diag.abs() < T::epsilon() {
108 return Err(InterpError::IllConditioned);
109 }
110 m[1] = six * (delta[1] - delta[0]) / diag;
111 }
112
113 let mut coeffs = [[T::zero(); 4]; N];
115 for i in 0..n - 1 {
116 let a = ys[i];
117 let b = delta[i] - h[i] * (two * m[i] + m[i + 1]) / six;
118 let c = m[i] / two;
119 let d = (m[i + 1] - m[i]) / (six * h[i]);
120 coeffs[i] = [a, b, c, d];
121 }
122
123 Ok(Self { xs, coeffs })
124 }
125
126 pub fn eval(&self, x: T) -> T {
128 let i = find_interval(&self.xs, x);
129 let dx = x - self.xs[i];
130 let [a, b, c, d] = self.coeffs[i];
131 a + dx * (b + dx * (c + dx * d))
133 }
134
135 pub fn eval_derivative(&self, x: T) -> (T, T) {
137 let i = find_interval(&self.xs, x);
138 let dx = x - self.xs[i];
139 let [a, b, c, d] = self.coeffs[i];
140 let two = T::one() + T::one();
141 let three = two + T::one();
142 let val = a + dx * (b + dx * (c + dx * d));
143 let dval = b + dx * (two * c + three * d * dx);
144 (val, dval)
145 }
146
147 pub fn xs(&self) -> &[T; N] {
149 &self.xs
150 }
151}
152
153#[cfg(feature = "alloc")]
156extern crate alloc;
157#[cfg(feature = "alloc")]
158use alloc::vec::Vec;
159
160#[cfg(feature = "alloc")]
175#[derive(Debug, Clone)]
176pub struct DynCubicSpline<T> {
177 xs: Vec<T>,
178 coeffs: Vec<[T; 4]>,
179}
180
181#[cfg(feature = "alloc")]
182impl<T: FloatScalar> DynCubicSpline<T> {
183 pub fn new(xs: Vec<T>, ys: Vec<T>) -> Result<Self, InterpError> {
185 if xs.len() != ys.len() {
186 return Err(InterpError::LengthMismatch);
187 }
188 if xs.len() < 3 {
189 return Err(InterpError::TooFewPoints);
190 }
191 validate_sorted(&xs)?;
192
193 let n = xs.len();
194 let two = T::one() + T::one();
195 let three = two + T::one();
196 let six = three + three;
197
198 let mut h = alloc::vec![T::zero(); n];
199 let mut delta = alloc::vec![T::zero(); n];
200 for i in 0..n - 1 {
201 h[i] = xs[i + 1] - xs[i];
202 delta[i] = (ys[i + 1] - ys[i]) / h[i];
203 }
204
205 let mut m = alloc::vec![T::zero(); n];
206
207 if n > 3 {
208 let mut cp = alloc::vec![T::zero(); n];
209 let mut dp = alloc::vec![T::zero(); n];
210
211 let diag = two * (h[0] + h[1]);
212 if diag.abs() < T::epsilon() {
213 return Err(InterpError::IllConditioned);
214 }
215 let rhs = six * (delta[1] - delta[0]);
216 cp[1] = h[1] / diag;
217 dp[1] = rhs / diag;
218
219 for i in 2..n - 1 {
220 let diag_i = two * (h[i - 1] + h[i]) - h[i - 1] * cp[i - 1];
221 if diag_i.abs() < T::epsilon() {
222 return Err(InterpError::IllConditioned);
223 }
224 let rhs_i = six * (delta[i] - delta[i - 1]) - h[i - 1] * dp[i - 1];
225 if i < n - 1 {
226 cp[i] = h[i] / diag_i;
227 }
228 dp[i] = rhs_i / diag_i;
229 }
230
231 m[n - 2] = dp[n - 2];
232 for i in (1..n - 2).rev() {
233 m[i] = dp[i] - cp[i] * m[i + 1];
234 }
235 } else {
236 let diag = two * (h[0] + h[1]);
237 if diag.abs() < T::epsilon() {
238 return Err(InterpError::IllConditioned);
239 }
240 m[1] = six * (delta[1] - delta[0]) / diag;
241 }
242
243 let mut coeffs = alloc::vec![[T::zero(); 4]; n - 1];
244 for i in 0..n - 1 {
245 let a = ys[i];
246 let b = delta[i] - h[i] * (two * m[i] + m[i + 1]) / six;
247 let c = m[i] / two;
248 let d = (m[i + 1] - m[i]) / (six * h[i]);
249 coeffs[i] = [a, b, c, d];
250 }
251
252 Ok(Self { xs, coeffs })
253 }
254
255 pub fn eval(&self, x: T) -> T {
257 let i = find_interval(&self.xs, x);
258 let dx = x - self.xs[i];
259 let [a, b, c, d] = self.coeffs[i];
260 a + dx * (b + dx * (c + dx * d))
261 }
262
263 pub fn eval_derivative(&self, x: T) -> (T, T) {
265 let i = find_interval(&self.xs, x);
266 let dx = x - self.xs[i];
267 let [a, b, c, d] = self.coeffs[i];
268 let two = T::one() + T::one();
269 let three = two + T::one();
270 let val = a + dx * (b + dx * (c + dx * d));
271 let dval = b + dx * (two * c + three * d * dx);
272 (val, dval)
273 }
274
275 pub fn xs(&self) -> &[T] {
277 &self.xs
278 }
279}