use super::common::{
complex_tensor_to_real_value, download_provider_complex_tensor, gather_gpu_complex_tensor,
parse_2d_lengths_from_data, parse_length, parse_symflag, transform_axes_complex_tensor,
value_to_complex_tensor, TransformDirection,
};
use super::ifft::ifft_complex_tensor;
use crate::builtins::common::random_args::complex_tensor_into_value;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::builtins::common::tensor;
use crate::builtins::math::fft::type_resolvers::ifft2_type;
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
use runmat_accelerate_api::GpuTensorHandle;
use runmat_builtins::{ComplexTensor, Value};
use runmat_macros::runtime_builtin;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::ifft2")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "ifft2",
op_kind: GpuOpKind::Custom("ifft2"),
supported_precisions: &[ScalarType::F32, ScalarType::F64],
broadcast: BroadcastSemantics::Matlab,
provider_hooks: &[ProviderHook::Custom("ifft_dim")],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::NewHandle,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes:
"Performs two sequential `ifft_dim` passes (dimensions 0 and 1); falls back to host execution when the hook is missing.",
};
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::ifft2")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "ifft2",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes:
"ifft2 terminates fusion plans; fused kernels are not generated for multi-dimensional inverse FFTs.",
};
const BUILTIN_NAME: &str = "ifft2";
fn ifft2_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message)
.with_builtin(BUILTIN_NAME)
.build()
}
#[runtime_builtin(
name = "ifft2",
category = "math/fft",
summary = "Compute the two-dimensional inverse discrete Fourier transform (IDFT) of numeric or complex data.",
keywords = "ifft2,inverse fft,image reconstruction,gpu",
type_resolver(ifft2_type),
builtin_path = "crate::builtins::math::fft::ifft2"
)]
async fn ifft2_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
let ((len_rows, len_cols), symmetric) = parse_ifft2_arguments(&rest)?;
match value {
Value::GpuTensor(handle) => ifft2_gpu(handle, (len_rows, len_cols), symmetric).await,
other => ifft2_host(other, (len_rows, len_cols), symmetric),
}
}
fn ifft2_host(
value: Value,
lengths: (Option<usize>, Option<usize>),
symmetric: bool,
) -> BuiltinResult<Value> {
let tensor = value_to_complex_tensor(value, BUILTIN_NAME)?;
let transformed = ifft2_complex_tensor(tensor, lengths)?;
finalize_ifft2_output(transformed, symmetric)
}
async fn ifft2_gpu(
handle: GpuTensorHandle,
lengths: (Option<usize>, Option<usize>),
symmetric: bool,
) -> BuiltinResult<Value> {
if matches!(lengths.0, Some(0)) || matches!(lengths.1, Some(0)) {
return ifft2_gpu_fallback(handle, lengths, symmetric).await;
}
if let Some(provider) = runmat_accelerate_api::provider() {
if let Ok(first) = provider.ifft_dim(&handle, lengths.0, 0).await {
match provider.ifft_dim(&first, lengths.1, 1).await {
Ok(second) => {
if first.buffer_id != second.buffer_id {
provider.free(&first).ok();
runmat_accelerate_api::clear_residency(&first);
}
if !symmetric {
return Ok(Value::GpuTensor(second));
}
if let Ok(real) = provider.fft_extract_real(&second).await {
provider.free(&second).ok();
runmat_accelerate_api::clear_residency(&second);
return Ok(Value::GpuTensor(real));
}
let complex =
download_provider_complex_tensor(provider, &second, BUILTIN_NAME, true)
.await?;
return finalize_ifft2_output(complex, true);
}
Err(_) => {
let partial =
download_provider_complex_tensor(provider, &first, BUILTIN_NAME, true)
.await?;
let completed = ifft_complex_tensor(partial, lengths.1, Some(2))?;
return finalize_ifft2_output(completed, symmetric);
}
}
}
}
ifft2_gpu_fallback(handle, lengths, symmetric).await
}
async fn ifft2_gpu_fallback(
handle: GpuTensorHandle,
lengths: (Option<usize>, Option<usize>),
symmetric: bool,
) -> BuiltinResult<Value> {
let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
let transformed = ifft2_complex_tensor(complex, lengths)?;
finalize_ifft2_output(transformed, symmetric)
}
fn ifft2_complex_tensor(
tensor: ComplexTensor,
lengths: (Option<usize>, Option<usize>),
) -> BuiltinResult<ComplexTensor> {
let (len_rows, len_cols) = lengths;
transform_axes_complex_tensor(
tensor,
&[len_rows, len_cols],
TransformDirection::Inverse,
BUILTIN_NAME,
)
}
fn finalize_ifft2_output(tensor: ComplexTensor, symmetric: bool) -> BuiltinResult<Value> {
if symmetric {
complex_tensor_to_real_value(tensor, BUILTIN_NAME)
} else {
Ok(complex_tensor_into_value(tensor))
}
}
type LengthPair = (Option<usize>, Option<usize>);
type LengthsAndSymmetry = (LengthPair, bool);
fn parse_ifft2_arguments(args: &[Value]) -> BuiltinResult<LengthsAndSymmetry> {
if args.is_empty() {
return Ok(((None, None), false));
}
let (maybe_flag, rem) = split_symflag(args)?;
let mut symmetry = false;
if let Some(flag) = maybe_flag {
symmetry = flag;
}
let lengths = match rem.len() {
0 => (None, None),
1 => parse_ifft2_single(&rem[0])?,
2 => {
let rows = parse_length(&rem[0], BUILTIN_NAME)?;
let cols = parse_length(&rem[1], BUILTIN_NAME)?;
(rows, cols)
}
_ => {
return Err(ifft2_error(
"ifft2: expected ifft2(X), ifft2(X, M, N), or ifft2(X, SIZE[, symflag])",
))
}
};
Ok((lengths, symmetry))
}
fn split_symflag(args: &[Value]) -> BuiltinResult<(Option<bool>, &[Value])> {
if let Some((last, rest)) = args.split_last() {
if let Some(flag) = parse_symflag(last, BUILTIN_NAME)? {
for value in rest {
if parse_symflag(value, BUILTIN_NAME)?.is_some() {
return Err(ifft2_error(
"ifft2: symmetry flag must appear once at the end",
));
}
}
return Ok((Some(flag), rest));
}
}
for value in args {
if parse_symflag(value, BUILTIN_NAME)?.is_some() {
return Err(ifft2_error(
"ifft2: symmetry flag must appear as the final argument",
));
}
}
Ok((None, args))
}
fn parse_ifft2_single(value: &Value) -> BuiltinResult<(Option<usize>, Option<usize>)> {
match value {
Value::Tensor(tensor) => parse_2d_lengths_from_data(&tensor.data, BUILTIN_NAME),
Value::LogicalArray(logical) => {
let tensor = tensor::logical_to_tensor(logical)
.map_err(|e| ifft2_error(format!("{BUILTIN_NAME}: {e}")))?;
parse_2d_lengths_from_data(&tensor.data, BUILTIN_NAME)
}
Value::Num(_) | Value::Int(_) => {
let len = parse_length(value, BUILTIN_NAME)?;
Ok((len, len))
}
Value::Complex(re, im) => {
if im.abs() > f64::EPSILON {
return Err(ifft2_error("ifft2: transform lengths must be real-valued"));
}
let scalar = Value::Num(*re);
let len = parse_length(&scalar, BUILTIN_NAME)?;
Ok((len, len))
}
Value::ComplexTensor(_) => Err(ifft2_error("ifft2: size vector must contain real values")),
Value::GpuTensor(_) => Err(ifft2_error(
"ifft2: size vector must be numeric and host-resident",
)),
Value::Bool(_) => Err(ifft2_error("ifft2: transform lengths must be numeric")),
Value::String(_)
| Value::StringArray(_)
| Value::CharArray(_)
| Value::Cell(_)
| Value::Struct(_)
| Value::FunctionHandle(_)
| Value::Closure(_)
| Value::HandleObject(_)
| Value::Listener(_)
| Value::Object(_)
| Value::ClassRef(_)
| Value::MException(_)
| Value::OutputList(_) => Err(ifft2_error("ifft2: transform lengths must be numeric")),
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::builtins::common::test_support;
use crate::builtins::math::fft::common;
use futures::executor::block_on;
#[cfg(feature = "wgpu")]
use runmat_accelerate_api::AccelProvider;
use runmat_accelerate_api::HostTensorView;
use runmat_builtins::{IntValue, ResolveContext, Tensor as HostTensor, Type};
fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
(a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
}
fn error_message(error: crate::RuntimeError) -> String {
error.message().to_string()
}
fn fft2_of_tensor(tensor: &HostTensor) -> ComplexTensor {
let complex = value_to_complex_tensor(Value::Tensor(tensor.clone()), "fft2").unwrap();
let first = super::super::fft::fft_complex_tensor(complex, None, Some(1)).unwrap();
super::super::fft::fft_complex_tensor(first, None, Some(2)).unwrap()
}
fn value_to_host_complex(value: Value) -> ComplexTensor {
match value {
Value::ComplexTensor(ct) => ct,
Value::GpuTensor(handle) => {
let provider = runmat_accelerate_api::provider_for_handle(&handle)
.or_else(runmat_accelerate_api::provider)
.expect("provider for gpu handle");
let host = block_on(provider.download(&handle)).expect("download gpu ifft2 output");
common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
}
other => panic!("expected complex value, got {other:?}"),
}
}
#[test]
fn ifft2_type_pads_rank() {
let out = ifft2_type(
&[Type::Tensor {
shape: Some(vec![Some(3)]),
}],
&ResolveContext::new(Vec::new()),
);
assert_eq!(
out,
Type::Tensor {
shape: Some(vec![Some(3), Some(1)])
}
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_inverts_known_fft2() {
let tensor = HostTensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let value =
ifft2_builtin(Value::ComplexTensor(spectrum.clone()), Vec::new()).expect("ifft2");
match value {
Value::ComplexTensor(out) => {
assert_eq!(out.shape, tensor.shape);
for (idx, (re, im)) in out.data.iter().enumerate() {
assert!(approx_eq((*re, *im), (tensor.data[idx], 0.0), 1e-12));
}
}
other => panic!("expected complex tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_symmetric_returns_real() {
let tensor = HostTensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let value = ifft2_builtin(
Value::ComplexTensor(spectrum.clone()),
vec![Value::from("symmetric")],
)
.expect("ifft2 symmetric");
match value {
Value::Tensor(out) => {
assert_eq!(out.shape, tensor.shape);
assert_eq!(out.data, tensor.data);
}
other => panic!("expected real tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_accepts_nonsymmetric_flag() {
let tensor = HostTensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let value = ifft2_builtin(
Value::ComplexTensor(spectrum.clone()),
vec![Value::from("nonsymmetric")],
)
.expect("ifft2 nonsymmetric");
let result = value_to_complex_tensor(value, "ifft2").expect("complex output");
assert_eq!(result.shape, tensor.shape);
for (idx, (re, im)) in result.data.iter().enumerate() {
assert!(approx_eq((*re, *im), (tensor.data[idx], 0.0), 1e-12));
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_accepts_scalar_length() {
let tensor = HostTensor::new((0..9).map(|v| v as f64).collect(), vec![3, 3]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let value = ifft2_builtin(
Value::ComplexTensor(spectrum),
vec![Value::Int(IntValue::I32(4))],
)
.expect("ifft2");
match value {
Value::ComplexTensor(out) => assert_eq!(out.shape, vec![4, 4]),
other => panic!("expected complex tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_accepts_size_vector() {
let tensor = HostTensor::new((0..6).map(|v| v as f64).collect(), vec![2, 3]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let size = HostTensor::new(vec![4.0, 2.0], vec![1, 2]).unwrap();
let value = ifft2_builtin(Value::ComplexTensor(spectrum), vec![Value::Tensor(size)])
.expect("ifft2");
match value {
Value::ComplexTensor(out) => assert_eq!(out.shape, vec![4, 2]),
other => panic!("expected complex tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_treats_empty_lengths_as_defaults() {
let tensor = HostTensor::new((0..6).map(|v| v as f64).collect(), vec![2, 3]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let empty_rows = HostTensor::new(vec![], vec![0]).unwrap();
let empty_cols = HostTensor::new(vec![], vec![0]).unwrap();
let value = ifft2_builtin(
Value::ComplexTensor(spectrum.clone()),
vec![Value::Tensor(empty_rows), Value::Tensor(empty_cols)],
)
.expect("ifft2");
match value {
Value::ComplexTensor(out) => {
assert_eq!(out.shape, tensor.shape);
for (idx, (re, im)) in out.data.iter().enumerate() {
assert!(approx_eq((*re, *im), (tensor.data[idx], 0.0), 1e-12));
}
}
other => panic!("expected complex tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_rejects_boolean_length() {
let tensor = HostTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let err = error_message(
ifft2_builtin(Value::ComplexTensor(spectrum), vec![Value::Bool(true)]).unwrap_err(),
);
assert!(err.contains("ifft2"));
assert!(err.contains("numeric"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_rejects_excess_arguments() {
let tensor = HostTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let err = error_message(
ifft2_builtin(
Value::ComplexTensor(spectrum),
vec![
Value::Int(IntValue::I32(2)),
Value::Int(IntValue::I32(2)),
Value::Int(IntValue::I32(2)),
],
)
.unwrap_err(),
);
assert!(err.contains("ifft2"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_zero_lengths_return_empty_result() {
let tensor = HostTensor::new((0..6).map(|v| v as f64).collect(), vec![2, 3]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let value = ifft2_builtin(
Value::ComplexTensor(spectrum),
vec![Value::Int(IntValue::I32(0)), Value::Int(IntValue::I32(0))],
)
.expect("ifft2");
match value {
Value::ComplexTensor(out) => {
assert!(out.data.is_empty());
assert_eq!(out.shape, vec![0, 0]);
}
other => panic!("expected complex tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_gpu_roundtrip_matches_cpu() {
test_support::with_test_provider(|provider| {
let tensor = HostTensor::new((0..8).map(|v| v as f64).collect(), vec![2, 4]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let view = HostTensorView {
data: &spectrum
.data
.iter()
.flat_map(|(re, im)| [*re, *im])
.collect::<Vec<_>>(),
shape: &[2, 4, 2],
};
let raw = provider.upload(&view).expect("upload spectrum");
let second = runmat_accelerate_api::GpuTensorHandle {
shape: spectrum.shape.clone(),
device_id: raw.device_id,
buffer_id: raw.buffer_id,
};
runmat_accelerate_api::set_handle_storage(
&second,
runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved,
);
let gpu =
ifft2_builtin(Value::GpuTensor(second.clone()), Vec::new()).expect("ifft2 gpu");
let cpu = ifft2_builtin(Value::ComplexTensor(spectrum.clone()), Vec::new())
.expect("ifft2 cpu");
let g = value_to_host_complex(gpu);
let c = value_to_host_complex(cpu);
assert_eq!(g.shape, c.shape);
for (lhs, rhs) in g.data.iter().zip(c.data.iter()) {
assert!(approx_eq(*lhs, *rhs, 1e-10), "{lhs:?} vs {rhs:?}");
}
provider.free(&raw).ok();
provider.free(&second).ok();
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_handles_row_and_column_lengths() {
let tensor = HostTensor::new((0..12).map(|v| v as f64).collect(), vec![3, 4]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let value = ifft2_builtin(
Value::ComplexTensor(spectrum),
vec![Value::Int(IntValue::I32(5)), Value::Int(IntValue::I32(2))],
)
.expect("ifft2");
match value {
Value::ComplexTensor(out) => assert_eq!(out.shape, vec![5, 2]),
other => panic!("expected complex tensor, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_rejects_unknown_symmetry_flag() {
let err = error_message(parse_ifft2_arguments(&[Value::from("invalid")]).unwrap_err());
assert!(err.contains("unrecognized option"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ifft2_requires_symflag_last() {
let tensor = HostTensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let err = error_message(
ifft2_builtin(
Value::ComplexTensor(spectrum),
vec![Value::from("symmetric"), Value::Int(IntValue::I32(2))],
)
.unwrap_err(),
);
assert!(err.contains("symmetry flag must appear as the final argument"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "wgpu")]
fn ifft2_wgpu_matches_cpu() {
let provider = match std::panic::catch_unwind(|| {
runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
}) {
Ok(Ok(Some(provider))) => provider,
_ => return,
};
let tensor = HostTensor::new((0..16).map(|v| v as f64).collect(), vec![4, 4]).unwrap();
let spectrum = fft2_of_tensor(&tensor);
let host_real_imag = spectrum
.data
.iter()
.flat_map(|(re, im)| [*re, *im])
.collect::<Vec<_>>();
let view = HostTensorView {
data: &host_real_imag,
shape: &[4, 4, 2],
};
let raw = provider.upload(&view).expect("upload spectrum");
let second = runmat_accelerate_api::GpuTensorHandle {
shape: spectrum.shape.clone(),
device_id: raw.device_id,
buffer_id: raw.buffer_id,
};
runmat_accelerate_api::set_handle_storage(
&second,
runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved,
);
let gpu_val =
ifft2_builtin(Value::GpuTensor(second.clone()), Vec::new()).expect("ifft2 gpu");
let cpu_val = ifft2_builtin(Value::ComplexTensor(spectrum), Vec::new()).expect("ifft2 cpu");
let gpu_ct = value_to_host_complex(gpu_val);
let cpu_ct = value_to_host_complex(cpu_val);
assert_eq!(gpu_ct.shape, cpu_ct.shape);
let tol = match provider.precision() {
runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
};
for (lhs, rhs) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
assert!(approx_eq(*lhs, *rhs, tol), "{lhs:?} vs {rhs:?}");
}
provider.free(&raw).ok();
provider.free(&second).ok();
}
fn ifft2_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
block_on(super::ifft2_builtin(value, rest))
}
}