1use runmat_builtins::{ResolveContext, Tensor, Type, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
9};
10use crate::builtins::common::tensor;
11use crate::dispatcher;
12
13use super::pp::{
14 interp_error, interval_index, is_vector_shape, out_of_range_value, parse_extrapolation,
15 parse_method, query_points, vector_from_value, Extrapolation, InterpMethod,
16};
17
18const NAME: &str = "interp2";
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::interp2")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: NAME,
23 op_kind: GpuOpKind::Custom("interpolation-2d"),
24 supported_precisions: &[ScalarType::F32, ScalarType::F64],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Initial implementation gathers GPU inputs to the CPU reference path. Bilinear and nearest kernels are good future provider candidates.",
34};
35
36#[runmat_macros::register_fusion_spec(
37 builtin_path = "crate::builtins::math::interpolation::interp2"
38)]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40 name: NAME,
41 shape: ShapeRequirements::Any,
42 constant_strategy: ConstantStrategy::InlineLiteral,
43 elementwise: None,
44 reduction: None,
45 emits_nan: true,
46 notes: "interp2 is currently a runtime sink.",
47};
48
49fn interp2_type(args: &[Type], _ctx: &ResolveContext) -> Type {
50 let query = match args.len() {
51 0..=2 => return Type::tensor(),
52 3 | 4 => args.get(1),
53 _ => args.get(3),
54 };
55 match query {
56 Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
57 Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
58 shape: shape.clone(),
59 },
60 _ => Type::tensor(),
61 }
62}
63
64#[runtime_builtin(
65 name = "interp2",
66 category = "math/interpolation",
67 summary = "Two-dimensional interpolation on gridded data.",
68 keywords = "interp2,interpolation,bilinear,nearest,grid,meshgrid",
69 accel = "sink",
70 sink = true,
71 type_resolver(interp2_type),
72 builtin_path = "crate::builtins::math::interpolation::interp2"
73)]
74async fn interp2_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
75 let parsed = ParsedInterp2::parse(args).await?;
76 let data = evaluate_grid(&parsed)?;
77 if data.len() == 1 {
78 return Ok(Value::Num(data[0]));
79 }
80 let tensor = Tensor::new(data, parsed.output_shape)
81 .map_err(|err| interp_error(NAME, format!("{NAME}: {err}")))?;
82 Ok(Value::Tensor(tensor))
83}
84
85struct ParsedInterp2 {
86 x_axis: Vec<f64>,
87 y_axis: Vec<f64>,
88 z: Tensor,
89 xq: Vec<f64>,
90 yq: Vec<f64>,
91 output_shape: Vec<usize>,
92 method: InterpMethod,
93 extrap: Extrapolation,
94}
95
96impl ParsedInterp2 {
97 async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
98 if args.len() < 3 {
99 return Err(interp_error(
100 NAME,
101 "interp2: expected Z, Xq, and Yq or X, Y, Z, Xq, and Yq",
102 ));
103 }
104
105 let mut method = InterpMethod::Linear;
106 let mut extrap = Extrapolation::Nan;
107 let explicit_axes = args.len() >= 5 && !is_option_arg(&args[3]);
108 let (x_axis, y_axis, z, xq_value, yq_value, options) = if explicit_axes {
109 let mut iter = args.into_iter();
110 let x = iter.next().expect("X");
111 let y = iter.next().expect("Y");
112 let z_value = iter.next().expect("Z");
113 let z = z_tensor(z_value).await?;
114 let (x_axis, y_axis) = axes_from_values(x, y, z.rows, z.cols).await?;
115 let xq = iter.next().expect("Xq");
116 let yq = iter.next().expect("Yq");
117 (x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
118 } else {
119 let mut iter = args.into_iter();
120 let z_value = iter.next().expect("Z");
121 let z = z_tensor(z_value).await?;
122 let x_axis: Vec<f64> = (1..=z.cols).map(|v| v as f64).collect();
123 let y_axis: Vec<f64> = (1..=z.rows).map(|v| v as f64).collect();
124 let xq = iter.next().expect("Xq");
125 let yq = iter.next().expect("Yq");
126 (x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
127 };
128
129 validate_axis(&x_axis, "X")?;
130 validate_axis(&y_axis, "Y")?;
131 let xq = query_points(xq_value, NAME).await?;
132 let yq = query_points(yq_value, NAME).await?;
133 let (xq_values, yq_values, output_shape) = align_queries(xq, yq)?;
134
135 for option in &options {
136 if let Some(parsed) = parse_extrapolation(option, NAME).await? {
137 extrap = parsed;
138 continue;
139 }
140 if let Some(parsed) = parse_method(option, NAME)? {
141 match parsed {
142 InterpMethod::Linear | InterpMethod::Nearest => method = parsed,
143 _ => {
144 return Err(interp_error(
145 NAME,
146 "interp2: only linear and nearest methods are supported",
147 ))
148 }
149 }
150 continue;
151 }
152 return Err(interp_error(
153 NAME,
154 "interp2: unsupported interpolation option",
155 ));
156 }
157
158 Ok(Self {
159 x_axis,
160 y_axis,
161 z,
162 xq: xq_values,
163 yq: yq_values,
164 output_shape,
165 method,
166 extrap,
167 })
168 }
169}
170
171fn is_option_arg(value: &Value) -> bool {
172 crate::builtins::common::random_args::keyword_of(value).is_some()
173}
174
175async fn z_tensor(value: Value) -> crate::BuiltinResult<Tensor> {
176 let gathered = dispatcher::gather_if_needed_async(&value).await?;
177 let z = tensor::value_into_tensor_for(NAME, gathered)
178 .map_err(|err| interp_error(NAME, format!("{NAME}: {err}")))?;
179 if z.shape.len() > 2 {
180 return Err(interp_error(NAME, "interp2: Z must be a 2-D matrix"));
181 }
182 if z.rows < 2 || z.cols < 2 {
183 return Err(interp_error(
184 NAME,
185 "interp2: Z must have at least two rows and two columns",
186 ));
187 }
188 Ok(z)
189}
190
191async fn axes_from_values(
192 x: Value,
193 y: Value,
194 rows: usize,
195 cols: usize,
196) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>)> {
197 let x_axis = axis_from_value(x, rows, cols, true).await?;
198 let y_axis = axis_from_value(y, rows, cols, false).await?;
199 Ok((x_axis, y_axis))
200}
201
202async fn axis_from_value(
203 value: Value,
204 rows: usize,
205 cols: usize,
206 is_x: bool,
207) -> crate::BuiltinResult<Vec<f64>> {
208 let gathered = dispatcher::gather_if_needed_async(&value).await?;
209 let tensor_value = tensor::value_into_tensor_for(NAME, gathered.clone());
210 if let Ok(t) = tensor_value {
211 if is_vector_shape(&t.shape) {
212 let expected = if is_x { cols } else { rows };
213 if t.data.len() != expected {
214 return Err(interp_error(
215 NAME,
216 format!("{NAME}: axis vector length must match Z dimensions"),
217 ));
218 }
219 return Ok(t.data);
220 }
221 if t.rows == rows && t.cols == cols {
222 return if is_x {
223 Ok((0..cols).map(|col| t.data[col * rows]).collect())
224 } else {
225 Ok((0..rows).map(|row| t.data[row]).collect())
226 };
227 }
228 }
229 let label = if is_x { "X" } else { "Y" };
230 vector_from_value(gathered, label, NAME).await
231}
232
233fn validate_axis(axis: &[f64], label: &str) -> crate::BuiltinResult<()> {
234 if axis.len() < 2 {
235 return Err(interp_error(
236 NAME,
237 format!("{NAME}: {label} axis must contain at least two points"),
238 ));
239 }
240 if axis.iter().any(|v| !v.is_finite()) {
241 return Err(interp_error(
242 NAME,
243 format!("{NAME}: {label} axis must be finite"),
244 ));
245 }
246 for pair in axis.windows(2) {
247 if pair[1] <= pair[0] {
248 return Err(interp_error(
249 NAME,
250 format!("{NAME}: {label} axis must be strictly increasing"),
251 ));
252 }
253 }
254 Ok(())
255}
256
257fn align_queries(
258 xq: super::pp::QueryPoints,
259 yq: super::pp::QueryPoints,
260) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>, Vec<usize>)> {
261 match (xq.values.len(), yq.values.len()) {
262 (1, 1) => Ok((xq.values, yq.values, vec![1, 1])),
263 (1, len) => Ok((vec![xq.values[0]; len], yq.values, yq.shape)),
264 (len, 1) => Ok((xq.values, vec![yq.values[0]; len], xq.shape)),
265 (left, right) if left == right && xq.shape == yq.shape => {
266 Ok((xq.values, yq.values, xq.shape))
267 }
268 _ => Err(interp_error(
269 NAME,
270 "interp2: Xq and Yq must be scalar or matching-size arrays",
271 )),
272 }
273}
274
275fn evaluate_grid(parsed: &ParsedInterp2) -> crate::BuiltinResult<Vec<f64>> {
276 let mut out = Vec::with_capacity(parsed.xq.len());
277 for (&xq, &yq) in parsed.xq.iter().zip(parsed.yq.iter()) {
278 let value = match parsed.method {
279 InterpMethod::Linear => eval_bilinear(parsed, xq, yq),
280 InterpMethod::Nearest => eval_nearest(parsed, xq, yq),
281 _ => unreachable!("interp2 parse rejects cubic methods"),
282 };
283 out.push(value);
284 }
285 Ok(out)
286}
287
288fn eval_bilinear(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
289 if !xq.is_finite() || !yq.is_finite() {
290 return f64::NAN;
291 }
292 let allow = matches!(parsed.extrap, Extrapolation::Extrapolate);
293 let Some(col) = interval_index(&parsed.x_axis, xq, allow) else {
294 return out_of_range_value(&parsed.extrap);
295 };
296 let Some(row) = interval_index(&parsed.y_axis, yq, allow) else {
297 return out_of_range_value(&parsed.extrap);
298 };
299 let x0 = parsed.x_axis[col];
300 let x1 = parsed.x_axis[col + 1];
301 let y0 = parsed.y_axis[row];
302 let y1 = parsed.y_axis[row + 1];
303 let tx = (xq - x0) / (x1 - x0);
304 let ty = (yq - y0) / (y1 - y0);
305 let z00 = z_at(&parsed.z, row, col);
306 let z10 = z_at(&parsed.z, row, col + 1);
307 let z01 = z_at(&parsed.z, row + 1, col);
308 let z11 = z_at(&parsed.z, row + 1, col + 1);
309 (1.0 - tx) * (1.0 - ty) * z00 + tx * (1.0 - ty) * z10 + (1.0 - tx) * ty * z01 + tx * ty * z11
310}
311
312fn eval_nearest(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
313 if !xq.is_finite() || !yq.is_finite() {
314 return f64::NAN;
315 }
316 let Some(col) = nearest_index(&parsed.x_axis, xq, &parsed.extrap) else {
317 return out_of_range_value(&parsed.extrap);
318 };
319 let Some(row) = nearest_index(&parsed.y_axis, yq, &parsed.extrap) else {
320 return out_of_range_value(&parsed.extrap);
321 };
322 z_at(&parsed.z, row, col)
323}
324
325fn z_at(z: &Tensor, row: usize, col: usize) -> f64 {
326 z.data[row + col * z.rows]
327}
328
329fn nearest_index(axis: &[f64], q: f64, extrap: &Extrapolation) -> Option<usize> {
330 if q < axis[0] {
331 return matches!(extrap, Extrapolation::Extrapolate).then_some(0);
332 }
333 let last = axis.len() - 1;
334 if q > axis[last] {
335 return matches!(extrap, Extrapolation::Extrapolate).then_some(last);
336 }
337 match axis.binary_search_by(|probe| probe.partial_cmp(&q).unwrap()) {
338 Ok(index) => Some(index),
339 Err(index) => {
340 let left = index.saturating_sub(1);
341 let right = index.min(last);
342 if (q - axis[left]).abs() <= (axis[right] - q).abs() {
343 Some(left)
344 } else {
345 Some(right)
346 }
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use futures::executor::block_on;
355
356 fn row(values: &[f64]) -> Value {
357 Value::Tensor(Tensor::new(values.to_vec(), vec![1, values.len()]).expect("tensor"))
358 }
359
360 #[test]
361 fn interp2_implicit_axes_bilinear_scalar() {
362 let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
363 let value =
364 block_on(interp2_builtin(vec![z, Value::Num(1.5), Value::Num(1.5)])).expect("interp2");
365 let Value::Num(result) = value else {
366 panic!("expected scalar");
367 };
368 assert!((result - 2.5).abs() < 1e-12);
369 }
370
371 #[test]
372 fn interp2_vector_axes_nearest() {
373 let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
374 let value = block_on(interp2_builtin(vec![
375 row(&[10.0, 20.0]),
376 row(&[100.0, 200.0]),
377 z,
378 Value::Num(18.0),
379 Value::Num(120.0),
380 Value::String("nearest".to_string()),
381 ]))
382 .expect("interp2");
383 assert_eq!(value, Value::Num(2.0));
384 }
385}