burn_core/nn/interpolate/
mod.rs

1mod interpolate1d;
2mod interpolate2d;
3
4pub use interpolate1d::*;
5pub use interpolate2d::*;
6
7use crate::tensor::ops::InterpolateMode as OpsInterpolateMode;
8
9/// Algorithm used for downsampling and upsampling
10///
11/// This enum defines different interpolation modes for resampling data.
12#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
13pub enum InterpolateMode {
14    /// Nearest-neighbor interpolation
15    ///
16    /// This mode selects the value of the nearest sample point for each output pixel.
17    /// It is applicable for both temporal and spatial data.
18    Nearest,
19
20    /// Linear interpolation
21    ///
22    /// This mode calculates the output value using linear
23    /// interpolation between nearby sample points.
24    ///
25    /// It is applicable for both temporal and spatial data.
26    Linear,
27
28    /// Cubic interpolation
29    ///
30    /// This mode uses cubic interpolation to calculate the output value
31    /// based on surrounding sample points.
32    ///
33    /// It is applicable for both temporal and spatial data and generally
34    /// provides smoother results than linear interpolation.
35    Cubic,
36}
37
38impl From<InterpolateMode> for OpsInterpolateMode {
39    fn from(mode: InterpolateMode) -> Self {
40        match mode {
41            InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
42            InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
43            InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
44        }
45    }
46}