use runmat_builtins::{ResolveContext, Tensor, Type, Value};
use runmat_macros::runtime_builtin;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::builtins::common::tensor;
use crate::dispatcher;
use super::pp::{
interp_error, interval_index, is_vector_shape, out_of_range_value, parse_extrapolation,
parse_method, query_points, vector_from_value, Extrapolation, InterpMethod,
};
const NAME: &str = "interp2";
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::interp2")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: NAME,
op_kind: GpuOpKind::Custom("interpolation-2d"),
supported_precisions: &[ScalarType::F32, ScalarType::F64],
broadcast: BroadcastSemantics::Matlab,
provider_hooks: &[],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::GatherImmediately,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Initial implementation gathers GPU inputs to the CPU reference path. Bilinear and nearest kernels are good future provider candidates.",
};
#[runmat_macros::register_fusion_spec(
builtin_path = "crate::builtins::math::interpolation::interp2"
)]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: NAME,
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: true,
notes: "interp2 is currently a runtime sink.",
};
fn interp2_type(args: &[Type], _ctx: &ResolveContext) -> Type {
let query = match args.len() {
0..=2 => return Type::tensor(),
3 | 4 => args.get(1),
_ => args.get(3),
};
match query {
Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
shape: shape.clone(),
},
_ => Type::tensor(),
}
}
#[runtime_builtin(
name = "interp2",
category = "math/interpolation",
summary = "Two-dimensional interpolation on gridded data.",
keywords = "interp2,interpolation,bilinear,nearest,grid,meshgrid",
accel = "sink",
sink = true,
type_resolver(interp2_type),
builtin_path = "crate::builtins::math::interpolation::interp2"
)]
async fn interp2_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
let parsed = ParsedInterp2::parse(args).await?;
let data = evaluate_grid(&parsed)?;
if data.len() == 1 {
return Ok(Value::Num(data[0]));
}
let tensor = Tensor::new(data, parsed.output_shape)
.map_err(|err| interp_error(NAME, format!("{NAME}: {err}")))?;
Ok(Value::Tensor(tensor))
}
struct ParsedInterp2 {
x_axis: Vec<f64>,
y_axis: Vec<f64>,
z: Tensor,
xq: Vec<f64>,
yq: Vec<f64>,
output_shape: Vec<usize>,
method: InterpMethod,
extrap: Extrapolation,
}
impl ParsedInterp2 {
async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
if args.len() < 3 {
return Err(interp_error(
NAME,
"interp2: expected Z, Xq, and Yq or X, Y, Z, Xq, and Yq",
));
}
let mut method = InterpMethod::Linear;
let mut extrap = Extrapolation::Nan;
let explicit_axes = args.len() >= 5 && !is_option_arg(&args[3]);
let (x_axis, y_axis, z, xq_value, yq_value, options) = if explicit_axes {
let mut iter = args.into_iter();
let x = iter.next().expect("X");
let y = iter.next().expect("Y");
let z_value = iter.next().expect("Z");
let z = z_tensor(z_value).await?;
let (x_axis, y_axis) = axes_from_values(x, y, z.rows, z.cols).await?;
let xq = iter.next().expect("Xq");
let yq = iter.next().expect("Yq");
(x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
} else {
let mut iter = args.into_iter();
let z_value = iter.next().expect("Z");
let z = z_tensor(z_value).await?;
let x_axis: Vec<f64> = (1..=z.cols).map(|v| v as f64).collect();
let y_axis: Vec<f64> = (1..=z.rows).map(|v| v as f64).collect();
let xq = iter.next().expect("Xq");
let yq = iter.next().expect("Yq");
(x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
};
validate_axis(&x_axis, "X")?;
validate_axis(&y_axis, "Y")?;
let xq = query_points(xq_value, NAME).await?;
let yq = query_points(yq_value, NAME).await?;
let (xq_values, yq_values, output_shape) = align_queries(xq, yq)?;
for option in &options {
if let Some(parsed) = parse_extrapolation(option, NAME).await? {
extrap = parsed;
continue;
}
if let Some(parsed) = parse_method(option, NAME)? {
match parsed {
InterpMethod::Linear | InterpMethod::Nearest => method = parsed,
_ => {
return Err(interp_error(
NAME,
"interp2: only linear and nearest methods are supported",
))
}
}
continue;
}
return Err(interp_error(
NAME,
"interp2: unsupported interpolation option",
));
}
Ok(Self {
x_axis,
y_axis,
z,
xq: xq_values,
yq: yq_values,
output_shape,
method,
extrap,
})
}
}
fn is_option_arg(value: &Value) -> bool {
crate::builtins::common::random_args::keyword_of(value).is_some()
}
async fn z_tensor(value: Value) -> crate::BuiltinResult<Tensor> {
let gathered = dispatcher::gather_if_needed_async(&value).await?;
let z = tensor::value_into_tensor_for(NAME, gathered)
.map_err(|err| interp_error(NAME, format!("{NAME}: {err}")))?;
if z.shape.len() > 2 {
return Err(interp_error(NAME, "interp2: Z must be a 2-D matrix"));
}
if z.rows < 2 || z.cols < 2 {
return Err(interp_error(
NAME,
"interp2: Z must have at least two rows and two columns",
));
}
Ok(z)
}
async fn axes_from_values(
x: Value,
y: Value,
rows: usize,
cols: usize,
) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>)> {
let x_axis = axis_from_value(x, rows, cols, true).await?;
let y_axis = axis_from_value(y, rows, cols, false).await?;
Ok((x_axis, y_axis))
}
async fn axis_from_value(
value: Value,
rows: usize,
cols: usize,
is_x: bool,
) -> crate::BuiltinResult<Vec<f64>> {
let gathered = dispatcher::gather_if_needed_async(&value).await?;
let tensor_value = tensor::value_into_tensor_for(NAME, gathered.clone());
if let Ok(t) = tensor_value {
if is_vector_shape(&t.shape) {
let expected = if is_x { cols } else { rows };
if t.data.len() != expected {
return Err(interp_error(
NAME,
format!("{NAME}: axis vector length must match Z dimensions"),
));
}
return Ok(t.data);
}
if t.rows == rows && t.cols == cols {
return if is_x {
Ok((0..cols).map(|col| t.data[col * rows]).collect())
} else {
Ok((0..rows).map(|row| t.data[row]).collect())
};
}
}
let label = if is_x { "X" } else { "Y" };
vector_from_value(gathered, label, NAME).await
}
fn validate_axis(axis: &[f64], label: &str) -> crate::BuiltinResult<()> {
if axis.len() < 2 {
return Err(interp_error(
NAME,
format!("{NAME}: {label} axis must contain at least two points"),
));
}
if axis.iter().any(|v| !v.is_finite()) {
return Err(interp_error(
NAME,
format!("{NAME}: {label} axis must be finite"),
));
}
for pair in axis.windows(2) {
if pair[1] <= pair[0] {
return Err(interp_error(
NAME,
format!("{NAME}: {label} axis must be strictly increasing"),
));
}
}
Ok(())
}
fn align_queries(
xq: super::pp::QueryPoints,
yq: super::pp::QueryPoints,
) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>, Vec<usize>)> {
match (xq.values.len(), yq.values.len()) {
(1, 1) => Ok((xq.values, yq.values, vec![1, 1])),
(1, len) => Ok((vec![xq.values[0]; len], yq.values, yq.shape)),
(len, 1) => Ok((xq.values, vec![yq.values[0]; len], xq.shape)),
(left, right) if left == right && xq.shape == yq.shape => {
Ok((xq.values, yq.values, xq.shape))
}
_ => Err(interp_error(
NAME,
"interp2: Xq and Yq must be scalar or matching-size arrays",
)),
}
}
fn evaluate_grid(parsed: &ParsedInterp2) -> crate::BuiltinResult<Vec<f64>> {
let mut out = Vec::with_capacity(parsed.xq.len());
for (&xq, &yq) in parsed.xq.iter().zip(parsed.yq.iter()) {
let value = match parsed.method {
InterpMethod::Linear => eval_bilinear(parsed, xq, yq),
InterpMethod::Nearest => eval_nearest(parsed, xq, yq),
_ => unreachable!("interp2 parse rejects cubic methods"),
};
out.push(value);
}
Ok(out)
}
fn eval_bilinear(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
if !xq.is_finite() || !yq.is_finite() {
return f64::NAN;
}
let allow = matches!(parsed.extrap, Extrapolation::Extrapolate);
let Some(col) = interval_index(&parsed.x_axis, xq, allow) else {
return out_of_range_value(&parsed.extrap);
};
let Some(row) = interval_index(&parsed.y_axis, yq, allow) else {
return out_of_range_value(&parsed.extrap);
};
let x0 = parsed.x_axis[col];
let x1 = parsed.x_axis[col + 1];
let y0 = parsed.y_axis[row];
let y1 = parsed.y_axis[row + 1];
let tx = (xq - x0) / (x1 - x0);
let ty = (yq - y0) / (y1 - y0);
let z00 = z_at(&parsed.z, row, col);
let z10 = z_at(&parsed.z, row, col + 1);
let z01 = z_at(&parsed.z, row + 1, col);
let z11 = z_at(&parsed.z, row + 1, col + 1);
(1.0 - tx) * (1.0 - ty) * z00 + tx * (1.0 - ty) * z10 + (1.0 - tx) * ty * z01 + tx * ty * z11
}
fn eval_nearest(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
if !xq.is_finite() || !yq.is_finite() {
return f64::NAN;
}
let Some(col) = nearest_index(&parsed.x_axis, xq, &parsed.extrap) else {
return out_of_range_value(&parsed.extrap);
};
let Some(row) = nearest_index(&parsed.y_axis, yq, &parsed.extrap) else {
return out_of_range_value(&parsed.extrap);
};
z_at(&parsed.z, row, col)
}
fn z_at(z: &Tensor, row: usize, col: usize) -> f64 {
z.data[row + col * z.rows]
}
fn nearest_index(axis: &[f64], q: f64, extrap: &Extrapolation) -> Option<usize> {
if q < axis[0] {
return matches!(extrap, Extrapolation::Extrapolate).then_some(0);
}
let last = axis.len() - 1;
if q > axis[last] {
return matches!(extrap, Extrapolation::Extrapolate).then_some(last);
}
match axis.binary_search_by(|probe| probe.partial_cmp(&q).unwrap()) {
Ok(index) => Some(index),
Err(index) => {
let left = index.saturating_sub(1);
let right = index.min(last);
if (q - axis[left]).abs() <= (axis[right] - q).abs() {
Some(left)
} else {
Some(right)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
fn row(values: &[f64]) -> Value {
Value::Tensor(Tensor::new(values.to_vec(), vec![1, values.len()]).expect("tensor"))
}
#[test]
fn interp2_implicit_axes_bilinear_scalar() {
let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
let value =
block_on(interp2_builtin(vec![z, Value::Num(1.5), Value::Num(1.5)])).expect("interp2");
let Value::Num(result) = value else {
panic!("expected scalar");
};
assert!((result - 2.5).abs() < 1e-12);
}
#[test]
fn interp2_vector_axes_nearest() {
let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
let value = block_on(interp2_builtin(vec![
row(&[10.0, 20.0]),
row(&[100.0, 200.0]),
z,
Value::Num(18.0),
Value::Num(120.0),
Value::String("nearest".to_string()),
]))
.expect("interp2");
assert_eq!(value, Value::Num(2.0));
}
}