1mod bicubic;
19mod bilinear;
20mod bspline;
21mod cubic_spline;
22mod linear;
23pub(super) mod thomas;
24
25pub use bicubic::Bicubic2d;
26pub use bilinear::Bilinear2d;
27pub use bspline::BSpline;
28pub use cubic_spline::CubicSpline;
29pub use linear::Linear1d;
30
31use scivex_core::Float;
32
33use crate::error::{OptimError, Result};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Interp1dMethod {
42 Linear,
44 CubicSpline,
46 BSpline,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum Interp2dMethod {
53 Bilinear,
55 Bicubic,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum SplineBoundary<T> {
62 Natural,
64 Clamped { left: T, right: T },
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum Extrapolate {
71 #[default]
73 Error,
74 Clamp,
76 Extend,
78}
79
80#[inline]
94pub(crate) fn find_interval<T: Float>(xs: &[T], x: T, extrap: Extrapolate) -> Result<(usize, T)> {
95 debug_assert!(xs.len() >= 2);
96 let n = xs.len();
97
98 if x < xs[0] {
99 return match extrap {
100 Extrapolate::Error => Err(OptimError::InvalidParameter {
101 name: "x",
102 reason: "query point is below data range",
103 }),
104 Extrapolate::Clamp => Ok((0, xs[0])),
105 Extrapolate::Extend => Ok((0, x)),
106 };
107 }
108
109 if x > xs[n - 1] {
110 return match extrap {
111 Extrapolate::Error => Err(OptimError::InvalidParameter {
112 name: "x",
113 reason: "query point is above data range",
114 }),
115 Extrapolate::Clamp => Ok((n - 2, xs[n - 1])),
116 Extrapolate::Extend => Ok((n - 2, x)),
117 };
118 }
119
120 if x == xs[n - 1] {
122 return Ok((n - 2, x));
123 }
124
125 let mut lo: usize = 0;
127 let mut hi: usize = n - 1;
128 while hi - lo > 1 {
129 let mid = lo + (hi - lo) / 2;
130 if xs[mid] <= x {
131 lo = mid;
132 } else {
133 hi = mid;
134 }
135 }
136
137 Ok((lo, x))
138}
139
140pub(crate) fn validate_sorted<T: Float>(xs: &[T], min_len: usize) -> Result<()> {
142 if xs.len() < min_len {
143 return Err(OptimError::InvalidParameter {
144 name: "xs",
145 reason: "not enough data points",
146 });
147 }
148 for i in 1..xs.len() {
149 if xs[i] <= xs[i - 1] {
150 return Err(OptimError::InvalidParameter {
151 name: "xs",
152 reason: "knots must be strictly increasing",
153 });
154 }
155 }
156 Ok(())
157}
158
159pub(crate) fn validate_finite<T: Float>(vals: &[T], name: &'static str) -> Result<()> {
161 for &v in vals {
162 if !v.is_finite() {
163 return Err(OptimError::NonFiniteValue { context: name });
164 }
165 }
166 Ok(())
167}
168
169pub fn interp1d<T: Float>(
177 xs: &[T],
178 ys: &[T],
179 query: &[T],
180 method: Interp1dMethod,
181) -> Result<Vec<T>> {
182 match method {
183 Interp1dMethod::Linear => {
184 let interp = Linear1d::new(xs, ys, Extrapolate::Error)?;
185 interp.eval_many(query)
186 }
187 Interp1dMethod::CubicSpline => {
188 let interp = CubicSpline::new(xs, ys, SplineBoundary::Natural, Extrapolate::Error)?;
189 interp.eval_many(query)
190 }
191 Interp1dMethod::BSpline => {
192 let interp = BSpline::fit(xs, ys, 3, Extrapolate::Error)?;
193 interp.eval_many(query)
194 }
195 }
196}
197
198pub fn interp2d<T: Float>(
202 xs: Vec<T>,
203 ys: Vec<T>,
204 zs: Vec<Vec<T>>,
205 query: &[(T, T)],
206 method: Interp2dMethod,
207) -> Result<Vec<T>> {
208 match method {
209 Interp2dMethod::Bilinear => {
210 let interp = Bilinear2d::new(xs, ys, zs, Extrapolate::Error)?;
211 interp.eval_many(query)
212 }
213 Interp2dMethod::Bicubic => {
214 let interp = Bicubic2d::new(xs, ys, &zs, Extrapolate::Error)?;
215 interp.eval_many(query)
216 }
217 }
218}
219
220#[cfg(test)]
221#[allow(clippy::float_cmp)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_find_interval_basic() {
227 let xs = [0.0, 1.0, 2.0, 3.0];
228 let (i, x) = find_interval(&xs, 1.5, Extrapolate::Error).unwrap();
229 assert_eq!(i, 1);
230 assert!((x - 1.5).abs() < 1e-15);
231 }
232
233 #[test]
234 fn test_find_interval_last_point() {
235 let xs = [0.0, 1.0, 2.0, 3.0];
236 let (i, x) = find_interval(&xs, 3.0, Extrapolate::Error).unwrap();
237 assert_eq!(i, 2);
238 assert!((x - 3.0).abs() < 1e-15);
239 }
240
241 #[test]
242 fn test_find_interval_error_below() {
243 let xs = [0.0, 1.0, 2.0];
244 let res = find_interval(&xs, -0.1, Extrapolate::Error);
245 assert!(res.is_err());
246 }
247
248 #[test]
249 fn test_find_interval_clamp_above() {
250 let xs = [0.0, 1.0, 2.0];
251 let (i, x) = find_interval(&xs, 5.0, Extrapolate::Clamp).unwrap();
252 assert_eq!(i, 1);
253 assert!((x - 2.0).abs() < 1e-15);
254 }
255
256 #[test]
257 fn test_find_interval_extend_below() {
258 let xs = [0.0, 1.0, 2.0];
259 let (i, x) = find_interval(&xs, -1.0, Extrapolate::Extend).unwrap();
260 assert_eq!(i, 0);
261 assert!((x - (-1.0)).abs() < 1e-15);
262 }
263
264 #[test]
265 fn test_validate_sorted_ok() {
266 assert!(validate_sorted(&[0.0, 1.0, 2.0], 2).is_ok());
267 }
268
269 #[test]
270 fn test_validate_sorted_too_few() {
271 assert!(validate_sorted(&[0.0_f64], 2).is_err());
272 }
273
274 #[test]
275 fn test_validate_sorted_not_increasing() {
276 assert!(validate_sorted(&[0.0, 2.0, 1.0], 2).is_err());
277 }
278
279 #[test]
280 fn test_interp1d_linear() {
281 let result = interp1d(
282 &[0.0, 1.0, 2.0],
283 &[0.0, 2.0, 4.0],
284 &[0.5, 1.5],
285 Interp1dMethod::Linear,
286 )
287 .unwrap();
288 assert!((result[0] - 1.0).abs() < 1e-12);
289 assert!((result[1] - 3.0).abs() < 1e-12);
290 }
291
292 #[test]
293 fn test_interp1d_cubic_spline() {
294 let result = interp1d(
295 &[0.0, 1.0, 2.0, 3.0],
296 &[0.0, 1.0, 4.0, 9.0],
297 &[1.0, 2.0],
298 Interp1dMethod::CubicSpline,
299 )
300 .unwrap();
301 assert!((result[0] - 1.0).abs() < 1e-10);
302 assert!((result[1] - 4.0).abs() < 1e-10);
303 }
304
305 #[test]
306 fn test_interp1d_bspline() {
307 let result = interp1d(
308 &[0.0, 1.0, 2.0, 3.0, 4.0],
309 &[0.0, 1.0, 4.0, 9.0, 16.0],
310 &[2.0],
311 Interp1dMethod::BSpline,
312 )
313 .unwrap();
314 assert!((result[0] - 4.0).abs() < 1e-6);
315 }
316
317 #[test]
318 fn test_interp2d_bilinear() {
319 let xs = vec![0.0, 1.0];
320 let ys = vec![0.0, 1.0];
321 let zs = vec![vec![0.0, 2.0], vec![1.0, 3.0]]; let result = interp2d(xs, ys, zs, &[(0.5, 0.5)], Interp2dMethod::Bilinear).unwrap();
323 assert!((result[0] - 1.5).abs() < 1e-12);
324 }
325
326 #[test]
327 fn test_interp2d_bicubic() {
328 let xs = vec![0.0, 1.0, 2.0, 3.0];
329 let ys = vec![0.0, 1.0, 2.0, 3.0];
330 let zs: Vec<Vec<f64>> = (0..4)
331 .map(|i| (0..4).map(|j| f64::from(i) + 2.0 * f64::from(j)).collect())
332 .collect();
333 let result = interp2d(xs, ys, zs, &[(1.5, 1.5)], Interp2dMethod::Bicubic).unwrap();
334 assert!((result[0] - 4.5).abs() < 1e-10);
335 }
336}