use runmat_builtins::{
BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
ResolveContext, Tensor, Type, Value,
};
use runmat_macros::runtime_builtin;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::builtins::common::tensor;
use crate::dispatcher;
use crate::{build_runtime_error, RuntimeError};
use super::pp::{
interval_index, is_vector_shape, out_of_range_value, parse_extrapolation, parse_method,
query_points, vector_from_value, Extrapolation, InterpMethod,
};
const NAME: &str = "interp2";
const INTERP2_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
name: "Vq",
ty: BuiltinParamType::NumericArray,
arity: BuiltinParamArity::Required,
default: None,
description: "Interpolated values over a 2-D grid.",
}];
const INTERP2_INPUTS_Z_XQ_YQ: [BuiltinParamDescriptor; 3] = [
BuiltinParamDescriptor {
name: "Z",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid sample matrix.",
},
BuiltinParamDescriptor {
name: "Xq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query X coordinates.",
},
BuiltinParamDescriptor {
name: "Yq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query Y coordinates.",
},
];
const INTERP2_INPUTS_X_Y_Z_XQ_YQ: [BuiltinParamDescriptor; 5] = [
BuiltinParamDescriptor {
name: "X",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid X axis vector or mesh.",
},
BuiltinParamDescriptor {
name: "Y",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid Y axis vector or mesh.",
},
BuiltinParamDescriptor {
name: "Z",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid sample matrix.",
},
BuiltinParamDescriptor {
name: "Xq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query X coordinates.",
},
BuiltinParamDescriptor {
name: "Yq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query Y coordinates.",
},
];
const INTERP2_INPUTS_Z_XQ_YQ_METHOD: [BuiltinParamDescriptor; 4] = [
BuiltinParamDescriptor {
name: "Z",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid sample matrix.",
},
BuiltinParamDescriptor {
name: "Xq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query X coordinates.",
},
BuiltinParamDescriptor {
name: "Yq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query Y coordinates.",
},
BuiltinParamDescriptor {
name: "method",
ty: BuiltinParamType::StringScalar,
arity: BuiltinParamArity::Optional,
default: Some("\"linear\""),
description: "Interpolation method: \"linear\" or \"nearest\".",
},
];
const INTERP2_INPUTS_Z_XQ_YQ_METHOD_EXTRAP: [BuiltinParamDescriptor; 5] = [
BuiltinParamDescriptor {
name: "Z",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid sample matrix.",
},
BuiltinParamDescriptor {
name: "Xq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query X coordinates.",
},
BuiltinParamDescriptor {
name: "Yq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query Y coordinates.",
},
BuiltinParamDescriptor {
name: "method",
ty: BuiltinParamType::StringScalar,
arity: BuiltinParamArity::Optional,
default: Some("\"linear\""),
description: "Interpolation method: \"linear\" or \"nearest\".",
},
BuiltinParamDescriptor {
name: "extrap",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Optional,
default: Some("NaN"),
description: "Extrapolation mode: \"extrap\" or scalar fill value.",
},
];
const INTERP2_INPUTS_X_Y_Z_XQ_YQ_METHOD: [BuiltinParamDescriptor; 6] = [
BuiltinParamDescriptor {
name: "X",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid X axis vector or mesh.",
},
BuiltinParamDescriptor {
name: "Y",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid Y axis vector or mesh.",
},
BuiltinParamDescriptor {
name: "Z",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid sample matrix.",
},
BuiltinParamDescriptor {
name: "Xq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query X coordinates.",
},
BuiltinParamDescriptor {
name: "Yq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query Y coordinates.",
},
BuiltinParamDescriptor {
name: "method",
ty: BuiltinParamType::StringScalar,
arity: BuiltinParamArity::Optional,
default: Some("\"linear\""),
description: "Interpolation method: \"linear\" or \"nearest\".",
},
];
const INTERP2_INPUTS_X_Y_Z_XQ_YQ_METHOD_EXTRAP: [BuiltinParamDescriptor; 7] = [
BuiltinParamDescriptor {
name: "X",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid X axis vector or mesh.",
},
BuiltinParamDescriptor {
name: "Y",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid Y axis vector or mesh.",
},
BuiltinParamDescriptor {
name: "Z",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Grid sample matrix.",
},
BuiltinParamDescriptor {
name: "Xq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query X coordinates.",
},
BuiltinParamDescriptor {
name: "Yq",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Query Y coordinates.",
},
BuiltinParamDescriptor {
name: "method",
ty: BuiltinParamType::StringScalar,
arity: BuiltinParamArity::Optional,
default: Some("\"linear\""),
description: "Interpolation method: \"linear\" or \"nearest\".",
},
BuiltinParamDescriptor {
name: "extrap",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Optional,
default: Some("NaN"),
description: "Extrapolation mode: \"extrap\" or scalar fill value.",
},
];
const INTERP2_SIGNATURES: [BuiltinSignatureDescriptor; 8] = [
BuiltinSignatureDescriptor {
label: "Vq = interp2(Z, Xq, Yq)",
inputs: &INTERP2_INPUTS_Z_XQ_YQ,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(X, Y, Z, Xq, Yq)",
inputs: &INTERP2_INPUTS_X_Y_Z_XQ_YQ,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(Z, Xq, Yq, method)",
inputs: &INTERP2_INPUTS_Z_XQ_YQ_METHOD,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(X, Y, Z, Xq, Yq, method)",
inputs: &INTERP2_INPUTS_X_Y_Z_XQ_YQ_METHOD,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(Z, Xq, Yq, extrap)",
inputs: &INTERP2_INPUTS_Z_XQ_YQ_METHOD,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(X, Y, Z, Xq, Yq, extrap)",
inputs: &INTERP2_INPUTS_X_Y_Z_XQ_YQ_METHOD,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(Z, Xq, Yq, method, extrap)",
inputs: &INTERP2_INPUTS_Z_XQ_YQ_METHOD_EXTRAP,
outputs: &INTERP2_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "Vq = interp2(X, Y, Z, Xq, Yq, method, extrap)",
inputs: &INTERP2_INPUTS_X_Y_Z_XQ_YQ_METHOD_EXTRAP,
outputs: &INTERP2_OUTPUT,
},
];
const INTERP2_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.INTERP2.INVALID_ARGUMENT",
identifier: Some("RunMat:interp2:InvalidArgument"),
when: "Argument count, method/extrapolation options, or axis/query compatibility is invalid.",
message: "interp2: invalid argument",
};
const INTERP2_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.INTERP2.INVALID_INPUT",
identifier: Some("RunMat:interp2:InvalidInput"),
when: "Grid or query values cannot be converted to numeric interpolation domains.",
message: "interp2: invalid input",
};
const INTERP2_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.INTERP2.INTERNAL",
identifier: Some("RunMat:interp2:Internal"),
when: "Interpolation output construction fails due to internal tensor assembly paths.",
message: "interp2: internal interpolation failure",
};
const INTERP2_ERRORS: [BuiltinErrorDescriptor; 3] = [
INTERP2_ERROR_INVALID_ARGUMENT,
INTERP2_ERROR_INVALID_INPUT,
INTERP2_ERROR_INTERNAL,
];
pub const INTERP2_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
signatures: &INTERP2_SIGNATURES,
output_mode: BuiltinOutputMode::Fixed,
completion_policy: BuiltinCompletionPolicy::Public,
errors: &INTERP2_ERRORS,
};
fn interp2_error_with_message(
message: impl Into<String>,
error: &'static BuiltinErrorDescriptor,
) -> RuntimeError {
let mut builder = build_runtime_error(message).with_builtin(NAME);
if let Some(identifier) = error.identifier {
builder = builder.with_identifier(identifier);
}
builder.build()
}
fn interp2_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
interp2_error_with_message(
format!(
"{}: {}",
INTERP2_ERROR_INVALID_ARGUMENT.message,
detail.as_ref()
),
&INTERP2_ERROR_INVALID_ARGUMENT,
)
}
fn interp2_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
interp2_error_with_message(
format!(
"{}: {}",
INTERP2_ERROR_INVALID_INPUT.message,
detail.as_ref()
),
&INTERP2_ERROR_INVALID_INPUT,
)
}
fn interp2_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
if err.identifier().is_some() {
err
} else {
interp2_error_with_message(err.message().to_string(), fallback)
}
}
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::interpolation::interp2")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: NAME,
op_kind: GpuOpKind::Custom("interpolation-2d"),
supported_precisions: &[ScalarType::F32, ScalarType::F64],
broadcast: BroadcastSemantics::Matlab,
provider_hooks: &[],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::GatherImmediately,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Initial implementation gathers GPU inputs to the CPU reference path. Bilinear and nearest kernels are good future provider candidates.",
};
#[runmat_macros::register_fusion_spec(
builtin_path = "crate::builtins::math::interpolation::interp2"
)]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: NAME,
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: true,
notes: "interp2 is currently a runtime sink.",
};
fn interp2_type(args: &[Type], _ctx: &ResolveContext) -> Type {
let query = match args.len() {
0..=2 => return Type::tensor(),
3 | 4 => args.get(1),
_ => args.get(3),
};
match query {
Some(Type::Num | Type::Int | Type::Bool) => Type::Num,
Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
shape: shape.clone(),
},
_ => Type::tensor(),
}
}
#[runtime_builtin(
name = "interp2",
category = "math/interpolation",
summary = "Interpolate two-dimensional gridded data.",
keywords = "interp2,interpolation,bilinear,nearest,grid,meshgrid",
accel = "sink",
sink = true,
type_resolver(interp2_type),
descriptor(crate::builtins::math::interpolation::interp2::INTERP2_DESCRIPTOR),
builtin_path = "crate::builtins::math::interpolation::interp2"
)]
async fn interp2_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
let parsed = ParsedInterp2::parse(args)
.await
.map_err(|err| interp2_map_error(err, &INTERP2_ERROR_INVALID_INPUT))?;
let data =
evaluate_grid(&parsed).map_err(|err| interp2_map_error(err, &INTERP2_ERROR_INTERNAL))?;
if data.len() == 1 {
return Ok(Value::Num(data[0]));
}
let tensor = Tensor::new(data, parsed.output_shape).map_err(|err| {
interp2_error_with_message(format!("{NAME}: {err}"), &INTERP2_ERROR_INTERNAL)
})?;
Ok(Value::Tensor(tensor))
}
struct ParsedInterp2 {
x_axis: Vec<f64>,
y_axis: Vec<f64>,
z: Tensor,
xq: Vec<f64>,
yq: Vec<f64>,
output_shape: Vec<usize>,
method: InterpMethod,
extrap: Extrapolation,
}
impl ParsedInterp2 {
async fn parse(args: Vec<Value>) -> crate::BuiltinResult<Self> {
if args.len() < 3 {
return Err(interp2_invalid_argument(
"expected Z, Xq, and Yq or X, Y, Z, Xq, and Yq",
));
}
let mut method = InterpMethod::Linear;
let mut extrap = Extrapolation::Nan;
let explicit_axes = args.len() >= 5 && !is_option_arg(&args[3]);
let (x_axis, y_axis, z, xq_value, yq_value, options) = if explicit_axes {
let mut iter = args.into_iter();
let x = iter.next().expect("X");
let y = iter.next().expect("Y");
let z_value = iter.next().expect("Z");
let z = z_tensor(z_value).await?;
let (x_axis, y_axis) = axes_from_values(x, y, z.rows, z.cols).await?;
let xq = iter.next().expect("Xq");
let yq = iter.next().expect("Yq");
(x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
} else {
let mut iter = args.into_iter();
let z_value = iter.next().expect("Z");
let z = z_tensor(z_value).await?;
let x_axis: Vec<f64> = (1..=z.cols).map(|v| v as f64).collect();
let y_axis: Vec<f64> = (1..=z.rows).map(|v| v as f64).collect();
let xq = iter.next().expect("Xq");
let yq = iter.next().expect("Yq");
(x_axis, y_axis, z, xq, yq, iter.collect::<Vec<_>>())
};
validate_axis(&x_axis, "X")?;
validate_axis(&y_axis, "Y")?;
let xq = query_points(xq_value, NAME).await?;
let yq = query_points(yq_value, NAME).await?;
let (xq_values, yq_values, output_shape) = align_queries(xq, yq)?;
for option in &options {
if let Some(parsed) = parse_extrapolation(option, NAME).await? {
extrap = parsed;
continue;
}
if let Some(parsed) = parse_method(option, NAME)? {
match parsed {
InterpMethod::Linear | InterpMethod::Nearest => method = parsed,
_ => {
return Err(interp2_invalid_argument(
"only linear and nearest methods are supported",
));
}
}
continue;
}
return Err(interp2_error_with_message(
"interp2: unsupported interpolation option",
&INTERP2_ERROR_INVALID_ARGUMENT,
));
}
Ok(Self {
x_axis,
y_axis,
z,
xq: xq_values,
yq: yq_values,
output_shape,
method,
extrap,
})
}
}
fn is_option_arg(value: &Value) -> bool {
crate::builtins::common::random_args::keyword_of(value).is_some()
}
async fn z_tensor(value: Value) -> crate::BuiltinResult<Tensor> {
let gathered = dispatcher::gather_if_needed_async(&value).await?;
let z =
tensor::value_into_tensor_for(NAME, gathered).map_err(|err| interp2_invalid_input(&err))?;
if z.shape.len() > 2 {
return Err(interp2_invalid_argument("Z must be a 2-D matrix"));
}
if z.rows < 2 || z.cols < 2 {
return Err(interp2_invalid_argument(
"Z must have at least two rows and two columns",
));
}
Ok(z)
}
async fn axes_from_values(
x: Value,
y: Value,
rows: usize,
cols: usize,
) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>)> {
let x_axis = axis_from_value(x, rows, cols, true).await?;
let y_axis = axis_from_value(y, rows, cols, false).await?;
Ok((x_axis, y_axis))
}
async fn axis_from_value(
value: Value,
rows: usize,
cols: usize,
is_x: bool,
) -> crate::BuiltinResult<Vec<f64>> {
let gathered = dispatcher::gather_if_needed_async(&value).await?;
let tensor_value = tensor::value_into_tensor_for(NAME, gathered.clone());
if let Ok(t) = tensor_value {
if is_vector_shape(&t.shape) {
let expected = if is_x { cols } else { rows };
if t.data.len() != expected {
return Err(interp2_invalid_argument(
"axis vector length must match Z dimensions",
));
}
return Ok(t.data);
}
if t.rows == rows && t.cols == cols {
return if is_x {
Ok((0..cols).map(|col| t.data[col * rows]).collect())
} else {
Ok((0..rows).map(|row| t.data[row]).collect())
};
}
}
let label = if is_x { "X" } else { "Y" };
vector_from_value(gathered, label, NAME).await
}
fn validate_axis(axis: &[f64], label: &str) -> crate::BuiltinResult<()> {
if axis.len() < 2 {
return Err(interp2_invalid_argument(format!(
"{label} axis must contain at least two points"
)));
}
if axis.iter().any(|v| !v.is_finite()) {
return Err(interp2_invalid_argument(format!(
"{label} axis must be finite"
)));
}
for pair in axis.windows(2) {
if pair[1] <= pair[0] {
return Err(interp2_invalid_argument(format!(
"{label} axis must be strictly increasing"
)));
}
}
Ok(())
}
fn align_queries(
xq: super::pp::QueryPoints,
yq: super::pp::QueryPoints,
) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>, Vec<usize>)> {
match (xq.values.len(), yq.values.len()) {
(1, 1) => Ok((xq.values, yq.values, vec![1, 1])),
(1, len) => Ok((vec![xq.values[0]; len], yq.values, yq.shape)),
(len, 1) => Ok((xq.values, vec![yq.values[0]; len], xq.shape)),
(left, right) if left == right && xq.shape == yq.shape => {
Ok((xq.values, yq.values, xq.shape))
}
_ => Err(interp2_invalid_argument(
"Xq and Yq must be scalar or matching-size arrays",
)),
}
}
fn evaluate_grid(parsed: &ParsedInterp2) -> crate::BuiltinResult<Vec<f64>> {
let mut out = Vec::with_capacity(parsed.xq.len());
for (&xq, &yq) in parsed.xq.iter().zip(parsed.yq.iter()) {
let value = match parsed.method {
InterpMethod::Linear => eval_bilinear(parsed, xq, yq),
InterpMethod::Nearest => eval_nearest(parsed, xq, yq),
_ => unreachable!("interp2 parse rejects cubic methods"),
};
out.push(value);
}
Ok(out)
}
fn eval_bilinear(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
if !xq.is_finite() || !yq.is_finite() {
return f64::NAN;
}
let allow = matches!(parsed.extrap, Extrapolation::Extrapolate);
let Some(col) = interval_index(&parsed.x_axis, xq, allow) else {
return out_of_range_value(&parsed.extrap);
};
let Some(row) = interval_index(&parsed.y_axis, yq, allow) else {
return out_of_range_value(&parsed.extrap);
};
let x0 = parsed.x_axis[col];
let x1 = parsed.x_axis[col + 1];
let y0 = parsed.y_axis[row];
let y1 = parsed.y_axis[row + 1];
let tx = (xq - x0) / (x1 - x0);
let ty = (yq - y0) / (y1 - y0);
let z00 = z_at(&parsed.z, row, col);
let z10 = z_at(&parsed.z, row, col + 1);
let z01 = z_at(&parsed.z, row + 1, col);
let z11 = z_at(&parsed.z, row + 1, col + 1);
(1.0 - tx) * (1.0 - ty) * z00 + tx * (1.0 - ty) * z10 + (1.0 - tx) * ty * z01 + tx * ty * z11
}
fn eval_nearest(parsed: &ParsedInterp2, xq: f64, yq: f64) -> f64 {
if !xq.is_finite() || !yq.is_finite() {
return f64::NAN;
}
let Some(col) = nearest_index(&parsed.x_axis, xq, &parsed.extrap) else {
return out_of_range_value(&parsed.extrap);
};
let Some(row) = nearest_index(&parsed.y_axis, yq, &parsed.extrap) else {
return out_of_range_value(&parsed.extrap);
};
z_at(&parsed.z, row, col)
}
fn z_at(z: &Tensor, row: usize, col: usize) -> f64 {
z.data[row + col * z.rows]
}
fn nearest_index(axis: &[f64], q: f64, extrap: &Extrapolation) -> Option<usize> {
if q < axis[0] {
return matches!(extrap, Extrapolation::Extrapolate).then_some(0);
}
let last = axis.len() - 1;
if q > axis[last] {
return matches!(extrap, Extrapolation::Extrapolate).then_some(last);
}
match axis.binary_search_by(|probe| probe.partial_cmp(&q).unwrap()) {
Ok(index) => Some(index),
Err(index) => {
let left = index.saturating_sub(1);
let right = index.min(last);
if (q - axis[left]).abs() <= (axis[right] - q).abs() {
Some(left)
} else {
Some(right)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
fn row(values: &[f64]) -> Value {
Value::Tensor(Tensor::new(values.to_vec(), vec![1, values.len()]).expect("tensor"))
}
#[test]
fn interp2_implicit_axes_bilinear_scalar() {
let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
let value =
block_on(interp2_builtin(vec![z, Value::Num(1.5), Value::Num(1.5)])).expect("interp2");
let Value::Num(result) = value else {
panic!("expected scalar");
};
assert!((result - 2.5).abs() < 1e-12);
}
#[test]
fn interp2_vector_axes_nearest() {
let z = Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).expect("tensor"));
let value = block_on(interp2_builtin(vec![
row(&[10.0, 20.0]),
row(&[100.0, 200.0]),
z,
Value::Num(18.0),
Value::Num(120.0),
Value::String("nearest".to_string()),
]))
.expect("interp2");
assert_eq!(value, Value::Num(2.0));
}
#[test]
fn interp2_descriptor_signatures_cover_surface() {
let labels: Vec<&str> = INTERP2_DESCRIPTOR
.signatures
.iter()
.map(|signature| signature.label)
.collect();
assert!(labels.contains(&"Vq = interp2(Z, Xq, Yq)"));
assert!(labels.contains(&"Vq = interp2(X, Y, Z, Xq, Yq)"));
assert!(labels.contains(&"Vq = interp2(X, Y, Z, Xq, Yq, method, extrap)"));
}
#[test]
fn interp2_descriptor_errors_have_stable_codes() {
let codes: Vec<&str> = INTERP2_DESCRIPTOR
.errors
.iter()
.map(|error| error.code)
.collect();
assert!(codes.contains(&"RM.INTERP2.INVALID_ARGUMENT"));
assert!(codes.contains(&"RM.INTERP2.INVALID_INPUT"));
assert!(codes.contains(&"RM.INTERP2.INTERNAL"));
}
#[test]
fn interp2_too_few_args_uses_stable_identifier() {
let err = block_on(interp2_builtin(vec![Value::Num(1.0), Value::Num(2.0)]))
.expect_err("expected interp2 argument error");
assert_eq!(err.identifier(), INTERP2_ERROR_INVALID_ARGUMENT.identifier);
}
}