Skip to main content

runmat_runtime/builtins/math/interpolation/
interp1.rs

1//! MATLAB-compatible `interp1` builtin for dense real numeric data.
2
3use runmat_builtins::{ResolveContext, Type, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
9};
10
11use super::pp::{
12    build_pchip_pp, build_spline_pp, evaluate_linear_or_nearest, evaluate_pp,
13    implicit_series_from_values, interp_error, parse_extrapolation, parse_method, query_points,
14    series_from_values, Extrapolation, InterpMethod,
15};
16
17const NAME: &str = "interp1";
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::interp1")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: NAME,
22    op_kind: GpuOpKind::Custom("interpolation-1d"),
23    supported_precisions: &[ScalarType::F32, ScalarType::F64],
24    broadcast: BroadcastSemantics::Matlab,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes: "Initial implementation gathers GPU inputs to the CPU reference path. Provider kernels can later accelerate linear and nearest evaluation.",
33};
34
35#[runmat_macros::register_fusion_spec(
36    builtin_path = "crate::builtins::math::interpolation::interp1"
37)]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: NAME,
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: true,
45    notes: "Interpolation is currently a runtime sink.",
46};
47
48fn interp1_type(args: &[Type], _ctx: &ResolveContext) -> Type {
49    let query = match args.len() {
50        0 | 1 => return Type::tensor(),
51        2 => args.get(1),
52        _ => args.get(2),
53    };
54    match query {
55        Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
56        Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
57            shape: shape.clone(),
58        },
59        _ => Type::tensor(),
60    }
61}
62
63#[runtime_builtin(
64    name = "interp1",
65    category = "math/interpolation",
66    summary = "One-dimensional interpolation for sampled data.",
67    keywords = "interp1,interpolation,linear,nearest,spline,pchip",
68    accel = "sink",
69    sink = true,
70    type_resolver(interp1_type),
71    builtin_path = "crate::builtins::math::interpolation::interp1"
72)]
73async fn interp1_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
74    let parsed = ParsedInterp1::parse(args).await?;
75    match parsed.method {
76        InterpMethod::Linear | InterpMethod::Nearest => evaluate_linear_or_nearest(
77            &parsed.series,
78            &parsed.query,
79            parsed.method,
80            &parsed.extrap,
81            NAME,
82        ),
83        InterpMethod::Spline => {
84            let pp = build_spline_pp(&parsed.series, NAME)?;
85            evaluate_pp(&pp, &parsed.query, &parsed.extrap_for_cubic(), NAME)
86        }
87        InterpMethod::Pchip => {
88            let pp = build_pchip_pp(&parsed.series, NAME)?;
89            evaluate_pp(&pp, &parsed.query, &parsed.extrap_for_cubic(), NAME)
90        }
91    }
92}
93
94struct ParsedInterp1 {
95    series: super::pp::NumericSeries,
96    query: super::pp::QueryPoints,
97    method: InterpMethod,
98    extrap: Extrapolation,
99}
100
101impl ParsedInterp1 {
102    async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
103        if args.len() < 2 {
104            return Err(interp_error(
105                NAME,
106                "interp1: expected at least Y and Xq arguments",
107            ));
108        }
109
110        let mut method = InterpMethod::Linear;
111        let mut extrap = Extrapolation::Nan;
112        let (series, query, options) = if args.len() == 2 || third_arg_is_option(&args) {
113            let mut iter = args.into_iter();
114            let y = iter.next().expect("Y argument");
115            let xq = iter.next().expect("Xq argument");
116            let series = implicit_series_from_values(y, NAME).await?;
117            let query = query_points(xq, NAME).await?;
118            (series, query, iter.collect::<Vec<_>>())
119        } else {
120            let mut iter = args.into_iter();
121            let x = iter.next().expect("X argument");
122            let y = iter.next().expect("Y argument");
123            let xq = iter.next().expect("Xq argument");
124            let series = series_from_values(x, y, NAME).await?;
125            let query = query_points(xq, NAME).await?;
126            (series, query, iter.collect::<Vec<_>>())
127        };
128
129        for option in &options {
130            if let Some(parsed) = parse_extrapolation(option, NAME).await? {
131                extrap = parsed;
132                continue;
133            }
134            if let Some(parsed) = parse_method(option, NAME)? {
135                method = parsed;
136                continue;
137            }
138            return Err(interp_error(
139                NAME,
140                "interp1: unsupported interpolation option",
141            ));
142        }
143
144        Ok(Self {
145            series,
146            query,
147            method,
148            extrap,
149        })
150    }
151
152    fn extrap_for_cubic(&self) -> Extrapolation {
153        self.extrap.clone()
154    }
155}
156
157fn third_arg_is_option(args: &[Value]) -> bool {
158    args.get(2)
159        .and_then(|value| crate::builtins::common::random_args::keyword_of(value))
160        .is_some()
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use futures::executor::block_on;
167    use runmat_builtins::Tensor;
168
169    fn row(values: &[f64]) -> Value {
170        Value::Tensor(Tensor::new(values.to_vec(), vec![1, values.len()]).expect("tensor"))
171    }
172
173    fn run(args: Vec<Value>) -> crate::BuiltinResult<Value> {
174        block_on(interp1_builtin(args))
175    }
176
177    #[test]
178    fn interp1_linear_midpoints() {
179        let result = run(vec![
180            row(&[1.0, 2.0, 3.0]),
181            row(&[10.0, 20.0, 40.0]),
182            row(&[1.5, 2.5]),
183        ])
184        .expect("interp1");
185        let Value::Tensor(tensor) = result else {
186            panic!("expected tensor");
187        };
188        assert_eq!(tensor.data, vec![15.0, 30.0]);
189    }
190
191    #[test]
192    fn interp1_nearest() {
193        let result = run(vec![
194            row(&[1.0, 2.0, 3.0]),
195            row(&[10.0, 20.0, 40.0]),
196            row(&[1.2, 2.8]),
197            Value::String("nearest".to_string()),
198        ])
199        .expect("interp1");
200        let Value::Tensor(tensor) = result else {
201            panic!("expected tensor");
202        };
203        assert_eq!(tensor.data, vec![10.0, 40.0]);
204    }
205
206    #[test]
207    fn interp1_default_out_of_range_is_nan() {
208        let result =
209            run(vec![row(&[1.0, 2.0]), row(&[10.0, 20.0]), Value::Num(0.0)]).expect("interp1");
210        let Value::Num(value) = result else {
211            panic!("expected scalar");
212        };
213        assert!(value.is_nan());
214    }
215
216    #[test]
217    fn interp1_extrapolates_when_requested() {
218        let result = run(vec![
219            row(&[1.0, 2.0]),
220            row(&[10.0, 20.0]),
221            Value::Num(0.0),
222            Value::String("extrap".to_string()),
223        ])
224        .expect("interp1");
225        assert_eq!(result, Value::Num(0.0));
226    }
227}