Skip to main content

runmat_runtime/builtins/math/fft/
forward.rs

1//! MATLAB-compatible `fft` builtin with GPU-aware semantics for RunMat.
2
3use super::common::{
4    default_dimension, gather_gpu_complex_tensor, parse_length, transform_complex_tensor,
5    value_to_complex_tensor, TransformDirection,
6};
7use runmat_accelerate_api::GpuTensorHandle;
8use runmat_builtins::{ComplexTensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{shape::normalize_scalar_shape, tensor};
17use crate::builtins::math::fft::type_resolvers::fft_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::forward")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "fft",
23    op_kind: GpuOpKind::Custom("fft"),
24    supported_precisions: &[ScalarType::F32, ScalarType::F64],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[ProviderHook::Custom("fft_dim")],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::NewHandle,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::forward")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "fft",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes:
45        "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
46};
47
48const BUILTIN_NAME: &str = "fft";
49
50fn fft_error(message: impl Into<String>) -> RuntimeError {
51    build_runtime_error(message)
52        .with_builtin(BUILTIN_NAME)
53        .build()
54}
55
56#[runtime_builtin(
57    name = "fft",
58    category = "math/fft",
59    summary = "Compute the discrete Fourier transform (DFT) of numeric or complex data.",
60    keywords = "fft,fourier transform,complex,gpu",
61    type_resolver(fft_type),
62    builtin_path = "crate::builtins::math::fft::forward"
63)]
64async fn fft_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
65    let (length, dimension) = parse_arguments(&rest).await?;
66    match value {
67        Value::GpuTensor(handle) => fft_gpu(handle, length, dimension).await,
68        other => fft_host(other, length, dimension),
69    }
70}
71
72fn fft_host(value: Value, length: Option<usize>, dimension: Option<usize>) -> BuiltinResult<Value> {
73    let tensor = value_to_complex_tensor(value, BUILTIN_NAME)?;
74    let transformed = fft_complex_tensor(tensor, length, dimension)?;
75    Ok(complex_tensor_into_value(transformed))
76}
77
78async fn fft_gpu(
79    handle: GpuTensorHandle,
80    length: Option<usize>,
81    dimension: Option<usize>,
82) -> BuiltinResult<Value> {
83    let mut shape = normalize_scalar_shape(&handle.shape);
84
85    let dim_one_based = match dimension {
86        Some(0) => return Err(fft_error("fft: dimension must be >= 1")),
87        Some(dim) => dim,
88        None => default_dimension(&shape),
89    };
90
91    let dim_index = dim_one_based - 1;
92    while shape.len() <= dim_index {
93        shape.push(1);
94    }
95    let current_len = shape[dim_index];
96    let target_len = length.unwrap_or(current_len);
97
98    if target_len == 0 {
99        let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
100        let transformed = fft_complex_tensor(complex, length, dimension)?;
101        return Ok(complex_tensor_into_value(transformed));
102    }
103
104    if let Some(provider) = runmat_accelerate_api::provider() {
105        if let Ok(out) = provider.fft_dim(&handle, length, dim_index).await {
106            return Ok(Value::GpuTensor(out));
107        }
108    }
109
110    let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
111    let transformed = fft_complex_tensor(complex, length, dimension)?;
112    Ok(complex_tensor_into_value(transformed))
113}
114
115async fn parse_dimension_arg(value: &Value) -> BuiltinResult<usize> {
116    tensor::dimension_from_value_async(value, BUILTIN_NAME, false)
117        .await
118        .map_err(fft_error)?
119        .ok_or_else(|| {
120            fft_error(format!(
121                "{BUILTIN_NAME}: dimension must be numeric, got {value:?}"
122            ))
123        })
124}
125
126async fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<usize>, Option<usize>)> {
127    match args.len() {
128        0 => Ok((None, None)),
129        1 => {
130            let len = parse_length(&args[0], BUILTIN_NAME)?;
131            Ok((len, None))
132        }
133        2 => {
134            let len = parse_length(&args[0], BUILTIN_NAME)?;
135            let dim = Some(parse_dimension_arg(&args[1]).await?);
136            Ok((len, dim))
137        }
138        _ => Err(fft_error(
139            "fft: expected fft(X), fft(X, N), or fft(X, N, DIM)",
140        )),
141    }
142}
143
144pub(super) fn fft_complex_tensor(
145    tensor: ComplexTensor,
146    length: Option<usize>,
147    dimension: Option<usize>,
148) -> BuiltinResult<ComplexTensor> {
149    transform_complex_tensor(
150        tensor,
151        length,
152        dimension,
153        TransformDirection::Forward,
154        BUILTIN_NAME,
155    )
156}
157
158#[cfg(test)]
159pub(crate) mod tests {
160    use super::*;
161    use crate::builtins::common::test_support;
162    use crate::builtins::math::fft::common;
163    use futures::executor::block_on;
164    use num_complex::Complex;
165    #[cfg(feature = "wgpu")]
166    use runmat_accelerate_api::AccelProvider;
167    use runmat_builtins::{
168        ComplexTensor as HostComplexTensor, IntValue, ResolveContext, Tensor, Type,
169    };
170    use rustfft::FftPlanner;
171
172    fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
173        (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
174    }
175
176    fn error_message(error: crate::RuntimeError) -> String {
177        error.message().to_string()
178    }
179
180    fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
181        match value {
182            Value::ComplexTensor(tensor) => tensor,
183            Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
184            Value::GpuTensor(handle) => {
185                let provider = runmat_accelerate_api::provider_for_handle(&handle)
186                    .or_else(runmat_accelerate_api::provider)
187                    .expect("provider for gpu handle");
188                let host = block_on(provider.download(&handle)).expect("download gpu fft output");
189                common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
190            }
191            other => panic!("expected complex tensor, got {other:?}"),
192        }
193    }
194
195    fn fft_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
196        block_on(super::fft_builtin(value, rest))
197    }
198
199    #[test]
200    fn fft_type_preserves_shape() {
201        let out = fft_type(
202            &[Type::Tensor {
203                shape: Some(vec![Some(2), Some(3)]),
204            }],
205            &ResolveContext::new(Vec::new()),
206        );
207        assert_eq!(
208            out,
209            Type::Tensor {
210                shape: Some(vec![Some(2), Some(3)])
211            }
212        );
213    }
214
215    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
216    #[test]
217    fn fft_real_vector() {
218        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
219        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
220        match result {
221            Value::ComplexTensor(ct) => {
222                assert_eq!(ct.shape, vec![4]);
223                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
224                for (idx, val) in ct.data.iter().enumerate() {
225                    assert!(
226                        approx_eq(*val, expected[idx], 1e-12),
227                        "idx {idx} {:?} ~= {:?}",
228                        val,
229                        expected[idx]
230                    );
231                }
232            }
233            other => panic!("expected complex tensor, got {other:?}"),
234        }
235    }
236
237    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
238    #[test]
239    fn fft_matrix_default_dimension() {
240        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
241        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
242        match result {
243            Value::ComplexTensor(ct) => {
244                assert_eq!(ct.shape, vec![2, 3]);
245                let expected = [
246                    (5.0, 0.0),
247                    (-3.0, 0.0),
248                    (7.0, 0.0),
249                    (-3.0, 0.0),
250                    (9.0, 0.0),
251                    (-3.0, 0.0),
252                ];
253                for (idx, val) in ct.data.iter().enumerate() {
254                    assert!(approx_eq(*val, expected[idx], 1e-12));
255                }
256            }
257            other => panic!("expected complex tensor, got {other:?}"),
258        }
259    }
260
261    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
262    #[test]
263    fn fft_zero_padding_with_length_argument() {
264        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
265        let result =
266            fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
267        match result {
268            Value::ComplexTensor(ct) => {
269                assert_eq!(ct.shape, vec![5]);
270                assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
271                assert_eq!(ct.data.len(), 5);
272            }
273            other => panic!("expected complex tensor, got {other:?}"),
274        }
275    }
276
277    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
278    #[test]
279    fn fft_empty_length_argument_defaults_to_input_length() {
280        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
281        let baseline =
282            fft_builtin_sync(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
283        let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
284        let result = fft_builtin_sync(
285            Value::Tensor(tensor),
286            vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
287        )
288        .expect("fft with empty length");
289        let base_ct = value_as_complex_tensor(baseline);
290        let result_ct = value_as_complex_tensor(result);
291        assert_eq!(base_ct.shape, result_ct.shape);
292        assert_eq!(base_ct.data.len(), result_ct.data.len());
293        for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
294            assert!(
295                approx_eq(*a, *b, 1e-12),
296                "mismatch at index {idx}: {:?} vs {:?}",
297                a,
298                b
299            );
300        }
301    }
302
303    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304    #[test]
305    fn fft_truncates_when_length_smaller() {
306        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
307        let result =
308            fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
309        match result {
310            Value::ComplexTensor(ct) => {
311                assert_eq!(ct.shape, vec![2]);
312                let expected = [(3.0, 0.0), (-1.0, 0.0)];
313                for (idx, val) in ct.data.iter().enumerate() {
314                    assert!(approx_eq(*val, expected[idx], 1e-12));
315                }
316            }
317            other => panic!("expected complex tensor, got {other:?}"),
318        }
319    }
320
321    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
322    #[test]
323    fn fft_zero_length_returns_empty_tensor() {
324        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
325        let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
326        match result {
327            Value::ComplexTensor(ct) => {
328                assert_eq!(ct.shape, vec![0]);
329                assert!(ct.data.is_empty());
330            }
331            other => panic!("expected complex tensor, got {other:?}"),
332        }
333    }
334
335    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
336    #[test]
337    fn fft_complex_input_preserves_imaginary_components() {
338        let tensor =
339            HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
340        let result =
341            fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
342        let mut expected = tensor
343            .data
344            .iter()
345            .map(|(re, im)| Complex::new(*re, *im))
346            .collect::<Vec<_>>();
347        FftPlanner::<f64>::new()
348            .plan_fft_forward(expected.len())
349            .process(&mut expected);
350        match result {
351            Value::ComplexTensor(ct) => {
352                assert_eq!(ct.shape, vec![3]);
353                assert_eq!(ct.data.len(), 3);
354                for (idx, val) in ct.data.iter().enumerate() {
355                    let exp = expected[idx];
356                    assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
357                }
358            }
359            other => panic!("expected complex tensor, got {other:?}"),
360        }
361    }
362
363    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364    #[test]
365    fn fft_row_vector_dimension_two() {
366        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
367        let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
368        match result {
369            Value::ComplexTensor(ct) => {
370                assert_eq!(ct.shape, vec![1, 4]);
371                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
372                for (idx, val) in ct.data.iter().enumerate() {
373                    assert!(approx_eq(*val, expected[idx], 1e-12));
374                }
375            }
376            other => panic!("expected complex tensor, got {other:?}"),
377        }
378    }
379
380    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
381    #[test]
382    fn fft_dimension_extends_rank() {
383        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
384        let original = tensor.clone();
385        let result =
386            fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
387        match result {
388            Value::ComplexTensor(ct) => {
389                assert_eq!(ct.shape, vec![1, 4, 1]);
390                assert_eq!(ct.data.len(), original.data.len());
391                for (idx, (re, im)) in ct.data.iter().enumerate() {
392                    assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
393                }
394            }
395            other => panic!("expected complex tensor, got {other:?}"),
396        }
397    }
398
399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400    #[test]
401    fn fft_dimension_extends_rank_with_padding() {
402        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
403        let original = tensor.clone();
404        let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
405            .expect("fft with padded third dimension");
406        match result {
407            Value::ComplexTensor(ct) => {
408                assert_eq!(ct.shape, vec![1, 4, 4]);
409                let mut expected = Vec::with_capacity(16);
410                for _depth in 0..4 {
411                    for &value in &original.data {
412                        expected.push((value, 0.0));
413                    }
414                }
415                assert_eq!(ct.data.len(), expected.len());
416                for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
417                    assert!(
418                        approx_eq(*actual, *expected, 1e-12),
419                        "idx {idx}: {:?} != {:?}",
420                        actual,
421                        expected
422                    );
423                }
424            }
425            other => panic!("expected complex tensor, got {other:?}"),
426        }
427    }
428
429    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
430    #[test]
431    fn fft_rejects_non_numeric_length() {
432        assert!(block_on(parse_arguments(&[Value::Bool(true)])).is_err());
433    }
434
435    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
436    #[test]
437    fn fft_rejects_negative_length() {
438        let err = error_message(block_on(parse_arguments(&[Value::Num(-1.0)])).unwrap_err());
439        assert!(err.contains("length must be non-negative"));
440    }
441
442    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
443    #[test]
444    fn fft_rejects_fractional_length() {
445        let err = error_message(block_on(parse_arguments(&[Value::Num(1.5)])).unwrap_err());
446        assert!(err.contains("length must be an integer"));
447    }
448
449    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
450    #[test]
451    fn fft_rejects_dimension_zero() {
452        let err = error_message(
453            block_on(parse_arguments(&[
454                Value::Num(4.0),
455                Value::Int(IntValue::I32(0)),
456            ]))
457            .unwrap_err(),
458        );
459        assert!(err.contains("dimension must be >= 1"));
460    }
461
462    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
463    #[test]
464    fn fft_accepts_scalar_tensor_dimension_argument() {
465        let dim = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
466        let (len, parsed_dim) = block_on(parse_arguments(&[Value::Num(4.0), Value::Tensor(dim)]))
467            .expect("parse arguments");
468        assert_eq!(len, Some(4));
469        assert_eq!(parsed_dim, Some(2));
470    }
471
472    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
473    #[test]
474    fn fft_gpu_roundtrip_matches_cpu() {
475        test_support::with_test_provider(|provider| {
476            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
477            let view = runmat_accelerate_api::HostTensorView {
478                data: &tensor.data,
479                shape: &tensor.shape,
480            };
481            let handle = provider.upload(&view).expect("upload");
482            let gpu = fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
483            let cpu = fft_builtin_sync(Value::Tensor(tensor), Vec::new()).expect("fft");
484            let gpu_host = value_as_complex_tensor(gpu);
485            let cpu_host = value_as_complex_tensor(cpu);
486            assert_eq!(gpu_host.shape, cpu_host.shape);
487            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
488                assert!(approx_eq(*a, *b, 1e-12));
489            }
490            provider.free(&handle).ok();
491        });
492    }
493
494    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495    #[test]
496    fn fft_gpu_non_power_of_two_length_matches_cpu() {
497        test_support::with_test_provider(|provider| {
498            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
499            let view = runmat_accelerate_api::HostTensorView {
500                data: &tensor.data,
501                shape: &tensor.shape,
502            };
503            let handle = provider.upload(&view).expect("upload");
504            let gpu = fft_builtin_sync(
505                Value::GpuTensor(handle.clone()),
506                vec![Value::Int(IntValue::I32(7))],
507            )
508            .expect("fft gpu");
509            let cpu = fft_builtin_sync(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(7))])
510                .expect("fft cpu");
511            let gpu_host = value_as_complex_tensor(gpu);
512            let cpu_host = value_as_complex_tensor(cpu);
513            assert_eq!(gpu_host.shape, cpu_host.shape);
514            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
515                assert!(approx_eq(*a, *b, 1e-10));
516            }
517            provider.free(&handle).ok();
518        });
519    }
520
521    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
522    #[test]
523    fn fft_gpu_prime_length_on_non_last_dimension_matches_cpu() {
524        test_support::with_test_provider(|provider| {
525            let tensor = Tensor::new((1..=18).map(|v| v as f64).collect(), vec![2, 3, 3]).unwrap();
526            let view = runmat_accelerate_api::HostTensorView {
527                data: &tensor.data,
528                shape: &tensor.shape,
529            };
530            let handle = provider.upload(&view).expect("upload");
531            let args = vec![Value::Int(IntValue::I32(7)), Value::Int(IntValue::I32(2))];
532            let gpu =
533                fft_builtin_sync(Value::GpuTensor(handle.clone()), args.clone()).expect("fft gpu");
534            let cpu = fft_builtin_sync(Value::Tensor(tensor), args).expect("fft cpu");
535            let gpu_host = value_as_complex_tensor(gpu);
536            let cpu_host = value_as_complex_tensor(cpu);
537            assert_eq!(gpu_host.shape, cpu_host.shape);
538            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
539                assert!(approx_eq(*a, *b, 1e-10), "{a:?} vs {b:?}");
540            }
541            provider.free(&handle).ok();
542        });
543    }
544
545    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
546    #[test]
547    #[cfg(feature = "wgpu")]
548    fn fft_wgpu_matches_cpu() {
549        if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
550            .expect("wgpu provider")
551        {
552            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
553            let tensor_cpu = tensor.clone();
554            let view = runmat_accelerate_api::HostTensorView {
555                data: &tensor.data,
556                shape: &tensor.shape,
557            };
558            let handle = provider.upload(&view).expect("upload");
559            let gpu =
560                fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
561            let cpu = fft_builtin_sync(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
562            let gpu_ct = value_as_complex_tensor(gpu);
563            let cpu_ct = value_as_complex_tensor(cpu);
564            let tol = match provider.precision() {
565                runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
566                runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
567            };
568            assert_eq!(gpu_ct.shape, cpu_ct.shape);
569            for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
570                assert!(approx_eq(*a, *b, tol), "{a:?} vs {b:?}");
571            }
572            provider.free(&handle).ok();
573        }
574    }
575}