runmat_runtime/builtins/math/fft/
forward.rs

1//! MATLAB-compatible `fft` builtin with GPU-aware semantics for RunMat.
2
3use super::common::{
4    default_dimension, host_to_complex_tensor, parse_length, tensor_to_complex_tensor,
5    trim_trailing_ones, value_to_complex_tensor,
6};
7use num_complex::Complex;
8use runmat_accelerate_api::{AccelProvider, GpuTensorHandle};
9use runmat_builtins::{ComplexTensor, Value};
10use runmat_macros::runtime_builtin;
11use rustfft::FftPlanner;
12use std::sync::Arc;
13
14use crate::builtins::common::random_args::complex_tensor_into_value;
15use crate::builtins::common::spec::{
16    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{gpu_helpers, tensor};
20#[cfg(feature = "doc_export")]
21use crate::register_builtin_doc_text;
22use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
23
24#[cfg(feature = "doc_export")]
25pub const DOC_MD: &str = r#"---
26title: "fft"
27category: "math/fft"
28keywords: ["fft", "fourier transform", "complex", "zero padding", "gpu"]
29summary: "Compute the discrete Fourier transform (DFT) of vectors, matrices, or N-D tensors."
30references:
31  - title: "MATLAB fft documentation"
32    url: "https://www.mathworks.com/help/matlab/ref/fft.html"
33gpu_support:
34  elementwise: false
35  reduction: false
36  precisions: ["f32", "f64"]
37  broadcasting: "matlab"
38  notes: "Falls back to host execution when the active acceleration provider does not expose an FFT hook."
39fusion:
40  elementwise: false
41  reduction: false
42  max_inputs: 1
43  constants: "inline"
44requires_feature: null
45tested:
46  unit: "builtins::math::fft::fft::tests"
47  integration: "builtins::math::fft::fft::tests::fft_gpu_roundtrip_matches_cpu"
48---
49
50# What does the `fft` function do in MATLAB / RunMat?
51`fft(X)` computes the discrete Fourier transform (DFT) of the input data. When `X` is a vector,
52`fft` returns the frequency-domain representation of the vector. When `X` is a matrix or an
53N-D tensor, the transform is applied along the first non-singleton dimension unless another
54dimension is specified.
55
56## How does the `fft` function behave in MATLAB / RunMat?
57- `fft(X)` transforms along the first dimension whose size is greater than 1.
58- `fft(X, n)` zero-pads or truncates `X` to length `n` before transforming along the default dimension.
59- `fft(X, n, dim)` applies the transform along dimension `dim`.
60- Real inputs produce complex outputs; complex inputs are handled element-wise with no additional conversion.
61- Empty inputs remain empty; zero-padding with `n` produces zero-valued spectra.
62- GPU arrays are gathered to the host when the selected provider has no FFT implementation.
63
64## Examples of using the `fft` function in MATLAB / RunMat
65
66### Computing the FFT of a real time-domain vector
67```matlab
68x = [1 2 3 4];
69Y = fft(x);
70```
71Expected output (RunMat prints complex numbers with `a + bi` formatting):
72```matlab
73Y =
74  Columns 1 through 4
75   10 + 0i  -2 + 2i  -2 + 0i  -2 - 2i
76```
77
78### Applying fft column-wise to a matrix
79```matlab
80A = [1 2 3; 4 5 6];
81F = fft(A);
82```
83Expected output:
84```matlab
85F =
86   5 + 0i   7 + 0i   9 + 0i
87  -3 + 3i  -3 + 3i  -3 + 3i
88```
89
90### Zero-padding before the FFT
91```matlab
92x = [1 2 3];
93Y = fft(x, 5);
94```
95The transform is computed on a length-5 sequence `[1 2 3 0 0]`, producing five complex frequency bins.
96
97### Selecting the transform dimension for a row vector
98```matlab
99x = [1 2 3 4];
100Y = fft(x, [], 2);
101```
102`Y` matches `fft(x)` because the transform is applied along dimension 2 (the row).
103
104### FFT of a complex-valued signal
105```matlab
106t = 0:3;
107x = exp(1i * pi/2 * t);
108Y = fft(x);
109```
110The complex sinusoid is mapped to a single non-zero frequency bin at the expected location.
111
112### FFT with gpuArray inputs
113```matlab
114g = gpuArray(rand(1, 1024));  % Residency is on the GPU
115G = fft(g);                   % Falls back to host if provider FFT hooks are unavailable
116result = gather(G);
117```
118RunMat gathers the data from the device and performs the transform on the host unless the active
119provider advertises an FFT implementation. When the WGPU provider handles the FFT, the kernel executes
120on the device but the result is downloaded immediately so the builtin can return a MATLAB-compatible
121`ComplexTensor`.
122
123## FAQ
124
125### Does `fft` always return complex values?
126Yes. Even when the imaginary part is zero, the result is stored as a complex array to match MATLAB semantics.
127
128### What happens if I pass `[]` as the second argument?
129Passing `[]` leaves the transform length unchanged. This is equivalent to omitting the `n` parameter.
130
131### Can I transform along a dimension larger than the current rank?
132Yes. RunMat automatically treats trailing dimensions as length-1 and will create the requested dimension on output.
133
134### How does zero-padding work?
135When `n` is larger than the size of `X` along the transform dimension, RunMat pads with zeros before evaluating the FFT.
136
137### What precision is used for the FFT?
138RunMat computes FFTs in double precision on the host. Providers may use single or double precision depending on device capabilities.
139
140### Will RunMat run the FFT on my GPU automatically?
141When a provider installs an FFT hook, RunMat executes on the GPU. Otherwise, the runtime gathers the data and performs the transform on the CPU.
142
143### Is inverse FFT (`ifft`) available?
144`ifft` will be provided in a companion builtin. Until then, you can recover a time-domain signal by dividing by the length and taking the complex conjugate manually.
145
146### How do I compute multi-dimensional FFTs?
147Call `fft` repeatedly along each dimension (`fft(fft(X, [], 1), [], 2)` for a 2-D FFT). Future releases will add dedicated helpers.
148
149### Does `fft` support complex strides or non-unit sampling intervals?
150`fft` assumes unit spacing. You can multiply the result by appropriate phase factors to account for custom sampling intervals.
151
152## See Also
153[ifft](./ifft), [fftshift](./fftshift), [abs](../elementwise/abs), [angle](../elementwise/angle), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
154
155## Source & Feedback
156- Full source: `crates/runmat-runtime/src/builtins/math/fft/fft.rs`
157- Found an issue? [Open a ticket](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
158"#;
159
160pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
161    name: "fft",
162    op_kind: GpuOpKind::Custom("fft"),
163    supported_precisions: &[ScalarType::F32, ScalarType::F64],
164    broadcast: BroadcastSemantics::Matlab,
165    provider_hooks: &[ProviderHook::Custom("fft_dim")],
166    constant_strategy: ConstantStrategy::InlineLiteral,
167    residency: ResidencyPolicy::NewHandle,
168    nan_mode: ReductionNaN::Include,
169    two_pass_threshold: None,
170    workgroup_size: None,
171    accepts_nan_mode: false,
172    notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
173};
174
175register_builtin_gpu_spec!(GPU_SPEC);
176
177pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
178    name: "fft",
179    shape: ShapeRequirements::Any,
180    constant_strategy: ConstantStrategy::InlineLiteral,
181    elementwise: None,
182    reduction: None,
183    emits_nan: false,
184    notes:
185        "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
186};
187
188register_builtin_fusion_spec!(FUSION_SPEC);
189
190#[cfg(feature = "doc_export")]
191register_builtin_doc_text!("fft", DOC_MD);
192
193#[runtime_builtin(
194    name = "fft",
195    category = "math/fft",
196    summary = "Compute the discrete Fourier transform (DFT) of numeric or complex data.",
197    keywords = "fft,fourier transform,complex,gpu"
198)]
199fn fft_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
200    let (length, dimension) = parse_arguments(&rest)?;
201    match value {
202        Value::GpuTensor(handle) => fft_gpu(handle, length, dimension),
203        other => fft_host(other, length, dimension),
204    }
205}
206
207fn fft_host(
208    value: Value,
209    length: Option<usize>,
210    dimension: Option<usize>,
211) -> Result<Value, String> {
212    let tensor = value_to_complex_tensor(value, "fft")?;
213    let transformed = fft_complex_tensor(tensor, length, dimension)?;
214    Ok(complex_tensor_into_value(transformed))
215}
216
217fn fft_gpu(
218    handle: GpuTensorHandle,
219    length: Option<usize>,
220    dimension: Option<usize>,
221) -> Result<Value, String> {
222    let mut shape = if handle.shape.is_empty() {
223        vec![1]
224    } else {
225        handle.shape.clone()
226    };
227
228    let dim_one_based = match dimension {
229        Some(0) => return Err("fft: dimension must be >= 1".to_string()),
230        Some(dim) => dim,
231        None => default_dimension(&shape),
232    };
233
234    let dim_index = dim_one_based - 1;
235    while shape.len() <= dim_index {
236        shape.push(1);
237    }
238    let current_len = shape[dim_index];
239    let target_len = length.unwrap_or(current_len);
240
241    if target_len == 0 {
242        let tensor = gpu_helpers::gather_tensor(&handle)?;
243        let complex = tensor_to_complex_tensor(tensor, "fft")?;
244        let transformed = fft_complex_tensor(complex, length, dimension)?;
245        return Ok(complex_tensor_into_value(transformed));
246    }
247
248    if let Some(provider) = runmat_accelerate_api::provider() {
249        if let Ok(out) = provider.fft_dim(&handle, length, dim_index) {
250            let complex = fft_download_gpu_result(provider, &out)?;
251            return Ok(complex_tensor_into_value(complex));
252        }
253    }
254
255    let tensor = gpu_helpers::gather_tensor(&handle)?;
256    let complex = tensor_to_complex_tensor(tensor, "fft")?;
257    let transformed = fft_complex_tensor(complex, length, dimension)?;
258    Ok(complex_tensor_into_value(transformed))
259}
260
261pub(super) fn fft_download_gpu_result(
262    provider: &dyn AccelProvider,
263    handle: &GpuTensorHandle,
264) -> Result<ComplexTensor, String> {
265    let host = provider.download(handle).map_err(|e| format!("fft: {e}"))?;
266    provider.free(handle).ok();
267    runmat_accelerate_api::clear_residency(handle);
268    host_to_complex_tensor(host, "fft")
269}
270
271fn parse_arguments(args: &[Value]) -> Result<(Option<usize>, Option<usize>), String> {
272    match args.len() {
273        0 => Ok((None, None)),
274        1 => {
275            let len = parse_length(&args[0], "fft")?;
276            Ok((len, None))
277        }
278        2 => {
279            let len = parse_length(&args[0], "fft")?;
280            let dim = Some(tensor::parse_dimension(&args[1], "fft")?);
281            Ok((len, dim))
282        }
283        _ => Err("fft: expected fft(X), fft(X, N), or fft(X, N, DIM)".to_string()),
284    }
285}
286
287pub(super) fn fft_complex_tensor(
288    mut tensor: ComplexTensor,
289    length: Option<usize>,
290    dimension: Option<usize>,
291) -> Result<ComplexTensor, String> {
292    if tensor.shape.is_empty() {
293        tensor.shape = vec![tensor.data.len()];
294        tensor.rows = tensor.shape.first().copied().unwrap_or(1);
295        tensor.cols = tensor.shape.get(1).copied().unwrap_or(1);
296    }
297
298    let mut shape = tensor.shape.clone();
299    let origin_rank = shape.len();
300    let dim = match dimension {
301        Some(0) => return Err("fft: dimension must be >= 1".to_string()),
302        Some(dim) => dim - 1,
303        None => default_dimension(&shape) - 1,
304    };
305
306    while shape.len() <= dim {
307        shape.push(1);
308    }
309
310    let current_len = shape[dim];
311    let target_len = length.unwrap_or(current_len);
312
313    if target_len == 0 {
314        let mut out_shape = shape;
315        out_shape[dim] = 0;
316        trim_trailing_ones(&mut out_shape, origin_rank);
317        return ComplexTensor::new(Vec::<(f64, f64)>::new(), out_shape)
318            .map_err(|e| format!("fft: {e}"));
319    }
320
321    let inner_stride = shape[..dim]
322        .iter()
323        .copied()
324        .fold(1usize, |acc, dim| acc.saturating_mul(dim));
325    let outer_stride = shape[dim + 1..]
326        .iter()
327        .copied()
328        .fold(1usize, |acc, dim| acc.saturating_mul(dim));
329    let num_slices = inner_stride.saturating_mul(outer_stride);
330
331    let input = tensor
332        .data
333        .into_iter()
334        .map(|(re, im)| Complex::new(re, im))
335        .collect::<Vec<_>>();
336
337    if num_slices == 0 {
338        let mut out_shape = shape;
339        out_shape[dim] = target_len;
340        trim_trailing_ones(&mut out_shape, origin_rank);
341        let data = vec![(0.0, 0.0); 0];
342        return ComplexTensor::new(data, out_shape).map_err(|e| format!("fft: {e}"));
343    }
344
345    let output_len = target_len.saturating_mul(num_slices);
346    let mut output = vec![Complex::new(0.0, 0.0); output_len];
347
348    let mut planner = FftPlanner::<f64>::new();
349    let fft_plan: Option<Arc<dyn rustfft::Fft<f64>>> = if target_len > 1 {
350        Some(planner.plan_fft_forward(target_len))
351    } else {
352        None
353    };
354
355    let copy_len = current_len.min(target_len);
356    let mut buffer = vec![Complex::new(0.0, 0.0); target_len];
357
358    for outer in 0..outer_stride {
359        let base_in = outer.saturating_mul(current_len.saturating_mul(inner_stride));
360        let base_out = outer.saturating_mul(target_len.saturating_mul(inner_stride));
361        for inner in 0..inner_stride {
362            buffer.iter_mut().for_each(|c| *c = Complex::new(0.0, 0.0));
363            for (k, slot) in buffer.iter_mut().enumerate().take(copy_len) {
364                let src_idx = base_in + inner + k * inner_stride;
365                if src_idx < input.len() {
366                    *slot = input[src_idx];
367                }
368            }
369            if target_len > 1 {
370                if let Some(plan) = &fft_plan {
371                    plan.process(&mut buffer);
372                }
373            }
374            for (k, value) in buffer.iter().enumerate().take(target_len) {
375                let dst_idx = base_out + inner + k * inner_stride;
376                if dst_idx < output.len() {
377                    output[dst_idx] = *value;
378                }
379            }
380        }
381    }
382
383    let mut out_shape = shape;
384    out_shape[dim] = target_len;
385    trim_trailing_ones(&mut out_shape, origin_rank.max(dim + 1));
386
387    let data = output.into_iter().map(|c| (c.re, c.im)).collect::<Vec<_>>();
388    ComplexTensor::new(data, out_shape).map_err(|e| format!("fft: {e}"))
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::builtins::common::test_support;
395    use num_complex::Complex;
396    use runmat_builtins::{ComplexTensor as HostComplexTensor, IntValue, Tensor};
397    use rustfft::FftPlanner;
398
399    fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
400        (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
401    }
402
403    fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
404        match value {
405            Value::ComplexTensor(tensor) => tensor,
406            Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
407            other => panic!("expected complex tensor, got {other:?}"),
408        }
409    }
410
411    #[test]
412    fn fft_real_vector() {
413        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
414        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
415        match result {
416            Value::ComplexTensor(ct) => {
417                assert_eq!(ct.shape, vec![4]);
418                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
419                for (idx, val) in ct.data.iter().enumerate() {
420                    assert!(
421                        approx_eq(*val, expected[idx], 1e-12),
422                        "idx {idx} {:?} ~= {:?}",
423                        val,
424                        expected[idx]
425                    );
426                }
427            }
428            other => panic!("expected complex tensor, got {other:?}"),
429        }
430    }
431
432    #[test]
433    fn fft_matrix_default_dimension() {
434        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
435        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
436        match result {
437            Value::ComplexTensor(ct) => {
438                assert_eq!(ct.shape, vec![2, 3]);
439                let expected = [
440                    (5.0, 0.0),
441                    (-3.0, 0.0),
442                    (7.0, 0.0),
443                    (-3.0, 0.0),
444                    (9.0, 0.0),
445                    (-3.0, 0.0),
446                ];
447                for (idx, val) in ct.data.iter().enumerate() {
448                    assert!(approx_eq(*val, expected[idx], 1e-12));
449                }
450            }
451            other => panic!("expected complex tensor, got {other:?}"),
452        }
453    }
454
455    #[test]
456    fn fft_zero_padding_with_length_argument() {
457        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
458        let result =
459            fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
460        match result {
461            Value::ComplexTensor(ct) => {
462                assert_eq!(ct.shape, vec![5]);
463                assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
464                assert_eq!(ct.data.len(), 5);
465            }
466            other => panic!("expected complex tensor, got {other:?}"),
467        }
468    }
469
470    #[test]
471    fn fft_empty_length_argument_defaults_to_input_length() {
472        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
473        let baseline =
474            fft_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
475        let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
476        let result = fft_builtin(
477            Value::Tensor(tensor),
478            vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
479        )
480        .expect("fft with empty length");
481        let base_ct = value_as_complex_tensor(baseline);
482        let result_ct = value_as_complex_tensor(result);
483        assert_eq!(base_ct.shape, result_ct.shape);
484        assert_eq!(base_ct.data.len(), result_ct.data.len());
485        for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
486            assert!(
487                approx_eq(*a, *b, 1e-12),
488                "mismatch at index {idx}: {:?} vs {:?}",
489                a,
490                b
491            );
492        }
493    }
494
495    #[test]
496    fn fft_truncates_when_length_smaller() {
497        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
498        let result =
499            fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
500        match result {
501            Value::ComplexTensor(ct) => {
502                assert_eq!(ct.shape, vec![2]);
503                let expected = [(3.0, 0.0), (-1.0, 0.0)];
504                for (idx, val) in ct.data.iter().enumerate() {
505                    assert!(approx_eq(*val, expected[idx], 1e-12));
506                }
507            }
508            other => panic!("expected complex tensor, got {other:?}"),
509        }
510    }
511
512    #[test]
513    fn fft_zero_length_returns_empty_tensor() {
514        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
515        let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
516        match result {
517            Value::ComplexTensor(ct) => {
518                assert_eq!(ct.shape, vec![0]);
519                assert!(ct.data.is_empty());
520            }
521            other => panic!("expected complex tensor, got {other:?}"),
522        }
523    }
524
525    #[test]
526    fn fft_complex_input_preserves_imaginary_components() {
527        let tensor =
528            HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
529        let result =
530            fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
531        let mut expected = tensor
532            .data
533            .iter()
534            .map(|(re, im)| Complex::new(*re, *im))
535            .collect::<Vec<_>>();
536        FftPlanner::<f64>::new()
537            .plan_fft_forward(expected.len())
538            .process(&mut expected);
539        match result {
540            Value::ComplexTensor(ct) => {
541                assert_eq!(ct.shape, vec![3]);
542                assert_eq!(ct.data.len(), 3);
543                for (idx, val) in ct.data.iter().enumerate() {
544                    let exp = expected[idx];
545                    assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
546                }
547            }
548            other => panic!("expected complex tensor, got {other:?}"),
549        }
550    }
551
552    #[test]
553    fn fft_row_vector_dimension_two() {
554        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
555        let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
556        match result {
557            Value::ComplexTensor(ct) => {
558                assert_eq!(ct.shape, vec![1, 4]);
559                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
560                for (idx, val) in ct.data.iter().enumerate() {
561                    assert!(approx_eq(*val, expected[idx], 1e-12));
562                }
563            }
564            other => panic!("expected complex tensor, got {other:?}"),
565        }
566    }
567
568    #[test]
569    fn fft_dimension_extends_rank() {
570        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
571        let original = tensor.clone();
572        let result =
573            fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
574        match result {
575            Value::ComplexTensor(ct) => {
576                assert_eq!(ct.shape, vec![1, 4, 1]);
577                assert_eq!(ct.data.len(), original.data.len());
578                for (idx, (re, im)) in ct.data.iter().enumerate() {
579                    assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
580                }
581            }
582            other => panic!("expected complex tensor, got {other:?}"),
583        }
584    }
585
586    #[test]
587    fn fft_dimension_extends_rank_with_padding() {
588        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
589        let original = tensor.clone();
590        let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
591            .expect("fft with padded third dimension");
592        match result {
593            Value::ComplexTensor(ct) => {
594                assert_eq!(ct.shape, vec![1, 4, 4]);
595                let mut expected = Vec::with_capacity(16);
596                for _depth in 0..4 {
597                    for &value in &original.data {
598                        expected.push((value, 0.0));
599                    }
600                }
601                assert_eq!(ct.data.len(), expected.len());
602                for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
603                    assert!(
604                        approx_eq(*actual, *expected, 1e-12),
605                        "idx {idx}: {:?} != {:?}",
606                        actual,
607                        expected
608                    );
609                }
610            }
611            other => panic!("expected complex tensor, got {other:?}"),
612        }
613    }
614
615    #[test]
616    fn fft_rejects_non_numeric_length() {
617        assert!(parse_arguments(&[Value::Bool(true)]).is_err());
618    }
619
620    #[test]
621    fn fft_rejects_negative_length() {
622        let err = parse_arguments(&[Value::Num(-1.0)]).unwrap_err();
623        assert!(err.contains("length must be non-negative"));
624    }
625
626    #[test]
627    fn fft_rejects_fractional_length() {
628        let err = parse_arguments(&[Value::Num(1.5)]).unwrap_err();
629        assert!(err.contains("length must be an integer"));
630    }
631
632    #[test]
633    fn fft_rejects_dimension_zero() {
634        let err = parse_arguments(&[Value::Num(4.0), Value::Int(IntValue::I32(0))]).unwrap_err();
635        assert!(err.contains("dimension must be >= 1"));
636    }
637
638    #[test]
639    fn fft_gpu_roundtrip_matches_cpu() {
640        test_support::with_test_provider(|provider| {
641            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
642            let view = runmat_accelerate_api::HostTensorView {
643                data: &tensor.data,
644                shape: &tensor.shape,
645            };
646            let handle = provider.upload(&view).expect("upload");
647            let gpu = fft_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
648            let cpu = fft_builtin(Value::Tensor(tensor), Vec::new()).expect("fft");
649            let gpu_host = value_as_complex_tensor(gpu);
650            let cpu_host = value_as_complex_tensor(cpu);
651            assert_eq!(gpu_host.shape, cpu_host.shape);
652            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
653                assert!(approx_eq(*a, *b, 1e-12));
654            }
655            provider.free(&handle).ok();
656        });
657    }
658
659    #[test]
660    #[cfg(feature = "wgpu")]
661    fn fft_wgpu_matches_cpu() {
662        if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
663            .expect("wgpu provider")
664        {
665            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
666            let tensor_cpu = tensor.clone();
667            let view = runmat_accelerate_api::HostTensorView {
668                data: &tensor.data,
669                shape: &tensor.shape,
670            };
671            let handle = provider.upload(&view).expect("upload");
672            let gpu = fft_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
673            let cpu = fft_builtin(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
674            let gpu_ct = value_as_complex_tensor(gpu);
675            let cpu_ct = value_as_complex_tensor(cpu);
676            assert_eq!(gpu_ct.shape, cpu_ct.shape);
677            for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
678                assert!(approx_eq(*a, *b, 1e-9));
679            }
680            provider.free(&handle).ok();
681        }
682    }
683
684    #[test]
685    #[cfg(feature = "doc_export")]
686    fn doc_examples_present() {
687        let blocks = test_support::doc_examples(DOC_MD);
688        assert!(!blocks.is_empty());
689    }
690}