Skip to main content

scirs2_interpolate/
extrapolation_wrapper.rs

1//! Generic extrapolation wrapper for 1-D interpolators.
2//!
3//! [`ExtrapolatingInterpolator`] wraps any `Fn(f64) -> f64` closure (or any
4//! type that implements [`Interpolate1D`]) and handles out-of-domain queries by
5//! applying a configurable [`ExtrapolationMode`].
6//!
7//! # Examples
8//!
9//! ```rust
10//! use scirs2_interpolate::extrapolation_wrapper::{
11//!     ExtrapolatingInterpolator, ExtrapolationMode,
12//! };
13//!
14//! // A simple piecewise-linear interpolant on [0, 1].
15//! let inner = |x: f64| x * x;          // any Fn(f64) -> f64
16//! let interp = ExtrapolatingInterpolator::new(inner, 0.0, 1.0,
17//!     ExtrapolationMode::Fill(f64::NAN));
18//!
19//! assert!((interp.eval(0.5).unwrap() - 0.25).abs() < 1e-12);
20//! assert!(interp.eval(-0.1).unwrap().is_nan());
21//! ```
22//!
23//! ## Periodic wrapping
24//!
25//! ```rust
26//! use std::f64::consts::PI;
27//! use scirs2_interpolate::extrapolation_wrapper::{
28//!     ExtrapolatingInterpolator, ExtrapolationMode,
29//! };
30//!
31//! let inner = |x: f64| x.sin();
32//! let interp = ExtrapolatingInterpolator::new(inner, 0.0, 2.0 * PI,
33//!     ExtrapolationMode::Periodic);
34//!
35//! // x = 3π is equivalent to x = π (period = 2π)
36//! let y1 = interp.eval(PI).unwrap();
37//! let y2 = interp.eval(3.0 * PI).unwrap();
38//! assert!((y1 - y2).abs() < 1e-10);
39//! ```
40
41use crate::error::{InterpolateError, InterpolateResult};
42
43// ─────────────────────────────────────────────────────────────────────────────
44// ExtrapolationMode
45// ─────────────────────────────────────────────────────────────────────────────
46
47/// How to handle queries outside the interpolation domain `[x_min, x_max]`.
48#[derive(Debug, Clone, PartialEq)]
49pub enum ExtrapolationMode {
50    /// Clamp `x` to the nearest boundary and evaluate the interpolant there.
51    Nearest,
52    /// Linear extrapolation using a finite-difference gradient at the boundary.
53    ///
54    /// The gradient is estimated with step `h` (defaults to
55    /// `(x_max - x_min) * 1e-5`).
56    Linear,
57    /// Polynomial extrapolation of degree `deg` using `deg+1` equally-spaced
58    /// points near the boundary.
59    Polynomial(usize),
60    /// Reflect `x` about the boundary and evaluate at the reflected point.
61    ///
62    /// For a lower boundary at `x_min`: reflected point is `2 x_min - x`.
63    /// For an upper boundary at `x_max`: reflected point is `2 x_max - x`.
64    Reflect,
65    /// Wrap `x` periodically so the period equals `x_max - x_min`.
66    Periodic,
67    /// Return `v` for every out-of-domain query.
68    Fill(f64),
69    /// Return [`InterpolateError::OutOfBounds`] for out-of-domain queries.
70    Error,
71}
72
73// ─────────────────────────────────────────────────────────────────────────────
74// Trait for 1-D interpolants
75// ─────────────────────────────────────────────────────────────────────────────
76
77/// Minimal trait for 1-D interpolants that can be wrapped with extrapolation
78/// handling.  Implemented automatically for `Fn(f64) -> f64`.
79pub trait Interpolate1D {
80    /// Evaluate the interpolant at `x`.  May assume `x` is within the
81    /// training domain.
82    fn interpolate(&self, x: f64) -> f64;
83}
84
85impl<F: Fn(f64) -> f64> Interpolate1D for F {
86    fn interpolate(&self, x: f64) -> f64 {
87        (self)(x)
88    }
89}
90
91// ─────────────────────────────────────────────────────────────────────────────
92// ExtrapolatingInterpolator
93// ─────────────────────────────────────────────────────────────────────────────
94
95/// A 1-D interpolant wrapped with configurable out-of-domain behaviour.
96///
97/// The same mode applies to both the lower (`x < x_min`) and upper
98/// (`x > x_max`) boundaries.  Use [`ExtrapolatingInterpolatorAsymmetric`] if
99/// you need different modes per boundary.
100pub struct ExtrapolatingInterpolator<I: Interpolate1D> {
101    inner: I,
102    x_min: f64,
103    x_max: f64,
104    mode: ExtrapolationMode,
105}
106
107impl<I: Interpolate1D> ExtrapolatingInterpolator<I> {
108    /// Create a new wrapped interpolant.
109    ///
110    /// # Arguments
111    /// * `inner`  – The underlying interpolant.
112    /// * `x_min`  – Lower bound of the valid domain.
113    /// * `x_max`  – Upper bound of the valid domain (must be > `x_min`).
114    /// * `mode`   – Extrapolation strategy.
115    pub fn new(inner: I, x_min: f64, x_max: f64, mode: ExtrapolationMode) -> Self {
116        assert!(
117            x_max > x_min,
118            "x_max ({x_max}) must be strictly greater than x_min ({x_min})"
119        );
120        Self {
121            inner,
122            x_min,
123            x_max,
124            mode,
125        }
126    }
127
128    /// Evaluate the (possibly extrapolated) interpolant at `x`.
129    pub fn eval(&self, x: f64) -> InterpolateResult<f64> {
130        if x >= self.x_min && x <= self.x_max {
131            return Ok(self.inner.interpolate(x));
132        }
133        let period = self.x_max - self.x_min;
134        match &self.mode {
135            ExtrapolationMode::Nearest => {
136                let clamped = x.clamp(self.x_min, self.x_max);
137                Ok(self.inner.interpolate(clamped))
138            }
139
140            ExtrapolationMode::Linear => {
141                let h = period * 1e-5;
142                if x < self.x_min {
143                    // Gradient at left boundary by forward difference.
144                    let f0 = self.inner.interpolate(self.x_min);
145                    let f1 = self.inner.interpolate(self.x_min + h);
146                    let slope = (f1 - f0) / h;
147                    Ok(f0 + slope * (x - self.x_min))
148                } else {
149                    // Gradient at right boundary by backward difference.
150                    let f0 = self.inner.interpolate(self.x_max);
151                    let f1 = self.inner.interpolate(self.x_max - h);
152                    let slope = (f0 - f1) / h;
153                    Ok(f0 + slope * (x - self.x_max))
154                }
155            }
156
157            ExtrapolationMode::Polynomial(deg) => {
158                let deg = *deg;
159                self.poly_extrapolate(x, deg)
160            }
161
162            ExtrapolationMode::Reflect => {
163                let mapped = if x < self.x_min {
164                    2.0 * self.x_min - x
165                } else {
166                    2.0 * self.x_max - x
167                };
168                // The reflected point may still be outside domain; clamp to be safe.
169                let clamped = mapped.clamp(self.x_min, self.x_max);
170                Ok(self.inner.interpolate(clamped))
171            }
172
173            ExtrapolationMode::Periodic => {
174                let wrapped = wrap_periodic(x, self.x_min, self.x_max);
175                Ok(self.inner.interpolate(wrapped))
176            }
177
178            ExtrapolationMode::Fill(v) => Ok(*v),
179
180            ExtrapolationMode::Error => Err(InterpolateError::OutOfBounds(format!(
181                "x={x:.6} is outside domain [{:.6}, {:.6}]",
182                self.x_min, self.x_max
183            ))),
184        }
185    }
186
187    /// Polynomial extrapolation of degree `deg` using `deg+1` boundary samples.
188    ///
189    /// Samples are taken in the interior near the violated boundary and the
190    /// polynomial is evaluated at `x` via Lagrange interpolation.
191    fn poly_extrapolate(&self, x: f64, deg: usize) -> InterpolateResult<f64> {
192        let n = deg + 1;
193        let h = period_step(self.x_min, self.x_max, n);
194        // Choose `n` sample nodes near the boundary that was violated.
195        let nodes: Vec<f64> = if x < self.x_min {
196            (0..n).map(|k| self.x_min + k as f64 * h).collect()
197        } else {
198            (0..n)
199                .map(|k| self.x_max - (n - 1 - k) as f64 * h)
200                .collect()
201        };
202        let ys: Vec<f64> = nodes.iter().map(|&xi| self.inner.interpolate(xi)).collect();
203        Ok(lagrange_eval(&nodes, &ys, x))
204    }
205
206    /// Domain lower bound.
207    pub fn x_min(&self) -> f64 {
208        self.x_min
209    }
210
211    /// Domain upper bound.
212    pub fn x_max(&self) -> f64 {
213        self.x_max
214    }
215
216    /// Reference to the extrapolation mode.
217    pub fn mode(&self) -> &ExtrapolationMode {
218        &self.mode
219    }
220}
221
222// ─────────────────────────────────────────────────────────────────────────────
223// Asymmetric wrapper (different modes per boundary)
224// ─────────────────────────────────────────────────────────────────────────────
225
226/// Like [`ExtrapolatingInterpolator`] but with independent extrapolation modes
227/// for the lower and upper boundaries.
228pub struct ExtrapolatingInterpolatorAsymmetric<I: Interpolate1D> {
229    inner: I,
230    x_min: f64,
231    x_max: f64,
232    lower_mode: ExtrapolationMode,
233    upper_mode: ExtrapolationMode,
234}
235
236impl<I: Interpolate1D> ExtrapolatingInterpolatorAsymmetric<I> {
237    /// Create a new asymmetric wrapper.
238    pub fn new(
239        inner: I,
240        x_min: f64,
241        x_max: f64,
242        lower_mode: ExtrapolationMode,
243        upper_mode: ExtrapolationMode,
244    ) -> Self {
245        assert!(x_max > x_min);
246        Self {
247            inner,
248            x_min,
249            x_max,
250            lower_mode,
251            upper_mode,
252        }
253    }
254
255    /// Evaluate, applying the appropriate boundary mode.
256    pub fn eval(&self, x: f64) -> InterpolateResult<f64> {
257        if x >= self.x_min && x <= self.x_max {
258            return Ok(self.inner.interpolate(x));
259        }
260        let mode = if x < self.x_min {
261            &self.lower_mode
262        } else {
263            &self.upper_mode
264        };
265        // Delegate to the symmetric impl by constructing a temporary one.
266        let tmp = ExtrapolatingInterpolator {
267            inner: DummyInner(&self.inner),
268            x_min: self.x_min,
269            x_max: self.x_max,
270            mode: mode.clone(),
271        };
272        tmp.eval(x)
273    }
274}
275
276/// Helper wrapper so we can borrow the inner interpolant.
277struct DummyInner<'a, I: Interpolate1D>(&'a I);
278
279impl<'a, I: Interpolate1D> Interpolate1D for DummyInner<'a, I> {
280    fn interpolate(&self, x: f64) -> f64 {
281        self.0.interpolate(x)
282    }
283}
284
285// ─────────────────────────────────────────────────────────────────────────────
286// Internal helpers
287// ─────────────────────────────────────────────────────────────────────────────
288
289/// Wrap `x` into `[x_min, x_max)` with period = `x_max - x_min`.
290fn wrap_periodic(x: f64, x_min: f64, x_max: f64) -> f64 {
291    let period = x_max - x_min;
292    let shifted = x - x_min;
293    let wrapped = shifted - period * (shifted / period).floor();
294    (x_min + wrapped).clamp(x_min, x_max)
295}
296
297/// Spacing for `n` equidistant nodes that fit inside `[x_min, x_max]`.
298fn period_step(x_min: f64, x_max: f64, n: usize) -> f64 {
299    if n <= 1 {
300        0.0
301    } else {
302        (x_max - x_min) / (n - 1) as f64
303    }
304}
305
306/// Evaluate the Lagrange interpolating polynomial defined by `(nodes, values)`
307/// at `x`.
308fn lagrange_eval(nodes: &[f64], values: &[f64], x: f64) -> f64 {
309    let n = nodes.len();
310    let mut result = 0.0_f64;
311    for i in 0..n {
312        let mut basis = 1.0_f64;
313        for j in 0..n {
314            if i != j {
315                let denom = nodes[i] - nodes[j];
316                if denom.abs() < 1e-300 {
317                    continue;
318                }
319                basis *= (x - nodes[j]) / denom;
320            }
321        }
322        result += values[i] * basis;
323    }
324    result
325}
326
327// ─────────────────────────────────────────────────────────────────────────────
328// Tests
329// ─────────────────────────────────────────────────────────────────────────────
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use std::f64::consts::PI;
335
336    // Helper: linear interpolant on [0, 1] with slope 1.
337    fn linear_unit() -> impl Fn(f64) -> f64 {
338        |x| x
339    }
340
341    // Helper: sin on [0, 2π].
342    fn sin_interp() -> impl Fn(f64) -> f64 {
343        |x: f64| x.sin()
344    }
345
346    #[test]
347    fn test_extrapolation_nearest_below() {
348        let interp =
349            ExtrapolatingInterpolator::new(linear_unit(), 0.0, 1.0, ExtrapolationMode::Nearest);
350        // x < 0 → clamp to 0.0 → f(0) = 0.
351        let val = interp.eval(-0.5).expect("should succeed");
352        assert!((val - 0.0).abs() < 1e-12, "nearest below: {val}");
353    }
354
355    #[test]
356    fn test_extrapolation_nearest_above() {
357        let interp =
358            ExtrapolatingInterpolator::new(linear_unit(), 0.0, 1.0, ExtrapolationMode::Nearest);
359        // x > 1 → clamp to 1.0 → f(1) = 1.
360        let val = interp.eval(2.0).expect("should succeed");
361        assert!((val - 1.0).abs() < 1e-12, "nearest above: {val}");
362    }
363
364    #[test]
365    fn test_extrapolation_linear_below() {
366        // f(x) = 2x + 3. Slope = 2.  Evaluate at x = -1.
367        let inner = |x: f64| 2.0 * x + 3.0;
368        let interp = ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Linear);
369        // linear extrapolation: f(0) + slope * (-1 - 0) = 3 + 2*(-1) = 1
370        let val = interp.eval(-1.0).expect("linear below");
371        assert!((val - 1.0).abs() < 1e-4, "linear extrap below: {val}");
372    }
373
374    #[test]
375    fn test_extrapolation_linear_above() {
376        let inner = |x: f64| 2.0 * x + 3.0;
377        let interp = ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Linear);
378        // f(1) = 5, slope ≈ 2, at x=2: 5 + 2*(2-1) = 7
379        let val = interp.eval(2.0).expect("linear above");
380        assert!((val - 7.0).abs() < 1e-3, "linear extrap above: {val}");
381    }
382
383    #[test]
384    fn test_extrapolation_fill() {
385        let fill_val = -999.0;
386        let interp = ExtrapolatingInterpolator::new(
387            linear_unit(),
388            0.0,
389            1.0,
390            ExtrapolationMode::Fill(fill_val),
391        );
392        assert_eq!(interp.eval(-5.0).unwrap(), fill_val);
393        assert_eq!(interp.eval(5.0).unwrap(), fill_val);
394    }
395
396    #[test]
397    fn test_extrapolation_fill_nan() {
398        let interp = ExtrapolatingInterpolator::new(
399            linear_unit(),
400            0.0,
401            1.0,
402            ExtrapolationMode::Fill(f64::NAN),
403        );
404        assert!(interp.eval(-1.0).unwrap().is_nan());
405    }
406
407    #[test]
408    fn test_extrapolation_error_mode() {
409        let interp =
410            ExtrapolatingInterpolator::new(linear_unit(), 0.0, 1.0, ExtrapolationMode::Error);
411        assert!(interp.eval(-0.1).is_err(), "Should error below range");
412        assert!(interp.eval(1.1).is_err(), "Should error above range");
413        // Inside domain is fine.
414        assert!(interp.eval(0.5).is_ok());
415    }
416
417    #[test]
418    fn test_extrapolation_periodic() {
419        // sin on [0, 2π] is periodic.
420        let interp = ExtrapolatingInterpolator::new(
421            sin_interp(),
422            0.0,
423            2.0 * PI,
424            ExtrapolationMode::Periodic,
425        );
426        // sin(π) ≈ sin(3π) since period = 2π
427        let y1 = interp.eval(PI).unwrap();
428        let y2 = interp.eval(3.0 * PI).unwrap();
429        assert!(
430            (y1 - y2).abs() < 1e-10,
431            "Periodic: sin(π)={y1} should equal sin(3π)={y2}"
432        );
433    }
434
435    #[test]
436    fn test_extrapolation_periodic_negative() {
437        let interp = ExtrapolatingInterpolator::new(
438            sin_interp(),
439            0.0,
440            2.0 * PI,
441            ExtrapolationMode::Periodic,
442        );
443        // sin(x) = sin(x + 2π), so sin(-π/2) ≈ sin(3π/2)
444        let y1 = interp.eval(-PI / 2.0).unwrap();
445        let y2 = interp.eval(3.0 * PI / 2.0).unwrap();
446        assert!((y1 - y2).abs() < 1e-10, "Periodic negative: {y1} vs {y2}");
447    }
448
449    #[test]
450    fn test_extrapolation_reflect_below() {
451        // f(x) = x² on [0, 1]. Reflect x=-0.3 → x=0.3.
452        let interp =
453            ExtrapolatingInterpolator::new(|x: f64| x * x, 0.0, 1.0, ExtrapolationMode::Reflect);
454        let val = interp.eval(-0.3).unwrap();
455        let expected = 0.3_f64 * 0.3;
456        assert!(
457            (val - expected).abs() < 1e-12,
458            "reflect below: {val} vs {expected}"
459        );
460    }
461
462    #[test]
463    fn test_extrapolation_reflect_above() {
464        let interp =
465            ExtrapolatingInterpolator::new(|x: f64| x * x, 0.0, 1.0, ExtrapolationMode::Reflect);
466        // Reflect x=1.4 → 2*1-1.4 = 0.6
467        let val = interp.eval(1.4).unwrap();
468        let expected = 0.6_f64 * 0.6;
469        assert!(
470            (val - expected).abs() < 1e-12,
471            "reflect above: {val} vs {expected}"
472        );
473    }
474
475    #[test]
476    fn test_extrapolation_polynomial_linear_exact() {
477        // f(x) = x + 1 (degree 1). Polynomial extrapolation with deg=1 should reproduce this.
478        let inner = |x: f64| x + 1.0;
479        let interp =
480            ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Polynomial(1));
481        let val = interp.eval(-0.5).unwrap();
482        let expected = -0.5 + 1.0; // = 0.5
483        assert!(
484            (val - expected).abs() < 1e-8,
485            "poly extrap degree 1: {val} vs {expected}"
486        );
487    }
488
489    #[test]
490    fn test_extrapolation_polynomial_quadratic() {
491        // f(x) = x². Polynomial(2) should extrapolate x²  exactly.
492        let inner = |x: f64| x * x;
493        let interp =
494            ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Polynomial(2));
495        let val = interp.eval(2.0).unwrap();
496        // Lagrange extrapolation of x² should give 4.
497        assert!(
498            (val - 4.0).abs() < 1e-6,
499            "poly extrap degree 2 above: {val}"
500        );
501    }
502
503    #[test]
504    fn test_inside_domain_uses_inner() {
505        let interp =
506            ExtrapolatingInterpolator::new(|x: f64| x * x, 0.0, 1.0, ExtrapolationMode::Error);
507        assert!((interp.eval(0.5).unwrap() - 0.25).abs() < 1e-15);
508    }
509
510    #[test]
511    fn test_asymmetric_different_modes() {
512        // Lower: Nearest, Upper: Error
513        let interp = ExtrapolatingInterpolatorAsymmetric::new(
514            linear_unit(),
515            0.0,
516            1.0,
517            ExtrapolationMode::Nearest,
518            ExtrapolationMode::Error,
519        );
520        // Below: nearest
521        let below = interp.eval(-0.5).expect("lower Nearest");
522        assert!((below - 0.0).abs() < 1e-12);
523        // Above: error
524        let above = interp.eval(1.5);
525        assert!(above.is_err(), "upper Error mode should fail");
526    }
527
528    #[test]
529    fn test_extrapolation_in_range_all_modes() {
530        let modes = vec![
531            ExtrapolationMode::Nearest,
532            ExtrapolationMode::Linear,
533            ExtrapolationMode::Polynomial(2),
534            ExtrapolationMode::Reflect,
535            ExtrapolationMode::Periodic,
536            ExtrapolationMode::Fill(0.0),
537            ExtrapolationMode::Error,
538        ];
539        for mode in modes {
540            let interp = ExtrapolatingInterpolator::new(|x: f64| x, 0.0, 1.0, mode);
541            // In-range queries should always succeed.
542            let val = interp
543                .eval(0.5)
544                .expect("in-range should succeed for any mode");
545            assert!((val - 0.5).abs() < 1e-12, "in-range eval failed: {val}");
546        }
547    }
548
549    #[test]
550    fn test_lagrange_eval_linear() {
551        let nodes = vec![0.0, 1.0];
552        let vals = vec![1.0, 3.0]; // f(x) = 2x+1
553        let y = lagrange_eval(&nodes, &vals, 2.0);
554        assert!((y - 5.0).abs() < 1e-10, "Lagrange extrapolation: {y}");
555    }
556
557    #[test]
558    fn test_wrap_periodic() {
559        // 5.0 is within [0, 2π] ≈ [0, 6.283], so it should wrap to itself.
560        let wrapped_inside = wrap_periodic(5.0, 0.0, 2.0 * std::f64::consts::PI);
561        assert!(
562            wrapped_inside >= 0.0 && wrapped_inside <= 2.0 * std::f64::consts::PI,
563            "inside wrap failed: {wrapped_inside}"
564        );
565        assert!(
566            (wrapped_inside - 5.0).abs() < 1e-12,
567            "inside wrap should be 5.0, got {wrapped_inside}"
568        );
569
570        // 7.0 > 2π: should wrap to 7.0 - 2π ≈ 0.717
571        let wrapped_above = wrap_periodic(7.0, 0.0, 2.0 * std::f64::consts::PI);
572        let expected_above = 7.0 - 2.0 * std::f64::consts::PI;
573        assert!(
574            (wrapped_above - expected_above).abs() < 1e-12,
575            "above wrap: {wrapped_above} vs {expected_above}"
576        );
577
578        // -1.0 < 0: should wrap to -1.0 + 2π
579        let wrapped_below = wrap_periodic(-1.0, 0.0, 2.0 * std::f64::consts::PI);
580        let expected_below = -1.0 + 2.0 * std::f64::consts::PI;
581        assert!(
582            (wrapped_below - expected_below).abs() < 1e-12,
583            "below wrap: {wrapped_below} vs {expected_below}"
584        );
585    }
586}