1use super::common::{
4 default_dimension, host_to_complex_tensor, parse_length, tensor_to_complex_tensor,
5 trim_trailing_ones, value_to_complex_tensor,
6};
7use num_complex::Complex;
8use runmat_accelerate_api::{AccelProvider, GpuTensorHandle};
9use runmat_builtins::{ComplexTensor, Value};
10use runmat_macros::runtime_builtin;
11use rustfft::FftPlanner;
12use std::sync::Arc;
13
14use crate::builtins::common::random_args::complex_tensor_into_value;
15use crate::builtins::common::spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{gpu_helpers, tensor};
20#[cfg(feature = "doc_export")]
21use crate::register_builtin_doc_text;
22use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
23
24#[cfg(feature = "doc_export")]
25pub const DOC_MD: &str = r#"---
26title: "fft"
27category: "math/fft"
28keywords: ["fft", "fourier transform", "complex", "zero padding", "gpu"]
29summary: "Compute the discrete Fourier transform (DFT) of vectors, matrices, or N-D tensors."
30references:
31 - title: "MATLAB fft documentation"
32 url: "https://www.mathworks.com/help/matlab/ref/fft.html"
33gpu_support:
34 elementwise: false
35 reduction: false
36 precisions: ["f32", "f64"]
37 broadcasting: "matlab"
38 notes: "Falls back to host execution when the active acceleration provider does not expose an FFT hook."
39fusion:
40 elementwise: false
41 reduction: false
42 max_inputs: 1
43 constants: "inline"
44requires_feature: null
45tested:
46 unit: "builtins::math::fft::fft::tests"
47 integration: "builtins::math::fft::fft::tests::fft_gpu_roundtrip_matches_cpu"
48---
49
50# What does the `fft` function do in MATLAB / RunMat?
51`fft(X)` computes the discrete Fourier transform (DFT) of the input data. When `X` is a vector,
52`fft` returns the frequency-domain representation of the vector. When `X` is a matrix or an
53N-D tensor, the transform is applied along the first non-singleton dimension unless another
54dimension is specified.
55
56## How does the `fft` function behave in MATLAB / RunMat?
57- `fft(X)` transforms along the first dimension whose size is greater than 1.
58- `fft(X, n)` zero-pads or truncates `X` to length `n` before transforming along the default dimension.
59- `fft(X, n, dim)` applies the transform along dimension `dim`.
60- Real inputs produce complex outputs; complex inputs are handled element-wise with no additional conversion.
61- Empty inputs remain empty; zero-padding with `n` produces zero-valued spectra.
62- GPU arrays are gathered to the host when the selected provider has no FFT implementation.
63
64## Examples of using the `fft` function in MATLAB / RunMat
65
66### Computing the FFT of a real time-domain vector
67```matlab
68x = [1 2 3 4];
69Y = fft(x);
70```
71Expected output (RunMat prints complex numbers with `a + bi` formatting):
72```matlab
73Y =
74 Columns 1 through 4
75 10 + 0i -2 + 2i -2 + 0i -2 - 2i
76```
77
78### Applying fft column-wise to a matrix
79```matlab
80A = [1 2 3; 4 5 6];
81F = fft(A);
82```
83Expected output:
84```matlab
85F =
86 5 + 0i 7 + 0i 9 + 0i
87 -3 + 3i -3 + 3i -3 + 3i
88```
89
90### Zero-padding before the FFT
91```matlab
92x = [1 2 3];
93Y = fft(x, 5);
94```
95The transform is computed on a length-5 sequence `[1 2 3 0 0]`, producing five complex frequency bins.
96
97### Selecting the transform dimension for a row vector
98```matlab
99x = [1 2 3 4];
100Y = fft(x, [], 2);
101```
102`Y` matches `fft(x)` because the transform is applied along dimension 2 (the row).
103
104### FFT of a complex-valued signal
105```matlab
106t = 0:3;
107x = exp(1i * pi/2 * t);
108Y = fft(x);
109```
110The complex sinusoid is mapped to a single non-zero frequency bin at the expected location.
111
112### FFT with gpuArray inputs
113```matlab
114g = gpuArray(rand(1, 1024)); % Residency is on the GPU
115G = fft(g); % Falls back to host if provider FFT hooks are unavailable
116result = gather(G);
117```
118RunMat gathers the data from the device and performs the transform on the host unless the active
119provider advertises an FFT implementation. When the WGPU provider handles the FFT, the kernel executes
120on the device but the result is downloaded immediately so the builtin can return a MATLAB-compatible
121`ComplexTensor`.
122
123## FAQ
124
125### Does `fft` always return complex values?
126Yes. Even when the imaginary part is zero, the result is stored as a complex array to match MATLAB semantics.
127
128### What happens if I pass `[]` as the second argument?
129Passing `[]` leaves the transform length unchanged. This is equivalent to omitting the `n` parameter.
130
131### Can I transform along a dimension larger than the current rank?
132Yes. RunMat automatically treats trailing dimensions as length-1 and will create the requested dimension on output.
133
134### How does zero-padding work?
135When `n` is larger than the size of `X` along the transform dimension, RunMat pads with zeros before evaluating the FFT.
136
137### What precision is used for the FFT?
138RunMat computes FFTs in double precision on the host. Providers may use single or double precision depending on device capabilities.
139
140### Will RunMat run the FFT on my GPU automatically?
141When a provider installs an FFT hook, RunMat executes on the GPU. Otherwise, the runtime gathers the data and performs the transform on the CPU.
142
143### Is inverse FFT (`ifft`) available?
144`ifft` will be provided in a companion builtin. Until then, you can recover a time-domain signal by dividing by the length and taking the complex conjugate manually.
145
146### How do I compute multi-dimensional FFTs?
147Call `fft` repeatedly along each dimension (`fft(fft(X, [], 1), [], 2)` for a 2-D FFT). Future releases will add dedicated helpers.
148
149### Does `fft` support complex strides or non-unit sampling intervals?
150`fft` assumes unit spacing. You can multiply the result by appropriate phase factors to account for custom sampling intervals.
151
152## See Also
153[ifft](./ifft), [fftshift](./fftshift), [abs](../elementwise/abs), [angle](../elementwise/angle), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
154
155## Source & Feedback
156- Full source: `crates/runmat-runtime/src/builtins/math/fft/fft.rs`
157- Found an issue? [Open a ticket](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
158"#;
159
160pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
161 name: "fft",
162 op_kind: GpuOpKind::Custom("fft"),
163 supported_precisions: &[ScalarType::F32, ScalarType::F64],
164 broadcast: BroadcastSemantics::Matlab,
165 provider_hooks: &[ProviderHook::Custom("fft_dim")],
166 constant_strategy: ConstantStrategy::InlineLiteral,
167 residency: ResidencyPolicy::NewHandle,
168 nan_mode: ReductionNaN::Include,
169 two_pass_threshold: None,
170 workgroup_size: None,
171 accepts_nan_mode: false,
172 notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
173};
174
175register_builtin_gpu_spec!(GPU_SPEC);
176
177pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
178 name: "fft",
179 shape: ShapeRequirements::Any,
180 constant_strategy: ConstantStrategy::InlineLiteral,
181 elementwise: None,
182 reduction: None,
183 emits_nan: false,
184 notes:
185 "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
186};
187
188register_builtin_fusion_spec!(FUSION_SPEC);
189
190#[cfg(feature = "doc_export")]
191register_builtin_doc_text!("fft", DOC_MD);
192
193#[runtime_builtin(
194 name = "fft",
195 category = "math/fft",
196 summary = "Compute the discrete Fourier transform (DFT) of numeric or complex data.",
197 keywords = "fft,fourier transform,complex,gpu"
198)]
199fn fft_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
200 let (length, dimension) = parse_arguments(&rest)?;
201 match value {
202 Value::GpuTensor(handle) => fft_gpu(handle, length, dimension),
203 other => fft_host(other, length, dimension),
204 }
205}
206
207fn fft_host(
208 value: Value,
209 length: Option<usize>,
210 dimension: Option<usize>,
211) -> Result<Value, String> {
212 let tensor = value_to_complex_tensor(value, "fft")?;
213 let transformed = fft_complex_tensor(tensor, length, dimension)?;
214 Ok(complex_tensor_into_value(transformed))
215}
216
217fn fft_gpu(
218 handle: GpuTensorHandle,
219 length: Option<usize>,
220 dimension: Option<usize>,
221) -> Result<Value, String> {
222 let mut shape = if handle.shape.is_empty() {
223 vec![1]
224 } else {
225 handle.shape.clone()
226 };
227
228 let dim_one_based = match dimension {
229 Some(0) => return Err("fft: dimension must be >= 1".to_string()),
230 Some(dim) => dim,
231 None => default_dimension(&shape),
232 };
233
234 let dim_index = dim_one_based - 1;
235 while shape.len() <= dim_index {
236 shape.push(1);
237 }
238 let current_len = shape[dim_index];
239 let target_len = length.unwrap_or(current_len);
240
241 if target_len == 0 {
242 let tensor = gpu_helpers::gather_tensor(&handle)?;
243 let complex = tensor_to_complex_tensor(tensor, "fft")?;
244 let transformed = fft_complex_tensor(complex, length, dimension)?;
245 return Ok(complex_tensor_into_value(transformed));
246 }
247
248 if let Some(provider) = runmat_accelerate_api::provider() {
249 if let Ok(out) = provider.fft_dim(&handle, length, dim_index) {
250 let complex = fft_download_gpu_result(provider, &out)?;
251 return Ok(complex_tensor_into_value(complex));
252 }
253 }
254
255 let tensor = gpu_helpers::gather_tensor(&handle)?;
256 let complex = tensor_to_complex_tensor(tensor, "fft")?;
257 let transformed = fft_complex_tensor(complex, length, dimension)?;
258 Ok(complex_tensor_into_value(transformed))
259}
260
261pub(super) fn fft_download_gpu_result(
262 provider: &dyn AccelProvider,
263 handle: &GpuTensorHandle,
264) -> Result<ComplexTensor, String> {
265 let host = provider.download(handle).map_err(|e| format!("fft: {e}"))?;
266 provider.free(handle).ok();
267 runmat_accelerate_api::clear_residency(handle);
268 host_to_complex_tensor(host, "fft")
269}
270
271fn parse_arguments(args: &[Value]) -> Result<(Option<usize>, Option<usize>), String> {
272 match args.len() {
273 0 => Ok((None, None)),
274 1 => {
275 let len = parse_length(&args[0], "fft")?;
276 Ok((len, None))
277 }
278 2 => {
279 let len = parse_length(&args[0], "fft")?;
280 let dim = Some(tensor::parse_dimension(&args[1], "fft")?);
281 Ok((len, dim))
282 }
283 _ => Err("fft: expected fft(X), fft(X, N), or fft(X, N, DIM)".to_string()),
284 }
285}
286
287pub(super) fn fft_complex_tensor(
288 mut tensor: ComplexTensor,
289 length: Option<usize>,
290 dimension: Option<usize>,
291) -> Result<ComplexTensor, String> {
292 if tensor.shape.is_empty() {
293 tensor.shape = vec![tensor.data.len()];
294 tensor.rows = tensor.shape.first().copied().unwrap_or(1);
295 tensor.cols = tensor.shape.get(1).copied().unwrap_or(1);
296 }
297
298 let mut shape = tensor.shape.clone();
299 let origin_rank = shape.len();
300 let dim = match dimension {
301 Some(0) => return Err("fft: dimension must be >= 1".to_string()),
302 Some(dim) => dim - 1,
303 None => default_dimension(&shape) - 1,
304 };
305
306 while shape.len() <= dim {
307 shape.push(1);
308 }
309
310 let current_len = shape[dim];
311 let target_len = length.unwrap_or(current_len);
312
313 if target_len == 0 {
314 let mut out_shape = shape;
315 out_shape[dim] = 0;
316 trim_trailing_ones(&mut out_shape, origin_rank);
317 return ComplexTensor::new(Vec::<(f64, f64)>::new(), out_shape)
318 .map_err(|e| format!("fft: {e}"));
319 }
320
321 let inner_stride = shape[..dim]
322 .iter()
323 .copied()
324 .fold(1usize, |acc, dim| acc.saturating_mul(dim));
325 let outer_stride = shape[dim + 1..]
326 .iter()
327 .copied()
328 .fold(1usize, |acc, dim| acc.saturating_mul(dim));
329 let num_slices = inner_stride.saturating_mul(outer_stride);
330
331 let input = tensor
332 .data
333 .into_iter()
334 .map(|(re, im)| Complex::new(re, im))
335 .collect::<Vec<_>>();
336
337 if num_slices == 0 {
338 let mut out_shape = shape;
339 out_shape[dim] = target_len;
340 trim_trailing_ones(&mut out_shape, origin_rank);
341 let data = vec![(0.0, 0.0); 0];
342 return ComplexTensor::new(data, out_shape).map_err(|e| format!("fft: {e}"));
343 }
344
345 let output_len = target_len.saturating_mul(num_slices);
346 let mut output = vec![Complex::new(0.0, 0.0); output_len];
347
348 let mut planner = FftPlanner::<f64>::new();
349 let fft_plan: Option<Arc<dyn rustfft::Fft<f64>>> = if target_len > 1 {
350 Some(planner.plan_fft_forward(target_len))
351 } else {
352 None
353 };
354
355 let copy_len = current_len.min(target_len);
356 let mut buffer = vec![Complex::new(0.0, 0.0); target_len];
357
358 for outer in 0..outer_stride {
359 let base_in = outer.saturating_mul(current_len.saturating_mul(inner_stride));
360 let base_out = outer.saturating_mul(target_len.saturating_mul(inner_stride));
361 for inner in 0..inner_stride {
362 buffer.iter_mut().for_each(|c| *c = Complex::new(0.0, 0.0));
363 for (k, slot) in buffer.iter_mut().enumerate().take(copy_len) {
364 let src_idx = base_in + inner + k * inner_stride;
365 if src_idx < input.len() {
366 *slot = input[src_idx];
367 }
368 }
369 if target_len > 1 {
370 if let Some(plan) = &fft_plan {
371 plan.process(&mut buffer);
372 }
373 }
374 for (k, value) in buffer.iter().enumerate().take(target_len) {
375 let dst_idx = base_out + inner + k * inner_stride;
376 if dst_idx < output.len() {
377 output[dst_idx] = *value;
378 }
379 }
380 }
381 }
382
383 let mut out_shape = shape;
384 out_shape[dim] = target_len;
385 trim_trailing_ones(&mut out_shape, origin_rank.max(dim + 1));
386
387 let data = output.into_iter().map(|c| (c.re, c.im)).collect::<Vec<_>>();
388 ComplexTensor::new(data, out_shape).map_err(|e| format!("fft: {e}"))
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::builtins::common::test_support;
395 use num_complex::Complex;
396 use runmat_builtins::{ComplexTensor as HostComplexTensor, IntValue, Tensor};
397 use rustfft::FftPlanner;
398
399 fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
400 (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
401 }
402
403 fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
404 match value {
405 Value::ComplexTensor(tensor) => tensor,
406 Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
407 other => panic!("expected complex tensor, got {other:?}"),
408 }
409 }
410
411 #[test]
412 fn fft_real_vector() {
413 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
414 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
415 match result {
416 Value::ComplexTensor(ct) => {
417 assert_eq!(ct.shape, vec![4]);
418 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
419 for (idx, val) in ct.data.iter().enumerate() {
420 assert!(
421 approx_eq(*val, expected[idx], 1e-12),
422 "idx {idx} {:?} ~= {:?}",
423 val,
424 expected[idx]
425 );
426 }
427 }
428 other => panic!("expected complex tensor, got {other:?}"),
429 }
430 }
431
432 #[test]
433 fn fft_matrix_default_dimension() {
434 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
435 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
436 match result {
437 Value::ComplexTensor(ct) => {
438 assert_eq!(ct.shape, vec![2, 3]);
439 let expected = [
440 (5.0, 0.0),
441 (-3.0, 0.0),
442 (7.0, 0.0),
443 (-3.0, 0.0),
444 (9.0, 0.0),
445 (-3.0, 0.0),
446 ];
447 for (idx, val) in ct.data.iter().enumerate() {
448 assert!(approx_eq(*val, expected[idx], 1e-12));
449 }
450 }
451 other => panic!("expected complex tensor, got {other:?}"),
452 }
453 }
454
455 #[test]
456 fn fft_zero_padding_with_length_argument() {
457 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
458 let result =
459 fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
460 match result {
461 Value::ComplexTensor(ct) => {
462 assert_eq!(ct.shape, vec![5]);
463 assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
464 assert_eq!(ct.data.len(), 5);
465 }
466 other => panic!("expected complex tensor, got {other:?}"),
467 }
468 }
469
470 #[test]
471 fn fft_empty_length_argument_defaults_to_input_length() {
472 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
473 let baseline =
474 fft_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
475 let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
476 let result = fft_builtin(
477 Value::Tensor(tensor),
478 vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
479 )
480 .expect("fft with empty length");
481 let base_ct = value_as_complex_tensor(baseline);
482 let result_ct = value_as_complex_tensor(result);
483 assert_eq!(base_ct.shape, result_ct.shape);
484 assert_eq!(base_ct.data.len(), result_ct.data.len());
485 for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
486 assert!(
487 approx_eq(*a, *b, 1e-12),
488 "mismatch at index {idx}: {:?} vs {:?}",
489 a,
490 b
491 );
492 }
493 }
494
495 #[test]
496 fn fft_truncates_when_length_smaller() {
497 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
498 let result =
499 fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
500 match result {
501 Value::ComplexTensor(ct) => {
502 assert_eq!(ct.shape, vec![2]);
503 let expected = [(3.0, 0.0), (-1.0, 0.0)];
504 for (idx, val) in ct.data.iter().enumerate() {
505 assert!(approx_eq(*val, expected[idx], 1e-12));
506 }
507 }
508 other => panic!("expected complex tensor, got {other:?}"),
509 }
510 }
511
512 #[test]
513 fn fft_zero_length_returns_empty_tensor() {
514 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
515 let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
516 match result {
517 Value::ComplexTensor(ct) => {
518 assert_eq!(ct.shape, vec![0]);
519 assert!(ct.data.is_empty());
520 }
521 other => panic!("expected complex tensor, got {other:?}"),
522 }
523 }
524
525 #[test]
526 fn fft_complex_input_preserves_imaginary_components() {
527 let tensor =
528 HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
529 let result =
530 fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
531 let mut expected = tensor
532 .data
533 .iter()
534 .map(|(re, im)| Complex::new(*re, *im))
535 .collect::<Vec<_>>();
536 FftPlanner::<f64>::new()
537 .plan_fft_forward(expected.len())
538 .process(&mut expected);
539 match result {
540 Value::ComplexTensor(ct) => {
541 assert_eq!(ct.shape, vec![3]);
542 assert_eq!(ct.data.len(), 3);
543 for (idx, val) in ct.data.iter().enumerate() {
544 let exp = expected[idx];
545 assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
546 }
547 }
548 other => panic!("expected complex tensor, got {other:?}"),
549 }
550 }
551
552 #[test]
553 fn fft_row_vector_dimension_two() {
554 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
555 let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
556 match result {
557 Value::ComplexTensor(ct) => {
558 assert_eq!(ct.shape, vec![1, 4]);
559 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
560 for (idx, val) in ct.data.iter().enumerate() {
561 assert!(approx_eq(*val, expected[idx], 1e-12));
562 }
563 }
564 other => panic!("expected complex tensor, got {other:?}"),
565 }
566 }
567
568 #[test]
569 fn fft_dimension_extends_rank() {
570 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
571 let original = tensor.clone();
572 let result =
573 fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
574 match result {
575 Value::ComplexTensor(ct) => {
576 assert_eq!(ct.shape, vec![1, 4, 1]);
577 assert_eq!(ct.data.len(), original.data.len());
578 for (idx, (re, im)) in ct.data.iter().enumerate() {
579 assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
580 }
581 }
582 other => panic!("expected complex tensor, got {other:?}"),
583 }
584 }
585
586 #[test]
587 fn fft_dimension_extends_rank_with_padding() {
588 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
589 let original = tensor.clone();
590 let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
591 .expect("fft with padded third dimension");
592 match result {
593 Value::ComplexTensor(ct) => {
594 assert_eq!(ct.shape, vec![1, 4, 4]);
595 let mut expected = Vec::with_capacity(16);
596 for _depth in 0..4 {
597 for &value in &original.data {
598 expected.push((value, 0.0));
599 }
600 }
601 assert_eq!(ct.data.len(), expected.len());
602 for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
603 assert!(
604 approx_eq(*actual, *expected, 1e-12),
605 "idx {idx}: {:?} != {:?}",
606 actual,
607 expected
608 );
609 }
610 }
611 other => panic!("expected complex tensor, got {other:?}"),
612 }
613 }
614
615 #[test]
616 fn fft_rejects_non_numeric_length() {
617 assert!(parse_arguments(&[Value::Bool(true)]).is_err());
618 }
619
620 #[test]
621 fn fft_rejects_negative_length() {
622 let err = parse_arguments(&[Value::Num(-1.0)]).unwrap_err();
623 assert!(err.contains("length must be non-negative"));
624 }
625
626 #[test]
627 fn fft_rejects_fractional_length() {
628 let err = parse_arguments(&[Value::Num(1.5)]).unwrap_err();
629 assert!(err.contains("length must be an integer"));
630 }
631
632 #[test]
633 fn fft_rejects_dimension_zero() {
634 let err = parse_arguments(&[Value::Num(4.0), Value::Int(IntValue::I32(0))]).unwrap_err();
635 assert!(err.contains("dimension must be >= 1"));
636 }
637
638 #[test]
639 fn fft_gpu_roundtrip_matches_cpu() {
640 test_support::with_test_provider(|provider| {
641 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
642 let view = runmat_accelerate_api::HostTensorView {
643 data: &tensor.data,
644 shape: &tensor.shape,
645 };
646 let handle = provider.upload(&view).expect("upload");
647 let gpu = fft_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
648 let cpu = fft_builtin(Value::Tensor(tensor), Vec::new()).expect("fft");
649 let gpu_host = value_as_complex_tensor(gpu);
650 let cpu_host = value_as_complex_tensor(cpu);
651 assert_eq!(gpu_host.shape, cpu_host.shape);
652 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
653 assert!(approx_eq(*a, *b, 1e-12));
654 }
655 provider.free(&handle).ok();
656 });
657 }
658
659 #[test]
660 #[cfg(feature = "wgpu")]
661 fn fft_wgpu_matches_cpu() {
662 if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
663 .expect("wgpu provider")
664 {
665 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
666 let tensor_cpu = tensor.clone();
667 let view = runmat_accelerate_api::HostTensorView {
668 data: &tensor.data,
669 shape: &tensor.shape,
670 };
671 let handle = provider.upload(&view).expect("upload");
672 let gpu = fft_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
673 let cpu = fft_builtin(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
674 let gpu_ct = value_as_complex_tensor(gpu);
675 let cpu_ct = value_as_complex_tensor(cpu);
676 assert_eq!(gpu_ct.shape, cpu_ct.shape);
677 for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
678 assert!(approx_eq(*a, *b, 1e-9));
679 }
680 provider.free(&handle).ok();
681 }
682 }
683
684 #[test]
685 #[cfg(feature = "doc_export")]
686 fn doc_examples_present() {
687 let blocks = test_support::doc_examples(DOC_MD);
688 assert!(!blocks.is_empty());
689 }
690}