use crate::builtins::common::gpu_helpers;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::{build_runtime_error, RuntimeError};
use runmat_accelerate_api::GpuTensorHandle;
use runmat_builtins::{
ComplexTensor, LogicalArray, ResolveContext, StringArray, Tensor, Type, Value,
};
use runmat_macros::runtime_builtin;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::shape::squeeze")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "squeeze",
op_kind: GpuOpKind::Custom("squeeze"),
supported_precisions: &[
ScalarType::F32,
ScalarType::F64,
ScalarType::I32,
ScalarType::Bool,
],
broadcast: BroadcastSemantics::None,
provider_hooks: &[ProviderHook::Custom("reshape")],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::InheritInputs,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Uses provider reshape hook to drop singleton metadata without moving device buffers.",
};
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::shape::squeeze")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "squeeze",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes:
"Squeeze only mutates metadata; fusion planner treats it as a no-op for kernel generation.",
};
fn squeeze_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message).with_builtin("squeeze").build()
}
fn squeeze_shape_options(shape: &[Option<usize>]) -> Vec<Option<usize>> {
if shape.len() <= 2 {
return shape.to_vec();
}
let mut squeezed: Vec<Option<usize>> = shape
.iter()
.copied()
.filter(|dim| *dim != Some(1))
.collect();
if squeezed.is_empty() {
squeezed = vec![Some(1), Some(1)];
} else if squeezed.len() == 1 {
squeezed.push(Some(1));
}
squeezed
}
fn squeeze_type(args: &[Type], _ctx: &ResolveContext) -> Type {
let input = match args.first() {
Some(value) => value,
None => return Type::Unknown,
};
match input {
Type::Tensor { shape: Some(shape) } => Type::Tensor {
shape: Some(squeeze_shape_options(shape)),
},
Type::Logical { shape: Some(shape) } => Type::Logical {
shape: Some(squeeze_shape_options(shape)),
},
Type::Tensor { shape: None } => Type::tensor(),
Type::Logical { shape: None } => Type::logical(),
Type::Num | Type::Int | Type::Bool => input.clone(),
Type::Cell { .. } => input.clone(),
Type::Unknown => Type::Unknown,
_ => Type::Unknown,
}
}
#[runtime_builtin(
name = "squeeze",
category = "array/shape",
summary = "Remove singleton dimensions while preserving MATLAB row/column semantics.",
keywords = "squeeze,singleton dimensions,array reshape,gpu",
accel = "shape",
type_resolver(squeeze_type),
builtin_path = "crate::builtins::array::shape::squeeze"
)]
async fn squeeze_builtin(value: Value) -> crate::BuiltinResult<Value> {
squeeze_value(value).await
}
async fn squeeze_value(value: Value) -> crate::BuiltinResult<Value> {
match value {
Value::Tensor(tensor) => squeeze_numeric_tensor(tensor).map(Value::Tensor),
Value::ComplexTensor(ct) => squeeze_complex_tensor(ct).map(Value::ComplexTensor),
Value::LogicalArray(logical) => squeeze_logical_array(logical).map(Value::LogicalArray),
Value::StringArray(strings) => squeeze_string_array(strings).map(Value::StringArray),
Value::GpuTensor(handle) => squeeze_gpu(handle).await,
Value::String(_) | Value::CharArray(_) | Value::Cell(_) | Value::Struct(_) => Ok(value),
Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => Ok(value),
other => Err(squeeze_error(format!(
"squeeze: unsupported input type {}; expected numeric, logical, string, char, cell, or gpu array",
value_kind(&other)
))),
}
}
fn value_kind(value: &Value) -> &'static str {
match value {
Value::Tensor(_) => "tensor",
Value::ComplexTensor(_) => "complex tensor",
Value::LogicalArray(_) => "logical array",
Value::StringArray(_) => "string array",
Value::CharArray(_) => "char array",
Value::Cell(_) => "cell array",
Value::GpuTensor(_) => "gpu array",
Value::Num(_) => "double scalar",
Value::Int(_) => "integer scalar",
Value::Bool(_) => "logical scalar",
Value::Complex(_, _) => "complex scalar",
Value::String(_) => "string scalar",
Value::Object(_) => "object",
Value::HandleObject(_) => "handle object",
Value::Listener(_) => "listener",
Value::Struct(_) => "struct",
Value::FunctionHandle(_) | Value::Closure(_) => "function handle",
Value::ClassRef(_) => "class reference",
Value::MException(_) => "exception",
Value::OutputList(_) => "output list",
}
}
fn squeeze_numeric_tensor(tensor: Tensor) -> crate::BuiltinResult<Tensor> {
let shape = squeeze_shape(&tensor.shape);
if shape == tensor.shape {
return Ok(tensor);
}
Tensor::new(tensor.data, shape).map_err(|e| squeeze_error(format!("squeeze: {e}")))
}
fn squeeze_complex_tensor(ct: ComplexTensor) -> crate::BuiltinResult<ComplexTensor> {
let shape = squeeze_shape(&ct.shape);
if shape == ct.shape {
return Ok(ct);
}
ComplexTensor::new(ct.data, shape).map_err(|e| squeeze_error(format!("squeeze: {e}")))
}
fn squeeze_logical_array(logical: LogicalArray) -> crate::BuiltinResult<LogicalArray> {
let shape = squeeze_shape(&logical.shape);
if shape == logical.shape {
return Ok(logical);
}
LogicalArray::new(logical.data, shape).map_err(|e| squeeze_error(format!("squeeze: {e}")))
}
fn squeeze_string_array(strings: StringArray) -> crate::BuiltinResult<StringArray> {
let shape = squeeze_shape(&strings.shape);
if shape == strings.shape {
return Ok(strings);
}
StringArray::new(strings.data, shape).map_err(|e| squeeze_error(format!("squeeze: {e}")))
}
async fn squeeze_gpu(handle: GpuTensorHandle) -> crate::BuiltinResult<Value> {
let shape_source = if handle.shape.is_empty() {
let gathered = gpu_helpers::gather_tensor_async(&handle).await?;
gathered.shape
} else {
handle.shape.clone()
};
let squeezed = squeeze_shape(&shape_source);
if squeezed == handle.shape {
return Ok(Value::GpuTensor(handle));
}
if let Some(provider) = runmat_accelerate_api::provider() {
if let Ok(updated) = provider.reshape(&handle, &squeezed) {
return Ok(Value::GpuTensor(updated));
}
}
let mut updated = handle;
updated.shape = squeezed;
Ok(Value::GpuTensor(updated))
}
fn squeeze_shape(shape: &[usize]) -> Vec<usize> {
if shape.len() <= 2 {
return shape.to_vec();
}
let mut squeezed: Vec<usize> = shape.iter().copied().filter(|&d| d != 1).collect();
if squeezed.is_empty() {
squeezed.push(1);
squeezed.push(1);
} else if squeezed.len() == 1 {
let first = squeezed[0];
squeezed = vec![first, 1];
}
squeezed
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
#[cfg(feature = "wgpu")]
use crate::dispatcher::download_handle_async;
use futures::executor::block_on;
fn squeeze_builtin(value: Value) -> crate::BuiltinResult<Value> {
block_on(super::squeeze_builtin(value))
}
use crate::builtins::common::test_support;
use runmat_builtins::{IntValue, Tensor};
#[test]
fn squeeze_type_preserves_logical_shape() {
let out = squeeze_type(
&[Type::Logical {
shape: Some(vec![Some(1), Some(3), Some(1)]),
}],
&ResolveContext::new(Vec::new()),
);
assert_eq!(
out,
Type::Logical {
shape: Some(vec![Some(3), Some(1)])
}
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_removes_middle_singletons() {
let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
let result = squeeze_builtin(Value::Tensor(tensor)).expect("squeeze ok");
match result {
Value::Tensor(t) => assert_eq!(t.shape, vec![2, 2]),
other => panic!("expected tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_preserves_row_vector() {
let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let result = squeeze_builtin(Value::Tensor(tensor)).expect("squeeze ok");
match result {
Value::Tensor(t) => assert_eq!(t.shape, vec![1, 3]),
other => panic!("expected tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_single_dimension_becomes_column_vector() {
let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 4, 1]).unwrap();
let result = squeeze_builtin(Value::Tensor(tensor)).expect("squeeze ok");
match result {
Value::Tensor(t) => assert_eq!(t.shape, vec![4, 1]),
other => panic!("expected tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_on_logical_array_respects_zero_dims() {
let logical = LogicalArray::new(vec![1, 0, 0, 1], vec![1, 4, 1]).unwrap();
let result = squeeze_builtin(Value::LogicalArray(logical)).expect("squeeze ok");
match result {
Value::LogicalArray(arr) => assert_eq!(arr.shape, vec![4, 1]),
other => panic!("expected logical array, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_on_string_array() {
let strings = StringArray::new(vec!["a".into(), "b".into()], vec![1, 1, 2]).unwrap();
let result = squeeze_builtin(Value::StringArray(strings)).expect("squeeze ok");
match result {
Value::StringArray(sa) => assert_eq!(sa.shape, vec![2, 1]),
other => panic!("expected string array, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_preserves_zero_length_dimensions() {
let tensor = Tensor::new(Vec::<f64>::new(), vec![1, 0, 3]).unwrap();
let result = squeeze_builtin(Value::Tensor(tensor)).expect("squeeze ok");
match result {
Value::Tensor(t) => assert_eq!(t.shape, vec![0, 3]),
other => panic!("expected tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_gpu_roundtrip() {
test_support::with_test_provider(|provider| {
let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
let view = runmat_accelerate_api::HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
let handle = provider.upload(&view).expect("upload");
let value = squeeze_builtin(Value::GpuTensor(handle)).expect("squeeze ok");
let gathered = test_support::gather(value).expect("gather");
assert_eq!(gathered.shape, vec![2, 2]);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_scalar_inputs_passthrough() {
let result = squeeze_builtin(Value::Int(IntValue::I32(42))).expect("squeeze ok for scalar");
assert_eq!(result, Value::Int(IntValue::I32(42)));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "wgpu")]
fn squeeze_wgpu_updates_shape_metadata() {
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
let tensor = Tensor::new(
(0..12).map(|v| v as f64).collect::<Vec<_>>(),
vec![1, 3, 4, 1],
)
.unwrap();
let view = runmat_accelerate_api::HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
let handle = provider.upload(&view).expect("upload source tensor");
let squeezed =
squeeze_builtin(Value::GpuTensor(handle.clone())).expect("squeeze gpu tensor");
let gpu_handle = match squeezed {
Value::GpuTensor(h) => h,
other => panic!("expected gpu tensor, got {other:?}"),
};
assert_eq!(gpu_handle.shape, vec![3, 4]);
let downloaded = block_on(download_handle_async(provider, &gpu_handle))
.expect("download squeezed tensor");
assert_eq!(downloaded.shape.as_slice(), &[3, 4]);
assert_eq!(downloaded.data.as_slice(), tensor.data.as_slice());
}
}