Skip to main content

runmat_runtime/builtins/math/interpolation/
interp2.rs

1//! MATLAB-compatible `interp2` builtin for gridded dense real data.
2
3use runmat_builtins::{ResolveContext, Tensor, Type, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
9};
10use crate::builtins::common::tensor;
11use crate::dispatcher;
12
13use super::pp::{
14    interp_error, interval_index, is_vector_shape, out_of_range_value, parse_extrapolation,
15    parse_method, query_points, vector_from_value, Extrapolation, InterpMethod,
16};
17
18const NAME: &str = "interp2";
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::interp2")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: NAME,
23    op_kind: GpuOpKind::Custom("interpolation-2d"),
24    supported_precisions: &[ScalarType::F32, ScalarType::F64],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::GatherImmediately,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Initial implementation gathers GPU inputs to the CPU reference path. Bilinear and nearest kernels are good future provider candidates.",
34};
35
36#[runmat_macros::register_fusion_spec(
37    builtin_path = "crate::builtins::math::interpolation::interp2"
38)]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40    name: NAME,
41    shape: ShapeRequirements::Any,
42    constant_strategy: ConstantStrategy::InlineLiteral,
43    elementwise: None,
44    reduction: None,
45    emits_nan: true,
46    notes: "interp2 is currently a runtime sink.",
47};
48
49fn interp2_type(args: &[Type], _ctx: &ResolveContext) -> Type {
50    let query = match args.len() {
51        0..=2 => return Type::tensor(),
52        3 | 4 => args.get(1),
53        _ => args.get(3),
54    };
55    match query {
56        Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
57        Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
58            shape: shape.clone(),
59        },
60        _ => Type::tensor(),
61    }
62}
63
64#[runtime_builtin(
65    name = "interp2",
66    category = "math/interpolation",
67    summary = "Two-dimensional interpolation on gridded data.",
68    keywords = "interp2,interpolation,bilinear,nearest,grid,meshgrid",
69    accel = "sink",
70    sink = true,
71    type_resolver(interp2_type),
72    builtin_path = "crate::builtins::math::interpolation::interp2"
73)]
74async fn interp2_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
75    let parsed = ParsedInterp2::parse(args).await?;
76    let data = evaluate_grid(&parsed)?;
77    if data.len() == 1 {
78        return Ok(Value::Num(data[0]));
79    }
80    let tensor = Tensor::new(data, parsed.output_shape)
81        .map_err(|err| interp_error(NAME, format!("{NAME}: {err}")))?;
82    Ok(Value::Tensor(tensor))
83}
84
85struct ParsedInterp2 {
86    x_axis: Vec<f64>,
87    y_axis: Vec<f64>,
88    z: Tensor,
89    xq: Vec<f64>,
90    yq: Vec<f64>,
91    output_shape: Vec<usize>,
92    method: InterpMethod,
93    extrap: Extrapolation,
94}
95
96impl ParsedInterp2 {
97    async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
98        if args.len() < 3 {
99            return Err(interp_error(
100                NAME,
101                "interp2: expected Z, Xq, and Yq or X, Y, Z, Xq, and Yq",
102            ));
103        }
104
105        let mut method = InterpMethod::Linear;
106        let mut extrap = Extrapolation::Nan;
107        let explicit_axes = args.len() >= 5 && !is_option_arg(&args[3]);
108        let (x_axis, y_axis, z, xq_value, yq_value, options) = if explicit_axes {
109            let mut iter = args.into_iter();
110            let x = iter.next().expect("X");
111            let y = iter.next().expect("Y");
112            let z_value = iter.next().expect("Z");
113            let z = z_tensor(z_value).await?;
114            let (x_axis, y_axis) = axes_from_values(x, y, z.rows, z.cols).await?;
115            let xq = iter.next().expect("Xq");
116            let yq = iter.next().expect("Yq");
117            (x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
118        } else {
119            let mut iter = args.into_iter();
120            let z_value = iter.next().expect("Z");
121            let z = z_tensor(z_value).await?;
122            let x_axis: Vec<f64> = (1..=z.cols).map(|v| v as f64).collect();
123            let y_axis: Vec<f64> = (1..=z.rows).map(|v| v as f64).collect();
124            let xq = iter.next().expect("Xq");
125            let yq = iter.next().expect("Yq");
126            (x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
127        };
128
129        validate_axis(&x_axis, "X")?;
130        validate_axis(&y_axis, "Y")?;
131        let xq = query_points(xq_value, NAME).await?;
132        let yq = query_points(yq_value, NAME).await?;
133        let (xq_values, yq_values, output_shape) = align_queries(xq, yq)?;
134
135        for option in &options {
136            if let Some(parsed) = parse_extrapolation(option, NAME).await? {
137                extrap = parsed;
138                continue;
139            }
140            if let Some(parsed) = parse_method(option, NAME)? {
141                match parsed {
142                    InterpMethod::Linear | InterpMethod::Nearest => method = parsed,
143                    _ => {
144                        return Err(interp_error(
145                            NAME,
146                            "interp2: only linear and nearest methods are supported",
147                        ))
148                    }
149                }
150                continue;
151            }
152            return Err(interp_error(
153                NAME,
154                "interp2: unsupported interpolation option",
155            ));
156        }
157
158        Ok(Self {
159            x_axis,
160            y_axis,
161            z,
162            xq: xq_values,
163            yq: yq_values,
164            output_shape,
165            method,
166            extrap,
167        })
168    }
169}
170
171fn is_option_arg(value: &Value) -> bool {
172    crate::builtins::common::random_args::keyword_of(value).is_some()
173}
174
175async fn z_tensor(value: Value) -> crate::BuiltinResult<Tensor> {
176    let gathered = dispatcher::gather_if_needed_async(&value).await?;
177    let z = tensor::value_into_tensor_for(NAME, gathered)
178        .map_err(|err| interp_error(NAME, format!("{NAME}: {err}")))?;
179    if z.shape.len() > 2 {
180        return Err(interp_error(NAME, "interp2: Z must be a 2-D matrix"));
181    }
182    if z.rows < 2 || z.cols < 2 {
183        return Err(interp_error(
184            NAME,
185            "interp2: Z must have at least two rows and two columns",
186        ));
187    }
188    Ok(z)
189}
190
191async fn axes_from_values(
192    x: Value,
193    y: Value,
194    rows: usize,
195    cols: usize,
196) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>)> {
197    let x_axis = axis_from_value(x, rows, cols, true).await?;
198    let y_axis = axis_from_value(y, rows, cols, false).await?;
199    Ok((x_axis, y_axis))
200}
201
202async fn axis_from_value(
203    value: Value,
204    rows: usize,
205    cols: usize,
206    is_x: bool,
207) -> crate::BuiltinResult<Vec<f64>> {
208    let gathered = dispatcher::gather_if_needed_async(&value).await?;
209    let tensor_value = tensor::value_into_tensor_for(NAME, gathered.clone());
210    if let Ok(t) = tensor_value {
211        if is_vector_shape(&t.shape) {
212            let expected = if is_x { cols } else { rows };
213            if t.data.len() != expected {
214                return Err(interp_error(
215                    NAME,
216                    format!("{NAME}: axis vector length must match Z dimensions"),
217                ));
218            }
219            return Ok(t.data);
220        }
221        if t.rows == rows && t.cols == cols {
222            return if is_x {
223                Ok((0..cols).map(|col| t.data[col * rows]).collect())
224            } else {
225                Ok((0..rows).map(|row| t.data[row]).collect())
226            };
227        }
228    }
229    let label = if is_x { "X" } else { "Y" };
230    vector_from_value(gathered, label, NAME).await
231}
232
233fn validate_axis(axis: &[f64], label: &str) -> crate::BuiltinResult<()> {
234    if axis.len() < 2 {
235        return Err(interp_error(
236            NAME,
237            format!("{NAME}: {label} axis must contain at least two points"),
238        ));
239    }
240    if axis.iter().any(|v| !v.is_finite()) {
241        return Err(interp_error(
242            NAME,
243            format!("{NAME}: {label} axis must be finite"),
244        ));
245    }
246    for pair in axis.windows(2) {
247        if pair[1] <= pair[0] {
248            return Err(interp_error(
249                NAME,
250                format!("{NAME}: {label} axis must be strictly increasing"),
251            ));
252        }
253    }
254    Ok(())
255}
256
257fn align_queries(
258    xq: super::pp::QueryPoints,
259    yq: super::pp::QueryPoints,
260) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>, Vec<usize>)> {
261    match (xq.values.len(), yq.values.len()) {
262        (1, 1) => Ok((xq.values, yq.values, vec![1, 1])),
263        (1, len) => Ok((vec![xq.values[0]; len], yq.values, yq.shape)),
264        (len, 1) => Ok((xq.values, vec![yq.values[0]; len], xq.shape)),
265        (left, right) if left == right && xq.shape == yq.shape => {
266            Ok((xq.values, yq.values, xq.shape))
267        }
268        _ => Err(interp_error(
269            NAME,
270            "interp2: Xq and Yq must be scalar or matching-size arrays",
271        )),
272    }
273}
274
275fn evaluate_grid(parsed: &ParsedInterp2) -> crate::BuiltinResult<Vec<f64>> {
276    let mut out = Vec::with_capacity(parsed.xq.len());
277    for (&xq, &yq) in parsed.xq.iter().zip(parsed.yq.iter()) {
278        let value = match parsed.method {
279            InterpMethod::Linear => eval_bilinear(parsed, xq, yq),
280            InterpMethod::Nearest => eval_nearest(parsed, xq, yq),
281            _ => unreachable!("interp2 parse rejects cubic methods"),
282        };
283        out.push(value);
284    }
285    Ok(out)
286}
287
288fn eval_bilinear(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
289    if !xq.is_finite() || !yq.is_finite() {
290        return f64::NAN;
291    }
292    let allow = matches!(parsed.extrap, Extrapolation::Extrapolate);
293    let Some(col) = interval_index(&parsed.x_axis, xq, allow) else {
294        return out_of_range_value(&parsed.extrap);
295    };
296    let Some(row) = interval_index(&parsed.y_axis, yq, allow) else {
297        return out_of_range_value(&parsed.extrap);
298    };
299    let x0 = parsed.x_axis[col];
300    let x1 = parsed.x_axis[col + 1];
301    let y0 = parsed.y_axis[row];
302    let y1 = parsed.y_axis[row + 1];
303    let tx = (xq - x0) / (x1 - x0);
304    let ty = (yq - y0) / (y1 - y0);
305    let z00 = z_at(&parsed.z, row, col);
306    let z10 = z_at(&parsed.z, row, col + 1);
307    let z01 = z_at(&parsed.z, row + 1, col);
308    let z11 = z_at(&parsed.z, row + 1, col + 1);
309    (1.0 - tx) * (1.0 - ty) * z00 + tx * (1.0 - ty) * z10 + (1.0 - tx) * ty * z01 + tx * ty * z11
310}
311
312fn eval_nearest(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
313    if !xq.is_finite() || !yq.is_finite() {
314        return f64::NAN;
315    }
316    let Some(col) = nearest_index(&parsed.x_axis, xq, &parsed.extrap) else {
317        return out_of_range_value(&parsed.extrap);
318    };
319    let Some(row) = nearest_index(&parsed.y_axis, yq, &parsed.extrap) else {
320        return out_of_range_value(&parsed.extrap);
321    };
322    z_at(&parsed.z, row, col)
323}
324
325fn z_at(z: &Tensor, row: usize, col: usize) -> f64 {
326    z.data[row + col * z.rows]
327}
328
329fn nearest_index(axis: &[f64], q: f64, extrap: &Extrapolation) -> Option<usize> {
330    if q < axis[0] {
331        return matches!(extrap, Extrapolation::Extrapolate).then_some(0);
332    }
333    let last = axis.len() - 1;
334    if q > axis[last] {
335        return matches!(extrap, Extrapolation::Extrapolate).then_some(last);
336    }
337    match axis.binary_search_by(|probe| probe.partial_cmp(&q).unwrap()) {
338        Ok(index) => Some(index),
339        Err(index) => {
340            let left = index.saturating_sub(1);
341            let right = index.min(last);
342            if (q - axis[left]).abs() <= (axis[right] - q).abs() {
343                Some(left)
344            } else {
345                Some(right)
346            }
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use futures::executor::block_on;
355
356    fn row(values: &[f64]) -> Value {
357        Value::Tensor(Tensor::new(values.to_vec(), vec![1, values.len()]).expect("tensor"))
358    }
359
360    #[test]
361    fn interp2_implicit_axes_bilinear_scalar() {
362        let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
363        let value =
364            block_on(interp2_builtin(vec![z, Value::Num(1.5), Value::Num(1.5)])).expect("interp2");
365        let Value::Num(result) = value else {
366            panic!("expected scalar");
367        };
368        assert!((result - 2.5).abs() < 1e-12);
369    }
370
371    #[test]
372    fn interp2_vector_axes_nearest() {
373        let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
374        let value = block_on(interp2_builtin(vec![
375            row(&[10.0, 20.0]),
376            row(&[100.0, 200.0]),
377            z,
378            Value::Num(18.0),
379            Value::Num(120.0),
380            Value::String("nearest".to_string()),
381        ]))
382        .expect("interp2");
383        assert_eq!(value, Value::Num(2.0));
384    }
385}