Skip to main content

runmat_runtime/builtins/math/fft/
forward.rs

1//! MATLAB-compatible `fft` builtin with GPU-aware semantics for RunMat.
2
3use super::common::{
4    default_dimension, gather_gpu_complex_tensor, parse_length, transform_complex_tensor,
5    value_to_complex_tensor, TransformDirection,
6};
7use runmat_accelerate_api::GpuTensorHandle;
8use runmat_builtins::{ComplexTensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{shape::normalize_scalar_shape, tensor};
17use crate::builtins::math::fft::type_resolvers::fft_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::forward")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "fft",
23    op_kind: GpuOpKind::Custom("fft"),
24    supported_precisions: &[ScalarType::F32, ScalarType::F64],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[ProviderHook::Custom("fft_dim")],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::NewHandle,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::forward")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "fft",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes:
45        "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
46};
47
48const BUILTIN_NAME: &str = "fft";
49
50fn fft_error(message: impl Into<String>) -> RuntimeError {
51    build_runtime_error(message)
52        .with_builtin(BUILTIN_NAME)
53        .build()
54}
55
56#[runtime_builtin(
57    name = "fft",
58    category = "math/fft",
59    summary = "Compute the discrete Fourier transform (DFT) of numeric or complex data.",
60    keywords = "fft,fourier transform,complex,gpu",
61    type_resolver(fft_type),
62    builtin_path = "crate::builtins::math::fft::forward"
63)]
64async fn fft_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
65    let (length, dimension) = parse_arguments(&rest).await?;
66    match value {
67        Value::GpuTensor(handle) => fft_gpu(handle, length, dimension).await,
68        other => fft_host(other, length, dimension),
69    }
70}
71
72fn fft_host(value: Value, length: Option<usize>, dimension: Option<usize>) -> BuiltinResult<Value> {
73    let tensor = value_to_complex_tensor(value, BUILTIN_NAME)?;
74    let transformed = fft_complex_tensor(tensor, length, dimension)?;
75    Ok(complex_tensor_into_value(transformed))
76}
77
78async fn fft_gpu(
79    handle: GpuTensorHandle,
80    length: Option<usize>,
81    dimension: Option<usize>,
82) -> BuiltinResult<Value> {
83    let mut shape = normalize_scalar_shape(&handle.shape);
84
85    let dim_one_based = match dimension {
86        Some(0) => return Err(fft_error("fft: dimension must be >= 1")),
87        Some(dim) => dim,
88        None => default_dimension(&shape),
89    };
90
91    let dim_index = dim_one_based - 1;
92    while shape.len() <= dim_index {
93        shape.push(1);
94    }
95    let current_len = shape[dim_index];
96    let target_len = length.unwrap_or(current_len);
97
98    if target_len == 0 {
99        let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
100        let transformed = fft_complex_tensor(complex, length, dimension)?;
101        return Ok(complex_tensor_into_value(transformed));
102    }
103
104    if let Some(provider) = runmat_accelerate_api::provider() {
105        if let Ok(out) = provider.fft_dim(&handle, length, dim_index).await {
106            return Ok(Value::GpuTensor(out));
107        }
108    }
109
110    let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
111    let transformed = fft_complex_tensor(complex, length, dimension)?;
112    Ok(complex_tensor_into_value(transformed))
113}
114
115async fn parse_dimension_arg(value: &Value) -> BuiltinResult<usize> {
116    tensor::dimension_from_value_async(value, BUILTIN_NAME, false)
117        .await
118        .map_err(fft_error)?
119        .ok_or_else(|| {
120            fft_error(format!(
121                "{BUILTIN_NAME}: dimension must be numeric, got {value:?}"
122            ))
123        })
124}
125
126async fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<usize>, Option<usize>)> {
127    match args.len() {
128        0 => Ok((None, None)),
129        1 => {
130            let len = parse_length(&args[0], BUILTIN_NAME)?;
131            Ok((len, None))
132        }
133        2 => {
134            let len = parse_length(&args[0], BUILTIN_NAME)?;
135            let dim = Some(parse_dimension_arg(&args[1]).await?);
136            Ok((len, dim))
137        }
138        _ => Err(fft_error(
139            "fft: expected fft(X), fft(X, N), or fft(X, N, DIM)",
140        )),
141    }
142}
143
144pub(super) fn fft_complex_tensor(
145    tensor: ComplexTensor,
146    length: Option<usize>,
147    dimension: Option<usize>,
148) -> BuiltinResult<ComplexTensor> {
149    transform_complex_tensor(
150        tensor,
151        length,
152        dimension,
153        TransformDirection::Forward,
154        BUILTIN_NAME,
155    )
156}
157
158#[cfg(test)]
159pub(crate) mod tests {
160    use super::*;
161    use crate::builtins::common::test_support;
162    use crate::builtins::math::fft::common;
163    use futures::executor::block_on;
164    use num_complex::Complex;
165    #[cfg(feature = "wgpu")]
166    use runmat_accelerate_api::AccelProvider;
167    use runmat_builtins::{
168        ComplexTensor as HostComplexTensor, IntValue, ResolveContext, Tensor, Type,
169    };
170    use rustfft::FftPlanner;
171
172    fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
173        (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
174    }
175
176    fn error_message(error: crate::RuntimeError) -> String {
177        error.message().to_string()
178    }
179
180    fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
181        match value {
182            Value::ComplexTensor(tensor) => tensor,
183            Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
184            Value::GpuTensor(handle) => {
185                let provider = runmat_accelerate_api::provider_for_handle(&handle)
186                    .or_else(runmat_accelerate_api::provider)
187                    .expect("provider for gpu handle");
188                let host = block_on(provider.download(&handle)).expect("download gpu fft output");
189                common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
190            }
191            other => panic!("expected complex tensor, got {other:?}"),
192        }
193    }
194
195    fn fft_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
196        block_on(super::fft_builtin(value, rest))
197    }
198
199    #[test]
200    fn fft_type_preserves_shape() {
201        let out = fft_type(
202            &[Type::Tensor {
203                shape: Some(vec![Some(2), Some(3)]),
204            }],
205            &ResolveContext::new(Vec::new()),
206        );
207        assert_eq!(
208            out,
209            Type::Tensor {
210                shape: Some(vec![Some(2), Some(3)])
211            }
212        );
213    }
214
215    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
216    #[test]
217    fn fft_real_vector() {
218        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
219        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
220        match result {
221            Value::ComplexTensor(ct) => {
222                assert_eq!(ct.shape, vec![4]);
223                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
224                for (idx, val) in ct.data.iter().enumerate() {
225                    assert!(
226                        approx_eq(*val, expected[idx], 1e-12),
227                        "idx {idx} {:?} ~= {:?}",
228                        val,
229                        expected[idx]
230                    );
231                }
232            }
233            other => panic!("expected complex tensor, got {other:?}"),
234        }
235    }
236
237    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
238    #[test]
239    fn fft_row_vector_default_dimension_preserves_orientation() {
240        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
241        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
242        match result {
243            Value::ComplexTensor(ct) => {
244                assert_eq!(ct.shape, vec![1, 4]);
245                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
246                for (idx, val) in ct.data.iter().enumerate() {
247                    assert!(
248                        approx_eq(*val, expected[idx], 1e-12),
249                        "idx {idx} {:?} ~= {:?}",
250                        val,
251                        expected[idx]
252                    );
253                }
254            }
255            other => panic!("expected complex tensor, got {other:?}"),
256        }
257    }
258
259    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
260    #[test]
261    fn fft_matrix_default_dimension() {
262        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
263        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
264        match result {
265            Value::ComplexTensor(ct) => {
266                assert_eq!(ct.shape, vec![2, 3]);
267                let expected = [
268                    (5.0, 0.0),
269                    (-3.0, 0.0),
270                    (7.0, 0.0),
271                    (-3.0, 0.0),
272                    (9.0, 0.0),
273                    (-3.0, 0.0),
274                ];
275                for (idx, val) in ct.data.iter().enumerate() {
276                    assert!(approx_eq(*val, expected[idx], 1e-12));
277                }
278            }
279            other => panic!("expected complex tensor, got {other:?}"),
280        }
281    }
282
283    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284    #[test]
285    fn fft_zero_padding_with_length_argument() {
286        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
287        let result =
288            fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
289        match result {
290            Value::ComplexTensor(ct) => {
291                assert_eq!(ct.shape, vec![5]);
292                assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
293                assert_eq!(ct.data.len(), 5);
294            }
295            other => panic!("expected complex tensor, got {other:?}"),
296        }
297    }
298
299    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
300    #[test]
301    fn fft_empty_length_argument_defaults_to_input_length() {
302        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
303        let baseline =
304            fft_builtin_sync(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
305        let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
306        let result = fft_builtin_sync(
307            Value::Tensor(tensor),
308            vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
309        )
310        .expect("fft with empty length");
311        let base_ct = value_as_complex_tensor(baseline);
312        let result_ct = value_as_complex_tensor(result);
313        assert_eq!(base_ct.shape, result_ct.shape);
314        assert_eq!(base_ct.data.len(), result_ct.data.len());
315        for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
316            assert!(
317                approx_eq(*a, *b, 1e-12),
318                "mismatch at index {idx}: {:?} vs {:?}",
319                a,
320                b
321            );
322        }
323    }
324
325    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
326    #[test]
327    fn fft_truncates_when_length_smaller() {
328        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
329        let result =
330            fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
331        match result {
332            Value::ComplexTensor(ct) => {
333                assert_eq!(ct.shape, vec![2]);
334                let expected = [(3.0, 0.0), (-1.0, 0.0)];
335                for (idx, val) in ct.data.iter().enumerate() {
336                    assert!(approx_eq(*val, expected[idx], 1e-12));
337                }
338            }
339            other => panic!("expected complex tensor, got {other:?}"),
340        }
341    }
342
343    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
344    #[test]
345    fn fft_zero_length_returns_empty_tensor() {
346        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
347        let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
348        match result {
349            Value::ComplexTensor(ct) => {
350                assert_eq!(ct.shape, vec![0]);
351                assert!(ct.data.is_empty());
352            }
353            other => panic!("expected complex tensor, got {other:?}"),
354        }
355    }
356
357    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
358    #[test]
359    fn fft_complex_input_preserves_imaginary_components() {
360        let tensor =
361            HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
362        let result =
363            fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
364        let mut expected = tensor
365            .data
366            .iter()
367            .map(|(re, im)| Complex::new(*re, *im))
368            .collect::<Vec<_>>();
369        FftPlanner::<f64>::new()
370            .plan_fft_forward(expected.len())
371            .process(&mut expected);
372        match result {
373            Value::ComplexTensor(ct) => {
374                assert_eq!(ct.shape, vec![3]);
375                assert_eq!(ct.data.len(), 3);
376                for (idx, val) in ct.data.iter().enumerate() {
377                    let exp = expected[idx];
378                    assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
379                }
380            }
381            other => panic!("expected complex tensor, got {other:?}"),
382        }
383    }
384
385    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
386    #[test]
387    fn fft_row_vector_dimension_two() {
388        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
389        let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
390        match result {
391            Value::ComplexTensor(ct) => {
392                assert_eq!(ct.shape, vec![1, 4]);
393                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
394                for (idx, val) in ct.data.iter().enumerate() {
395                    assert!(approx_eq(*val, expected[idx], 1e-12));
396                }
397            }
398            other => panic!("expected complex tensor, got {other:?}"),
399        }
400    }
401
402    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403    #[test]
404    fn fft_dimension_extends_rank() {
405        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
406        let original = tensor.clone();
407        let result =
408            fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
409        match result {
410            Value::ComplexTensor(ct) => {
411                assert_eq!(ct.shape, vec![1, 4, 1]);
412                assert_eq!(ct.data.len(), original.data.len());
413                for (idx, (re, im)) in ct.data.iter().enumerate() {
414                    assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
415                }
416            }
417            other => panic!("expected complex tensor, got {other:?}"),
418        }
419    }
420
421    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
422    #[test]
423    fn fft_dimension_extends_rank_with_padding() {
424        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
425        let original = tensor.clone();
426        let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
427            .expect("fft with padded third dimension");
428        match result {
429            Value::ComplexTensor(ct) => {
430                assert_eq!(ct.shape, vec![1, 4, 4]);
431                let mut expected = Vec::with_capacity(16);
432                for _depth in 0..4 {
433                    for &value in &original.data {
434                        expected.push((value, 0.0));
435                    }
436                }
437                assert_eq!(ct.data.len(), expected.len());
438                for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
439                    assert!(
440                        approx_eq(*actual, *expected, 1e-12),
441                        "idx {idx}: {:?} != {:?}",
442                        actual,
443                        expected
444                    );
445                }
446            }
447            other => panic!("expected complex tensor, got {other:?}"),
448        }
449    }
450
451    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
452    #[test]
453    fn fft_rejects_non_numeric_length() {
454        assert!(block_on(parse_arguments(&[Value::Bool(true)])).is_err());
455    }
456
457    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
458    #[test]
459    fn fft_rejects_negative_length() {
460        let err = error_message(block_on(parse_arguments(&[Value::Num(-1.0)])).unwrap_err());
461        assert!(err.contains("length must be non-negative"));
462    }
463
464    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
465    #[test]
466    fn fft_rejects_fractional_length() {
467        let err = error_message(block_on(parse_arguments(&[Value::Num(1.5)])).unwrap_err());
468        assert!(err.contains("length must be an integer"));
469    }
470
471    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
472    #[test]
473    fn fft_rejects_dimension_zero() {
474        let err = error_message(
475            block_on(parse_arguments(&[
476                Value::Num(4.0),
477                Value::Int(IntValue::I32(0)),
478            ]))
479            .unwrap_err(),
480        );
481        assert!(err.contains("dimension must be >= 1"));
482    }
483
484    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
485    #[test]
486    fn fft_accepts_scalar_tensor_dimension_argument() {
487        let dim = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
488        let (len, parsed_dim) = block_on(parse_arguments(&[Value::Num(4.0), Value::Tensor(dim)]))
489            .expect("parse arguments");
490        assert_eq!(len, Some(4));
491        assert_eq!(parsed_dim, Some(2));
492    }
493
494    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495    #[test]
496    fn fft_gpu_roundtrip_matches_cpu() {
497        test_support::with_test_provider(|provider| {
498            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
499            let view = runmat_accelerate_api::HostTensorView {
500                data: &tensor.data,
501                shape: &tensor.shape,
502            };
503            let handle = provider.upload(&view).expect("upload");
504            let gpu = fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
505            let cpu = fft_builtin_sync(Value::Tensor(tensor), Vec::new()).expect("fft");
506            let gpu_host = value_as_complex_tensor(gpu);
507            let cpu_host = value_as_complex_tensor(cpu);
508            assert_eq!(gpu_host.shape, cpu_host.shape);
509            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
510                assert!(approx_eq(*a, *b, 1e-12));
511            }
512            provider.free(&handle).ok();
513        });
514    }
515
516    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517    #[test]
518    fn fft_gpu_non_power_of_two_length_matches_cpu() {
519        test_support::with_test_provider(|provider| {
520            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
521            let view = runmat_accelerate_api::HostTensorView {
522                data: &tensor.data,
523                shape: &tensor.shape,
524            };
525            let handle = provider.upload(&view).expect("upload");
526            let gpu = fft_builtin_sync(
527                Value::GpuTensor(handle.clone()),
528                vec![Value::Int(IntValue::I32(7))],
529            )
530            .expect("fft gpu");
531            let cpu = fft_builtin_sync(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(7))])
532                .expect("fft cpu");
533            let gpu_host = value_as_complex_tensor(gpu);
534            let cpu_host = value_as_complex_tensor(cpu);
535            assert_eq!(gpu_host.shape, cpu_host.shape);
536            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
537                assert!(approx_eq(*a, *b, 1e-10));
538            }
539            provider.free(&handle).ok();
540        });
541    }
542
543    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
544    #[test]
545    fn fft_gpu_prime_length_on_non_last_dimension_matches_cpu() {
546        test_support::with_test_provider(|provider| {
547            let tensor = Tensor::new((1..=18).map(|v| v as f64).collect(), vec![2, 3, 3]).unwrap();
548            let view = runmat_accelerate_api::HostTensorView {
549                data: &tensor.data,
550                shape: &tensor.shape,
551            };
552            let handle = provider.upload(&view).expect("upload");
553            let args = vec![Value::Int(IntValue::I32(7)), Value::Int(IntValue::I32(2))];
554            let gpu =
555                fft_builtin_sync(Value::GpuTensor(handle.clone()), args.clone()).expect("fft gpu");
556            let cpu = fft_builtin_sync(Value::Tensor(tensor), args).expect("fft cpu");
557            let gpu_host = value_as_complex_tensor(gpu);
558            let cpu_host = value_as_complex_tensor(cpu);
559            assert_eq!(gpu_host.shape, cpu_host.shape);
560            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
561                assert!(approx_eq(*a, *b, 1e-10), "{a:?} vs {b:?}");
562            }
563            provider.free(&handle).ok();
564        });
565    }
566
567    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
568    #[test]
569    #[cfg(feature = "wgpu")]
570    fn fft_wgpu_matches_cpu() {
571        if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
572            .expect("wgpu provider")
573        {
574            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
575            let tensor_cpu = tensor.clone();
576            let view = runmat_accelerate_api::HostTensorView {
577                data: &tensor.data,
578                shape: &tensor.shape,
579            };
580            let handle = provider.upload(&view).expect("upload");
581            let gpu =
582                fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
583            let cpu = fft_builtin_sync(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
584            let gpu_ct = value_as_complex_tensor(gpu);
585            let cpu_ct = value_as_complex_tensor(cpu);
586            let tol = match provider.precision() {
587                runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
588                runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
589            };
590            assert_eq!(gpu_ct.shape, cpu_ct.shape);
591            for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
592                assert!(approx_eq(*a, *b, tol), "{a:?} vs {b:?}");
593            }
594            provider.free(&handle).ok();
595        }
596    }
597}