Skip to main content

runmat_runtime/builtins/math/fft/
forward.rs

1//! MATLAB-compatible `fft` builtin with GPU-aware semantics for RunMat.
2
3use super::common::{
4    default_dimension, gather_gpu_complex_tensor, parse_length, transform_complex_tensor,
5    value_to_complex_tensor, TransformDirection,
6};
7use runmat_accelerate_api::GpuTensorHandle;
8use runmat_builtins::{
9    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11    ComplexTensor, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use crate::builtins::common::random_args::complex_tensor_into_value;
16use crate::builtins::common::spec::{
17    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::{shape::normalize_scalar_shape, tensor};
21use crate::builtins::math::fft::type_resolvers::fft_type;
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::forward")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "fft",
27    op_kind: GpuOpKind::Custom("fft"),
28    supported_precisions: &[ScalarType::F32, ScalarType::F64],
29    broadcast: BroadcastSemantics::Matlab,
30    provider_hooks: &[ProviderHook::Custom("fft_dim")],
31    constant_strategy: ConstantStrategy::InlineLiteral,
32    residency: ResidencyPolicy::NewHandle,
33    nan_mode: ReductionNaN::Include,
34    two_pass_threshold: None,
35    workgroup_size: None,
36    accepts_nan_mode: false,
37    notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::forward")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42    name: "fft",
43    shape: ShapeRequirements::Any,
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    elementwise: None,
46    reduction: None,
47    emits_nan: false,
48    notes:
49        "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
50};
51
52const BUILTIN_NAME: &str = "fft";
53
54const FFT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55    name: "Y",
56    ty: BuiltinParamType::NumericArray,
57    arity: BuiltinParamArity::Required,
58    default: None,
59    description: "Complex Fourier spectrum output.",
60}];
61
62const FFT_INPUTS_CORE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
63    name: "X",
64    ty: BuiltinParamType::Any,
65    arity: BuiltinParamArity::Required,
66    default: None,
67    description: "Input signal/array.",
68}];
69
70const FFT_INPUTS_WITH_N: [BuiltinParamDescriptor; 2] = [
71    BuiltinParamDescriptor {
72        name: "X",
73        ty: BuiltinParamType::Any,
74        arity: BuiltinParamArity::Required,
75        default: None,
76        description: "Input signal/array.",
77    },
78    BuiltinParamDescriptor {
79        name: "N",
80        ty: BuiltinParamType::NumericScalar,
81        arity: BuiltinParamArity::Optional,
82        default: Some("[]"),
83        description: "Transform length along selected dimension.",
84    },
85];
86
87const FFT_INPUTS_WITH_N_DIM: [BuiltinParamDescriptor; 3] = [
88    BuiltinParamDescriptor {
89        name: "X",
90        ty: BuiltinParamType::Any,
91        arity: BuiltinParamArity::Required,
92        default: None,
93        description: "Input signal/array.",
94    },
95    BuiltinParamDescriptor {
96        name: "N",
97        ty: BuiltinParamType::NumericScalar,
98        arity: BuiltinParamArity::Optional,
99        default: Some("[]"),
100        description: "Transform length along selected dimension.",
101    },
102    BuiltinParamDescriptor {
103        name: "DIM",
104        ty: BuiltinParamType::NumericScalar,
105        arity: BuiltinParamArity::Optional,
106        default: Some("first non-singleton dimension"),
107        description: "Dimension to transform along.",
108    },
109];
110
111const FFT_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
112    BuiltinSignatureDescriptor {
113        label: "Y = fft(X)",
114        inputs: &FFT_INPUTS_CORE,
115        outputs: &FFT_OUTPUT,
116    },
117    BuiltinSignatureDescriptor {
118        label: "Y = fft(X, N)",
119        inputs: &FFT_INPUTS_WITH_N,
120        outputs: &FFT_OUTPUT,
121    },
122    BuiltinSignatureDescriptor {
123        label: "Y = fft(X, N, DIM)",
124        inputs: &FFT_INPUTS_WITH_N_DIM,
125        outputs: &FFT_OUTPUT,
126    },
127];
128
129const FFT_ERROR_ARG_COUNT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130    code: "RM.FFT.ARG_COUNT",
131    identifier: Some("RunMat:fft:ArgCount"),
132    when: "More than three input arguments are supplied.",
133    message: "fft: expected fft(X), fft(X, N), or fft(X, N, DIM)",
134};
135
136const FFT_ERROR_INVALID_LENGTH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
137    code: "RM.FFT.INVALID_LENGTH",
138    identifier: Some("RunMat:fft:InvalidLength"),
139    when: "Length argument N is invalid.",
140    message: "fft: invalid length argument",
141};
142
143const FFT_ERROR_INVALID_DIMENSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144    code: "RM.FFT.INVALID_DIMENSION",
145    identifier: Some("RunMat:fft:InvalidDimension"),
146    when: "Dimension argument DIM is invalid.",
147    message: "fft: invalid dimension argument",
148};
149
150const FFT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
151    code: "RM.FFT.INVALID_INPUT",
152    identifier: Some("RunMat:fft:InvalidInput"),
153    when: "Input cannot be converted to supported numeric/complex domain.",
154    message: "fft: invalid input",
155};
156
157const FFT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
158    code: "RM.FFT.INTERNAL",
159    identifier: Some("RunMat:fft:Internal"),
160    when: "FFT execution or tensor shaping fails.",
161    message: "fft: internal error",
162};
163
164const FFT_ERRORS: [BuiltinErrorDescriptor; 5] = [
165    FFT_ERROR_ARG_COUNT,
166    FFT_ERROR_INVALID_LENGTH,
167    FFT_ERROR_INVALID_DIMENSION,
168    FFT_ERROR_INVALID_INPUT,
169    FFT_ERROR_INTERNAL,
170];
171
172pub const FFT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
173    signatures: &FFT_SIGNATURES,
174    output_mode: BuiltinOutputMode::Fixed,
175    completion_policy: BuiltinCompletionPolicy::Public,
176    errors: &FFT_ERRORS,
177};
178
179fn fft_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
180    fft_error_with_message(error.message, error)
181}
182
183fn fft_error_with_detail(
184    error: &'static BuiltinErrorDescriptor,
185    detail: impl AsRef<str>,
186) -> RuntimeError {
187    fft_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
188}
189
190fn fft_error_with_source(
191    error: &'static BuiltinErrorDescriptor,
192    detail: impl AsRef<str>,
193    source: RuntimeError,
194) -> RuntimeError {
195    let mut builder = build_runtime_error(format!("{}: {}", error.message, detail.as_ref()))
196        .with_builtin(BUILTIN_NAME)
197        .with_source(source);
198    if let Some(identifier) = error.identifier {
199        builder = builder.with_identifier(identifier);
200    }
201    builder.build()
202}
203
204fn fft_error_with_message(
205    message: impl Into<String>,
206    error: &'static BuiltinErrorDescriptor,
207) -> RuntimeError {
208    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
209    if let Some(identifier) = error.identifier {
210        builder = builder.with_identifier(identifier);
211    }
212    builder.build()
213}
214
215#[runtime_builtin(
216    name = "fft",
217    category = "math/fft",
218    summary = "Compute discrete Fourier transforms.",
219    keywords = "fft,fourier transform,complex,gpu",
220    type_resolver(fft_type),
221    descriptor(crate::builtins::math::fft::forward::FFT_DESCRIPTOR),
222    builtin_path = "crate::builtins::math::fft::forward"
223)]
224async fn fft_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
225    let (length, dimension) = parse_arguments(&rest).await?;
226    match value {
227        Value::GpuTensor(handle) => fft_gpu(handle, length, dimension).await,
228        other => fft_host(other, length, dimension),
229    }
230}
231
232fn fft_host(value: Value, length: Option<usize>, dimension: Option<usize>) -> BuiltinResult<Value> {
233    let tensor = value_to_complex_tensor(value, BUILTIN_NAME).map_err(|source| {
234        fft_error_with_source(&FFT_ERROR_INVALID_INPUT, "input conversion failed", source)
235    })?;
236    let transformed = fft_complex_tensor(tensor, length, dimension)?;
237    Ok(complex_tensor_into_value(transformed))
238}
239
240async fn fft_gpu(
241    handle: GpuTensorHandle,
242    length: Option<usize>,
243    dimension: Option<usize>,
244) -> BuiltinResult<Value> {
245    let mut shape = normalize_scalar_shape(&handle.shape);
246
247    let dim_one_based = match dimension {
248        Some(0) => return Err(fft_error(&FFT_ERROR_INVALID_DIMENSION)),
249        Some(dim) => dim,
250        None => default_dimension(&shape),
251    };
252
253    let dim_index = dim_one_based - 1;
254    while shape.len() <= dim_index {
255        shape.push(1);
256    }
257    let current_len = shape[dim_index];
258    let target_len = length.unwrap_or(current_len);
259
260    if target_len == 0 {
261        let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME)
262            .await
263            .map_err(|source| {
264                fft_error_with_source(&FFT_ERROR_INVALID_INPUT, "gpu gather failed", source)
265            })?;
266        let transformed = fft_complex_tensor(complex, length, dimension)?;
267        return Ok(complex_tensor_into_value(transformed));
268    }
269
270    if let Some(provider) = runmat_accelerate_api::provider() {
271        if let Ok(out) = provider.fft_dim(&handle, length, dim_index).await {
272            return Ok(Value::GpuTensor(out));
273        }
274    }
275
276    let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME)
277        .await
278        .map_err(|source| {
279            fft_error_with_source(&FFT_ERROR_INVALID_INPUT, "gpu gather failed", source)
280        })?;
281    let transformed = fft_complex_tensor(complex, length, dimension)?;
282    Ok(complex_tensor_into_value(transformed))
283}
284
285async fn parse_dimension_arg(value: &Value) -> BuiltinResult<usize> {
286    tensor::dimension_from_value_async(value, BUILTIN_NAME, false)
287        .await
288        .map_err(|detail| fft_error_with_detail(&FFT_ERROR_INVALID_DIMENSION, detail))?
289        .ok_or_else(|| {
290            fft_error_with_detail(&FFT_ERROR_INVALID_DIMENSION, format!("received {value:?}"))
291        })
292}
293
294async fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<usize>, Option<usize>)> {
295    match args.len() {
296        0 => Ok((None, None)),
297        1 => {
298            let len = parse_length(&args[0], BUILTIN_NAME).map_err(|source| {
299                fft_error_with_source(&FFT_ERROR_INVALID_LENGTH, "length parse failed", source)
300            })?;
301            Ok((len, None))
302        }
303        2 => {
304            let len = parse_length(&args[0], BUILTIN_NAME).map_err(|source| {
305                fft_error_with_source(&FFT_ERROR_INVALID_LENGTH, "length parse failed", source)
306            })?;
307            let dim = Some(parse_dimension_arg(&args[1]).await?);
308            Ok((len, dim))
309        }
310        _ => Err(fft_error(&FFT_ERROR_ARG_COUNT)),
311    }
312}
313
314pub(super) fn fft_complex_tensor(
315    tensor: ComplexTensor,
316    length: Option<usize>,
317    dimension: Option<usize>,
318) -> BuiltinResult<ComplexTensor> {
319    transform_complex_tensor(
320        tensor,
321        length,
322        dimension,
323        TransformDirection::Forward,
324        BUILTIN_NAME,
325    )
326    .map_err(|source| fft_error_with_source(&FFT_ERROR_INTERNAL, "transform failed", source))
327}
328
329#[cfg(test)]
330pub(crate) mod tests {
331    use super::*;
332    use crate::builtins::common::test_support;
333    use crate::builtins::math::fft::common;
334    use futures::executor::block_on;
335    use num_complex::Complex;
336    #[cfg(feature = "wgpu")]
337    use runmat_accelerate_api::AccelProvider;
338    use runmat_builtins::{
339        builtin_function_by_name, ComplexTensor as HostComplexTensor, IntValue, ResolveContext,
340        Tensor, Type,
341    };
342    use rustfft::FftPlanner;
343
344    fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
345        (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
346    }
347
348    fn error_message(error: crate::RuntimeError) -> String {
349        error.message().to_string()
350    }
351
352    fn error_identifier(error: &crate::RuntimeError) -> Option<&str> {
353        error.identifier()
354    }
355
356    fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
357        match value {
358            Value::ComplexTensor(tensor) => tensor,
359            Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
360            Value::GpuTensor(handle) => {
361                let provider = runmat_accelerate_api::provider_for_handle(&handle)
362                    .or_else(runmat_accelerate_api::provider)
363                    .expect("provider for gpu handle");
364                let host = block_on(provider.download(&handle)).expect("download gpu fft output");
365                common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
366            }
367            other => panic!("expected complex tensor, got {other:?}"),
368        }
369    }
370
371    fn fft_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
372        block_on(super::fft_builtin(value, rest))
373    }
374
375    #[test]
376    fn fft_type_preserves_shape() {
377        let out = fft_type(
378            &[Type::Tensor {
379                shape: Some(vec![Some(2), Some(3)]),
380            }],
381            &ResolveContext::new(Vec::new()),
382        );
383        assert_eq!(
384            out,
385            Type::Tensor {
386                shape: Some(vec![Some(2), Some(3)])
387            }
388        );
389    }
390
391    #[test]
392    fn fft_descriptor_signatures_and_errors() {
393        let builtin = builtin_function_by_name(BUILTIN_NAME).expect("fft builtin");
394        let descriptor = builtin.descriptor.expect("fft descriptor");
395        let labels: Vec<&str> = descriptor.signatures.iter().map(|sig| sig.label).collect();
396        assert!(labels.contains(&"Y = fft(X)"));
397        assert!(labels.contains(&"Y = fft(X, N)"));
398        assert!(labels.contains(&"Y = fft(X, N, DIM)"));
399        assert!(descriptor
400            .errors
401            .iter()
402            .any(|err| err.code == "RM.FFT.INVALID_LENGTH"));
403    }
404
405    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406    #[test]
407    fn fft_real_vector() {
408        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
409        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
410        match result {
411            Value::ComplexTensor(ct) => {
412                assert_eq!(ct.shape, vec![4]);
413                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
414                for (idx, val) in ct.data.iter().enumerate() {
415                    assert!(
416                        approx_eq(*val, expected[idx], 1e-12),
417                        "idx {idx} {:?} ~= {:?}",
418                        val,
419                        expected[idx]
420                    );
421                }
422            }
423            other => panic!("expected complex tensor, got {other:?}"),
424        }
425    }
426
427    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
428    #[test]
429    fn fft_row_vector_default_dimension_preserves_orientation() {
430        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
431        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
432        match result {
433            Value::ComplexTensor(ct) => {
434                assert_eq!(ct.shape, vec![1, 4]);
435                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
436                for (idx, val) in ct.data.iter().enumerate() {
437                    assert!(
438                        approx_eq(*val, expected[idx], 1e-12),
439                        "idx {idx} {:?} ~= {:?}",
440                        val,
441                        expected[idx]
442                    );
443                }
444            }
445            other => panic!("expected complex tensor, got {other:?}"),
446        }
447    }
448
449    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
450    #[test]
451    fn fft_matrix_default_dimension() {
452        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
453        let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
454        match result {
455            Value::ComplexTensor(ct) => {
456                assert_eq!(ct.shape, vec![2, 3]);
457                let expected = [
458                    (5.0, 0.0),
459                    (-3.0, 0.0),
460                    (7.0, 0.0),
461                    (-3.0, 0.0),
462                    (9.0, 0.0),
463                    (-3.0, 0.0),
464                ];
465                for (idx, val) in ct.data.iter().enumerate() {
466                    assert!(approx_eq(*val, expected[idx], 1e-12));
467                }
468            }
469            other => panic!("expected complex tensor, got {other:?}"),
470        }
471    }
472
473    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
474    #[test]
475    fn fft_zero_padding_with_length_argument() {
476        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
477        let result =
478            fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
479        match result {
480            Value::ComplexTensor(ct) => {
481                assert_eq!(ct.shape, vec![5]);
482                assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
483                assert_eq!(ct.data.len(), 5);
484            }
485            other => panic!("expected complex tensor, got {other:?}"),
486        }
487    }
488
489    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
490    #[test]
491    fn fft_empty_length_argument_defaults_to_input_length() {
492        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
493        let baseline =
494            fft_builtin_sync(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
495        let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
496        let result = fft_builtin_sync(
497            Value::Tensor(tensor),
498            vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
499        )
500        .expect("fft with empty length");
501        let base_ct = value_as_complex_tensor(baseline);
502        let result_ct = value_as_complex_tensor(result);
503        assert_eq!(base_ct.shape, result_ct.shape);
504        assert_eq!(base_ct.data.len(), result_ct.data.len());
505        for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
506            assert!(
507                approx_eq(*a, *b, 1e-12),
508                "mismatch at index {idx}: {:?} vs {:?}",
509                a,
510                b
511            );
512        }
513    }
514
515    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
516    #[test]
517    fn fft_truncates_when_length_smaller() {
518        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
519        let result =
520            fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
521        match result {
522            Value::ComplexTensor(ct) => {
523                assert_eq!(ct.shape, vec![2]);
524                let expected = [(3.0, 0.0), (-1.0, 0.0)];
525                for (idx, val) in ct.data.iter().enumerate() {
526                    assert!(approx_eq(*val, expected[idx], 1e-12));
527                }
528            }
529            other => panic!("expected complex tensor, got {other:?}"),
530        }
531    }
532
533    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
534    #[test]
535    fn fft_zero_length_returns_empty_tensor() {
536        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
537        let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
538        match result {
539            Value::ComplexTensor(ct) => {
540                assert_eq!(ct.shape, vec![0]);
541                assert!(ct.data.is_empty());
542            }
543            other => panic!("expected complex tensor, got {other:?}"),
544        }
545    }
546
547    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
548    #[test]
549    fn fft_complex_input_preserves_imaginary_components() {
550        let tensor =
551            HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
552        let result =
553            fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
554        let mut expected = tensor
555            .data
556            .iter()
557            .map(|(re, im)| Complex::new(*re, *im))
558            .collect::<Vec<_>>();
559        FftPlanner::<f64>::new()
560            .plan_fft_forward(expected.len())
561            .process(&mut expected);
562        match result {
563            Value::ComplexTensor(ct) => {
564                assert_eq!(ct.shape, vec![3]);
565                assert_eq!(ct.data.len(), 3);
566                for (idx, val) in ct.data.iter().enumerate() {
567                    let exp = expected[idx];
568                    assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
569                }
570            }
571            other => panic!("expected complex tensor, got {other:?}"),
572        }
573    }
574
575    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
576    #[test]
577    fn fft_row_vector_dimension_two() {
578        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
579        let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
580        match result {
581            Value::ComplexTensor(ct) => {
582                assert_eq!(ct.shape, vec![1, 4]);
583                let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
584                for (idx, val) in ct.data.iter().enumerate() {
585                    assert!(approx_eq(*val, expected[idx], 1e-12));
586                }
587            }
588            other => panic!("expected complex tensor, got {other:?}"),
589        }
590    }
591
592    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
593    #[test]
594    fn fft_dimension_extends_rank() {
595        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
596        let original = tensor.clone();
597        let result =
598            fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
599        match result {
600            Value::ComplexTensor(ct) => {
601                assert_eq!(ct.shape, vec![1, 4, 1]);
602                assert_eq!(ct.data.len(), original.data.len());
603                for (idx, (re, im)) in ct.data.iter().enumerate() {
604                    assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
605                }
606            }
607            other => panic!("expected complex tensor, got {other:?}"),
608        }
609    }
610
611    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
612    #[test]
613    fn fft_dimension_extends_rank_with_padding() {
614        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
615        let original = tensor.clone();
616        let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
617            .expect("fft with padded third dimension");
618        match result {
619            Value::ComplexTensor(ct) => {
620                assert_eq!(ct.shape, vec![1, 4, 4]);
621                let mut expected = Vec::with_capacity(16);
622                for _depth in 0..4 {
623                    for &value in &original.data {
624                        expected.push((value, 0.0));
625                    }
626                }
627                assert_eq!(ct.data.len(), expected.len());
628                for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
629                    assert!(
630                        approx_eq(*actual, *expected, 1e-12),
631                        "idx {idx}: {:?} != {:?}",
632                        actual,
633                        expected
634                    );
635                }
636            }
637            other => panic!("expected complex tensor, got {other:?}"),
638        }
639    }
640
641    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
642    #[test]
643    fn fft_rejects_non_numeric_length() {
644        let err = block_on(parse_arguments(&[Value::Bool(true)])).unwrap_err();
645        assert_eq!(error_identifier(&err), FFT_ERROR_INVALID_LENGTH.identifier);
646        assert!(error_message(err).contains(FFT_ERROR_INVALID_LENGTH.message));
647    }
648
649    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650    #[test]
651    fn fft_rejects_negative_length() {
652        let err = block_on(parse_arguments(&[Value::Num(-1.0)])).unwrap_err();
653        assert_eq!(error_identifier(&err), FFT_ERROR_INVALID_LENGTH.identifier);
654        assert!(error_message(err).contains(FFT_ERROR_INVALID_LENGTH.message));
655    }
656
657    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
658    #[test]
659    fn fft_rejects_fractional_length() {
660        let err = block_on(parse_arguments(&[Value::Num(1.5)])).unwrap_err();
661        assert_eq!(error_identifier(&err), FFT_ERROR_INVALID_LENGTH.identifier);
662        assert!(error_message(err).contains(FFT_ERROR_INVALID_LENGTH.message));
663    }
664
665    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
666    #[test]
667    fn fft_rejects_dimension_zero() {
668        let err = block_on(parse_arguments(&[
669            Value::Num(4.0),
670            Value::Int(IntValue::I32(0)),
671        ]))
672        .unwrap_err();
673        assert_eq!(
674            error_identifier(&err),
675            FFT_ERROR_INVALID_DIMENSION.identifier
676        );
677        assert!(error_message(err).contains(FFT_ERROR_INVALID_DIMENSION.message));
678    }
679
680    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
681    #[test]
682    fn fft_accepts_scalar_tensor_dimension_argument() {
683        let dim = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
684        let (len, parsed_dim) = block_on(parse_arguments(&[Value::Num(4.0), Value::Tensor(dim)]))
685            .expect("parse arguments");
686        assert_eq!(len, Some(4));
687        assert_eq!(parsed_dim, Some(2));
688    }
689
690    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
691    #[test]
692    fn fft_gpu_roundtrip_matches_cpu() {
693        test_support::with_test_provider(|provider| {
694            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
695            let view = runmat_accelerate_api::HostTensorView {
696                data: &tensor.data,
697                shape: &tensor.shape,
698            };
699            let handle = provider.upload(&view).expect("upload");
700            let gpu = fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
701            let cpu = fft_builtin_sync(Value::Tensor(tensor), Vec::new()).expect("fft");
702            let gpu_host = value_as_complex_tensor(gpu);
703            let cpu_host = value_as_complex_tensor(cpu);
704            assert_eq!(gpu_host.shape, cpu_host.shape);
705            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
706                assert!(approx_eq(*a, *b, 1e-12));
707            }
708            provider.free(&handle).ok();
709        });
710    }
711
712    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713    #[test]
714    fn fft_gpu_non_power_of_two_length_matches_cpu() {
715        test_support::with_test_provider(|provider| {
716            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
717            let view = runmat_accelerate_api::HostTensorView {
718                data: &tensor.data,
719                shape: &tensor.shape,
720            };
721            let handle = provider.upload(&view).expect("upload");
722            let gpu = fft_builtin_sync(
723                Value::GpuTensor(handle.clone()),
724                vec![Value::Int(IntValue::I32(7))],
725            )
726            .expect("fft gpu");
727            let cpu = fft_builtin_sync(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(7))])
728                .expect("fft cpu");
729            let gpu_host = value_as_complex_tensor(gpu);
730            let cpu_host = value_as_complex_tensor(cpu);
731            assert_eq!(gpu_host.shape, cpu_host.shape);
732            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
733                assert!(approx_eq(*a, *b, 1e-10));
734            }
735            provider.free(&handle).ok();
736        });
737    }
738
739    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
740    #[test]
741    fn fft_gpu_prime_length_on_non_last_dimension_matches_cpu() {
742        test_support::with_test_provider(|provider| {
743            let tensor = Tensor::new((1..=18).map(|v| v as f64).collect(), vec![2, 3, 3]).unwrap();
744            let view = runmat_accelerate_api::HostTensorView {
745                data: &tensor.data,
746                shape: &tensor.shape,
747            };
748            let handle = provider.upload(&view).expect("upload");
749            let args = vec![Value::Int(IntValue::I32(7)), Value::Int(IntValue::I32(2))];
750            let gpu =
751                fft_builtin_sync(Value::GpuTensor(handle.clone()), args.clone()).expect("fft gpu");
752            let cpu = fft_builtin_sync(Value::Tensor(tensor), args).expect("fft cpu");
753            let gpu_host = value_as_complex_tensor(gpu);
754            let cpu_host = value_as_complex_tensor(cpu);
755            assert_eq!(gpu_host.shape, cpu_host.shape);
756            for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
757                assert!(approx_eq(*a, *b, 1e-10), "{a:?} vs {b:?}");
758            }
759            provider.free(&handle).ok();
760        });
761    }
762
763    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764    #[test]
765    #[cfg(feature = "wgpu")]
766    fn fft_wgpu_matches_cpu() {
767        if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
768            .expect("wgpu provider")
769        {
770            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
771            let tensor_cpu = tensor.clone();
772            let view = runmat_accelerate_api::HostTensorView {
773                data: &tensor.data,
774                shape: &tensor.shape,
775            };
776            let handle = provider.upload(&view).expect("upload");
777            let gpu =
778                fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
779            let cpu = fft_builtin_sync(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
780            let gpu_ct = value_as_complex_tensor(gpu);
781            let cpu_ct = value_as_complex_tensor(cpu);
782            let tol = match provider.precision() {
783                runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
784                runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
785            };
786            assert_eq!(gpu_ct.shape, cpu_ct.shape);
787            for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
788                assert!(approx_eq(*a, *b, tol), "{a:?} vs {b:?}");
789            }
790            provider.free(&handle).ok();
791        }
792    }
793}