1use super::common::{
4 default_dimension, gather_gpu_complex_tensor, parse_length, transform_complex_tensor,
5 value_to_complex_tensor, TransformDirection,
6};
7use runmat_accelerate_api::GpuTensorHandle;
8use runmat_builtins::{
9 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11 ComplexTensor, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use crate::builtins::common::random_args::complex_tensor_into_value;
16use crate::builtins::common::spec::{
17 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::{shape::normalize_scalar_shape, tensor};
21use crate::builtins::math::fft::type_resolvers::fft_type;
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::forward")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "fft",
27 op_kind: GpuOpKind::Custom("fft"),
28 supported_precisions: &[ScalarType::F32, ScalarType::F64],
29 broadcast: BroadcastSemantics::Matlab,
30 provider_hooks: &[ProviderHook::Custom("fft_dim")],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::NewHandle,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::forward")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "fft",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes:
49 "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
50};
51
52const BUILTIN_NAME: &str = "fft";
53
54const FFT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
55 name: "Y",
56 ty: BuiltinParamType::NumericArray,
57 arity: BuiltinParamArity::Required,
58 default: None,
59 description: "Complex Fourier spectrum output.",
60}];
61
62const FFT_INPUTS_CORE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
63 name: "X",
64 ty: BuiltinParamType::Any,
65 arity: BuiltinParamArity::Required,
66 default: None,
67 description: "Input signal/array.",
68}];
69
70const FFT_INPUTS_WITH_N: [BuiltinParamDescriptor; 2] = [
71 BuiltinParamDescriptor {
72 name: "X",
73 ty: BuiltinParamType::Any,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "Input signal/array.",
77 },
78 BuiltinParamDescriptor {
79 name: "N",
80 ty: BuiltinParamType::NumericScalar,
81 arity: BuiltinParamArity::Optional,
82 default: Some("[]"),
83 description: "Transform length along selected dimension.",
84 },
85];
86
87const FFT_INPUTS_WITH_N_DIM: [BuiltinParamDescriptor; 3] = [
88 BuiltinParamDescriptor {
89 name: "X",
90 ty: BuiltinParamType::Any,
91 arity: BuiltinParamArity::Required,
92 default: None,
93 description: "Input signal/array.",
94 },
95 BuiltinParamDescriptor {
96 name: "N",
97 ty: BuiltinParamType::NumericScalar,
98 arity: BuiltinParamArity::Optional,
99 default: Some("[]"),
100 description: "Transform length along selected dimension.",
101 },
102 BuiltinParamDescriptor {
103 name: "DIM",
104 ty: BuiltinParamType::NumericScalar,
105 arity: BuiltinParamArity::Optional,
106 default: Some("first non-singleton dimension"),
107 description: "Dimension to transform along.",
108 },
109];
110
111const FFT_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
112 BuiltinSignatureDescriptor {
113 label: "Y = fft(X)",
114 inputs: &FFT_INPUTS_CORE,
115 outputs: &FFT_OUTPUT,
116 },
117 BuiltinSignatureDescriptor {
118 label: "Y = fft(X, N)",
119 inputs: &FFT_INPUTS_WITH_N,
120 outputs: &FFT_OUTPUT,
121 },
122 BuiltinSignatureDescriptor {
123 label: "Y = fft(X, N, DIM)",
124 inputs: &FFT_INPUTS_WITH_N_DIM,
125 outputs: &FFT_OUTPUT,
126 },
127];
128
129const FFT_ERROR_ARG_COUNT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130 code: "RM.FFT.ARG_COUNT",
131 identifier: Some("RunMat:fft:ArgCount"),
132 when: "More than three input arguments are supplied.",
133 message: "fft: expected fft(X), fft(X, N), or fft(X, N, DIM)",
134};
135
136const FFT_ERROR_INVALID_LENGTH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
137 code: "RM.FFT.INVALID_LENGTH",
138 identifier: Some("RunMat:fft:InvalidLength"),
139 when: "Length argument N is invalid.",
140 message: "fft: invalid length argument",
141};
142
143const FFT_ERROR_INVALID_DIMENSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144 code: "RM.FFT.INVALID_DIMENSION",
145 identifier: Some("RunMat:fft:InvalidDimension"),
146 when: "Dimension argument DIM is invalid.",
147 message: "fft: invalid dimension argument",
148};
149
150const FFT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
151 code: "RM.FFT.INVALID_INPUT",
152 identifier: Some("RunMat:fft:InvalidInput"),
153 when: "Input cannot be converted to supported numeric/complex domain.",
154 message: "fft: invalid input",
155};
156
157const FFT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
158 code: "RM.FFT.INTERNAL",
159 identifier: Some("RunMat:fft:Internal"),
160 when: "FFT execution or tensor shaping fails.",
161 message: "fft: internal error",
162};
163
164const FFT_ERRORS: [BuiltinErrorDescriptor; 5] = [
165 FFT_ERROR_ARG_COUNT,
166 FFT_ERROR_INVALID_LENGTH,
167 FFT_ERROR_INVALID_DIMENSION,
168 FFT_ERROR_INVALID_INPUT,
169 FFT_ERROR_INTERNAL,
170];
171
172pub const FFT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
173 signatures: &FFT_SIGNATURES,
174 output_mode: BuiltinOutputMode::Fixed,
175 completion_policy: BuiltinCompletionPolicy::Public,
176 errors: &FFT_ERRORS,
177};
178
179fn fft_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
180 fft_error_with_message(error.message, error)
181}
182
183fn fft_error_with_detail(
184 error: &'static BuiltinErrorDescriptor,
185 detail: impl AsRef<str>,
186) -> RuntimeError {
187 fft_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
188}
189
190fn fft_error_with_source(
191 error: &'static BuiltinErrorDescriptor,
192 detail: impl AsRef<str>,
193 source: RuntimeError,
194) -> RuntimeError {
195 let mut builder = build_runtime_error(format!("{}: {}", error.message, detail.as_ref()))
196 .with_builtin(BUILTIN_NAME)
197 .with_source(source);
198 if let Some(identifier) = error.identifier {
199 builder = builder.with_identifier(identifier);
200 }
201 builder.build()
202}
203
204fn fft_error_with_message(
205 message: impl Into<String>,
206 error: &'static BuiltinErrorDescriptor,
207) -> RuntimeError {
208 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
209 if let Some(identifier) = error.identifier {
210 builder = builder.with_identifier(identifier);
211 }
212 builder.build()
213}
214
215#[runtime_builtin(
216 name = "fft",
217 category = "math/fft",
218 summary = "Compute discrete Fourier transforms.",
219 keywords = "fft,fourier transform,complex,gpu",
220 type_resolver(fft_type),
221 descriptor(crate::builtins::math::fft::forward::FFT_DESCRIPTOR),
222 builtin_path = "crate::builtins::math::fft::forward"
223)]
224async fn fft_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
225 let (length, dimension) = parse_arguments(&rest).await?;
226 match value {
227 Value::GpuTensor(handle) => fft_gpu(handle, length, dimension).await,
228 other => fft_host(other, length, dimension),
229 }
230}
231
232fn fft_host(value: Value, length: Option<usize>, dimension: Option<usize>) -> BuiltinResult<Value> {
233 let tensor = value_to_complex_tensor(value, BUILTIN_NAME).map_err(|source| {
234 fft_error_with_source(&FFT_ERROR_INVALID_INPUT, "input conversion failed", source)
235 })?;
236 let transformed = fft_complex_tensor(tensor, length, dimension)?;
237 Ok(complex_tensor_into_value(transformed))
238}
239
240async fn fft_gpu(
241 handle: GpuTensorHandle,
242 length: Option<usize>,
243 dimension: Option<usize>,
244) -> BuiltinResult<Value> {
245 let mut shape = normalize_scalar_shape(&handle.shape);
246
247 let dim_one_based = match dimension {
248 Some(0) => return Err(fft_error(&FFT_ERROR_INVALID_DIMENSION)),
249 Some(dim) => dim,
250 None => default_dimension(&shape),
251 };
252
253 let dim_index = dim_one_based - 1;
254 while shape.len() <= dim_index {
255 shape.push(1);
256 }
257 let current_len = shape[dim_index];
258 let target_len = length.unwrap_or(current_len);
259
260 if target_len == 0 {
261 let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME)
262 .await
263 .map_err(|source| {
264 fft_error_with_source(&FFT_ERROR_INVALID_INPUT, "gpu gather failed", source)
265 })?;
266 let transformed = fft_complex_tensor(complex, length, dimension)?;
267 return Ok(complex_tensor_into_value(transformed));
268 }
269
270 if let Some(provider) = runmat_accelerate_api::provider() {
271 if let Ok(out) = provider.fft_dim(&handle, length, dim_index).await {
272 return Ok(Value::GpuTensor(out));
273 }
274 }
275
276 let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME)
277 .await
278 .map_err(|source| {
279 fft_error_with_source(&FFT_ERROR_INVALID_INPUT, "gpu gather failed", source)
280 })?;
281 let transformed = fft_complex_tensor(complex, length, dimension)?;
282 Ok(complex_tensor_into_value(transformed))
283}
284
285async fn parse_dimension_arg(value: &Value) -> BuiltinResult<usize> {
286 tensor::dimension_from_value_async(value, BUILTIN_NAME, false)
287 .await
288 .map_err(|detail| fft_error_with_detail(&FFT_ERROR_INVALID_DIMENSION, detail))?
289 .ok_or_else(|| {
290 fft_error_with_detail(&FFT_ERROR_INVALID_DIMENSION, format!("received {value:?}"))
291 })
292}
293
294async fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<usize>, Option<usize>)> {
295 match args.len() {
296 0 => Ok((None, None)),
297 1 => {
298 let len = parse_length(&args[0], BUILTIN_NAME).map_err(|source| {
299 fft_error_with_source(&FFT_ERROR_INVALID_LENGTH, "length parse failed", source)
300 })?;
301 Ok((len, None))
302 }
303 2 => {
304 let len = parse_length(&args[0], BUILTIN_NAME).map_err(|source| {
305 fft_error_with_source(&FFT_ERROR_INVALID_LENGTH, "length parse failed", source)
306 })?;
307 let dim = Some(parse_dimension_arg(&args[1]).await?);
308 Ok((len, dim))
309 }
310 _ => Err(fft_error(&FFT_ERROR_ARG_COUNT)),
311 }
312}
313
314pub(super) fn fft_complex_tensor(
315 tensor: ComplexTensor,
316 length: Option<usize>,
317 dimension: Option<usize>,
318) -> BuiltinResult<ComplexTensor> {
319 transform_complex_tensor(
320 tensor,
321 length,
322 dimension,
323 TransformDirection::Forward,
324 BUILTIN_NAME,
325 )
326 .map_err(|source| fft_error_with_source(&FFT_ERROR_INTERNAL, "transform failed", source))
327}
328
329#[cfg(test)]
330pub(crate) mod tests {
331 use super::*;
332 use crate::builtins::common::test_support;
333 use crate::builtins::math::fft::common;
334 use futures::executor::block_on;
335 use num_complex::Complex;
336 #[cfg(feature = "wgpu")]
337 use runmat_accelerate_api::AccelProvider;
338 use runmat_builtins::{
339 builtin_function_by_name, ComplexTensor as HostComplexTensor, IntValue, ResolveContext,
340 Tensor, Type,
341 };
342 use rustfft::FftPlanner;
343
344 fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
345 (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
346 }
347
348 fn error_message(error: crate::RuntimeError) -> String {
349 error.message().to_string()
350 }
351
352 fn error_identifier(error: &crate::RuntimeError) -> Option<&str> {
353 error.identifier()
354 }
355
356 fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
357 match value {
358 Value::ComplexTensor(tensor) => tensor,
359 Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
360 Value::GpuTensor(handle) => {
361 let provider = runmat_accelerate_api::provider_for_handle(&handle)
362 .or_else(runmat_accelerate_api::provider)
363 .expect("provider for gpu handle");
364 let host = block_on(provider.download(&handle)).expect("download gpu fft output");
365 common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
366 }
367 other => panic!("expected complex tensor, got {other:?}"),
368 }
369 }
370
371 fn fft_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
372 block_on(super::fft_builtin(value, rest))
373 }
374
375 #[test]
376 fn fft_type_preserves_shape() {
377 let out = fft_type(
378 &[Type::Tensor {
379 shape: Some(vec![Some(2), Some(3)]),
380 }],
381 &ResolveContext::new(Vec::new()),
382 );
383 assert_eq!(
384 out,
385 Type::Tensor {
386 shape: Some(vec![Some(2), Some(3)])
387 }
388 );
389 }
390
391 #[test]
392 fn fft_descriptor_signatures_and_errors() {
393 let builtin = builtin_function_by_name(BUILTIN_NAME).expect("fft builtin");
394 let descriptor = builtin.descriptor.expect("fft descriptor");
395 let labels: Vec<&str> = descriptor.signatures.iter().map(|sig| sig.label).collect();
396 assert!(labels.contains(&"Y = fft(X)"));
397 assert!(labels.contains(&"Y = fft(X, N)"));
398 assert!(labels.contains(&"Y = fft(X, N, DIM)"));
399 assert!(descriptor
400 .errors
401 .iter()
402 .any(|err| err.code == "RM.FFT.INVALID_LENGTH"));
403 }
404
405 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406 #[test]
407 fn fft_real_vector() {
408 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
409 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
410 match result {
411 Value::ComplexTensor(ct) => {
412 assert_eq!(ct.shape, vec![4]);
413 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
414 for (idx, val) in ct.data.iter().enumerate() {
415 assert!(
416 approx_eq(*val, expected[idx], 1e-12),
417 "idx {idx} {:?} ~= {:?}",
418 val,
419 expected[idx]
420 );
421 }
422 }
423 other => panic!("expected complex tensor, got {other:?}"),
424 }
425 }
426
427 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
428 #[test]
429 fn fft_row_vector_default_dimension_preserves_orientation() {
430 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
431 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
432 match result {
433 Value::ComplexTensor(ct) => {
434 assert_eq!(ct.shape, vec![1, 4]);
435 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
436 for (idx, val) in ct.data.iter().enumerate() {
437 assert!(
438 approx_eq(*val, expected[idx], 1e-12),
439 "idx {idx} {:?} ~= {:?}",
440 val,
441 expected[idx]
442 );
443 }
444 }
445 other => panic!("expected complex tensor, got {other:?}"),
446 }
447 }
448
449 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
450 #[test]
451 fn fft_matrix_default_dimension() {
452 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
453 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
454 match result {
455 Value::ComplexTensor(ct) => {
456 assert_eq!(ct.shape, vec![2, 3]);
457 let expected = [
458 (5.0, 0.0),
459 (-3.0, 0.0),
460 (7.0, 0.0),
461 (-3.0, 0.0),
462 (9.0, 0.0),
463 (-3.0, 0.0),
464 ];
465 for (idx, val) in ct.data.iter().enumerate() {
466 assert!(approx_eq(*val, expected[idx], 1e-12));
467 }
468 }
469 other => panic!("expected complex tensor, got {other:?}"),
470 }
471 }
472
473 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
474 #[test]
475 fn fft_zero_padding_with_length_argument() {
476 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
477 let result =
478 fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
479 match result {
480 Value::ComplexTensor(ct) => {
481 assert_eq!(ct.shape, vec![5]);
482 assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
483 assert_eq!(ct.data.len(), 5);
484 }
485 other => panic!("expected complex tensor, got {other:?}"),
486 }
487 }
488
489 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
490 #[test]
491 fn fft_empty_length_argument_defaults_to_input_length() {
492 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
493 let baseline =
494 fft_builtin_sync(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
495 let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
496 let result = fft_builtin_sync(
497 Value::Tensor(tensor),
498 vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
499 )
500 .expect("fft with empty length");
501 let base_ct = value_as_complex_tensor(baseline);
502 let result_ct = value_as_complex_tensor(result);
503 assert_eq!(base_ct.shape, result_ct.shape);
504 assert_eq!(base_ct.data.len(), result_ct.data.len());
505 for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
506 assert!(
507 approx_eq(*a, *b, 1e-12),
508 "mismatch at index {idx}: {:?} vs {:?}",
509 a,
510 b
511 );
512 }
513 }
514
515 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
516 #[test]
517 fn fft_truncates_when_length_smaller() {
518 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
519 let result =
520 fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
521 match result {
522 Value::ComplexTensor(ct) => {
523 assert_eq!(ct.shape, vec![2]);
524 let expected = [(3.0, 0.0), (-1.0, 0.0)];
525 for (idx, val) in ct.data.iter().enumerate() {
526 assert!(approx_eq(*val, expected[idx], 1e-12));
527 }
528 }
529 other => panic!("expected complex tensor, got {other:?}"),
530 }
531 }
532
533 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
534 #[test]
535 fn fft_zero_length_returns_empty_tensor() {
536 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
537 let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
538 match result {
539 Value::ComplexTensor(ct) => {
540 assert_eq!(ct.shape, vec![0]);
541 assert!(ct.data.is_empty());
542 }
543 other => panic!("expected complex tensor, got {other:?}"),
544 }
545 }
546
547 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
548 #[test]
549 fn fft_complex_input_preserves_imaginary_components() {
550 let tensor =
551 HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
552 let result =
553 fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
554 let mut expected = tensor
555 .data
556 .iter()
557 .map(|(re, im)| Complex::new(*re, *im))
558 .collect::<Vec<_>>();
559 FftPlanner::<f64>::new()
560 .plan_fft_forward(expected.len())
561 .process(&mut expected);
562 match result {
563 Value::ComplexTensor(ct) => {
564 assert_eq!(ct.shape, vec![3]);
565 assert_eq!(ct.data.len(), 3);
566 for (idx, val) in ct.data.iter().enumerate() {
567 let exp = expected[idx];
568 assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
569 }
570 }
571 other => panic!("expected complex tensor, got {other:?}"),
572 }
573 }
574
575 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
576 #[test]
577 fn fft_row_vector_dimension_two() {
578 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
579 let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
580 match result {
581 Value::ComplexTensor(ct) => {
582 assert_eq!(ct.shape, vec![1, 4]);
583 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
584 for (idx, val) in ct.data.iter().enumerate() {
585 assert!(approx_eq(*val, expected[idx], 1e-12));
586 }
587 }
588 other => panic!("expected complex tensor, got {other:?}"),
589 }
590 }
591
592 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
593 #[test]
594 fn fft_dimension_extends_rank() {
595 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
596 let original = tensor.clone();
597 let result =
598 fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
599 match result {
600 Value::ComplexTensor(ct) => {
601 assert_eq!(ct.shape, vec![1, 4, 1]);
602 assert_eq!(ct.data.len(), original.data.len());
603 for (idx, (re, im)) in ct.data.iter().enumerate() {
604 assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
605 }
606 }
607 other => panic!("expected complex tensor, got {other:?}"),
608 }
609 }
610
611 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
612 #[test]
613 fn fft_dimension_extends_rank_with_padding() {
614 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
615 let original = tensor.clone();
616 let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
617 .expect("fft with padded third dimension");
618 match result {
619 Value::ComplexTensor(ct) => {
620 assert_eq!(ct.shape, vec![1, 4, 4]);
621 let mut expected = Vec::with_capacity(16);
622 for _depth in 0..4 {
623 for &value in &original.data {
624 expected.push((value, 0.0));
625 }
626 }
627 assert_eq!(ct.data.len(), expected.len());
628 for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
629 assert!(
630 approx_eq(*actual, *expected, 1e-12),
631 "idx {idx}: {:?} != {:?}",
632 actual,
633 expected
634 );
635 }
636 }
637 other => panic!("expected complex tensor, got {other:?}"),
638 }
639 }
640
641 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
642 #[test]
643 fn fft_rejects_non_numeric_length() {
644 let err = block_on(parse_arguments(&[Value::Bool(true)])).unwrap_err();
645 assert_eq!(error_identifier(&err), FFT_ERROR_INVALID_LENGTH.identifier);
646 assert!(error_message(err).contains(FFT_ERROR_INVALID_LENGTH.message));
647 }
648
649 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650 #[test]
651 fn fft_rejects_negative_length() {
652 let err = block_on(parse_arguments(&[Value::Num(-1.0)])).unwrap_err();
653 assert_eq!(error_identifier(&err), FFT_ERROR_INVALID_LENGTH.identifier);
654 assert!(error_message(err).contains(FFT_ERROR_INVALID_LENGTH.message));
655 }
656
657 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
658 #[test]
659 fn fft_rejects_fractional_length() {
660 let err = block_on(parse_arguments(&[Value::Num(1.5)])).unwrap_err();
661 assert_eq!(error_identifier(&err), FFT_ERROR_INVALID_LENGTH.identifier);
662 assert!(error_message(err).contains(FFT_ERROR_INVALID_LENGTH.message));
663 }
664
665 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
666 #[test]
667 fn fft_rejects_dimension_zero() {
668 let err = block_on(parse_arguments(&[
669 Value::Num(4.0),
670 Value::Int(IntValue::I32(0)),
671 ]))
672 .unwrap_err();
673 assert_eq!(
674 error_identifier(&err),
675 FFT_ERROR_INVALID_DIMENSION.identifier
676 );
677 assert!(error_message(err).contains(FFT_ERROR_INVALID_DIMENSION.message));
678 }
679
680 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
681 #[test]
682 fn fft_accepts_scalar_tensor_dimension_argument() {
683 let dim = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
684 let (len, parsed_dim) = block_on(parse_arguments(&[Value::Num(4.0), Value::Tensor(dim)]))
685 .expect("parse arguments");
686 assert_eq!(len, Some(4));
687 assert_eq!(parsed_dim, Some(2));
688 }
689
690 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
691 #[test]
692 fn fft_gpu_roundtrip_matches_cpu() {
693 test_support::with_test_provider(|provider| {
694 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
695 let view = runmat_accelerate_api::HostTensorView {
696 data: &tensor.data,
697 shape: &tensor.shape,
698 };
699 let handle = provider.upload(&view).expect("upload");
700 let gpu = fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
701 let cpu = fft_builtin_sync(Value::Tensor(tensor), Vec::new()).expect("fft");
702 let gpu_host = value_as_complex_tensor(gpu);
703 let cpu_host = value_as_complex_tensor(cpu);
704 assert_eq!(gpu_host.shape, cpu_host.shape);
705 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
706 assert!(approx_eq(*a, *b, 1e-12));
707 }
708 provider.free(&handle).ok();
709 });
710 }
711
712 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713 #[test]
714 fn fft_gpu_non_power_of_two_length_matches_cpu() {
715 test_support::with_test_provider(|provider| {
716 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
717 let view = runmat_accelerate_api::HostTensorView {
718 data: &tensor.data,
719 shape: &tensor.shape,
720 };
721 let handle = provider.upload(&view).expect("upload");
722 let gpu = fft_builtin_sync(
723 Value::GpuTensor(handle.clone()),
724 vec![Value::Int(IntValue::I32(7))],
725 )
726 .expect("fft gpu");
727 let cpu = fft_builtin_sync(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(7))])
728 .expect("fft cpu");
729 let gpu_host = value_as_complex_tensor(gpu);
730 let cpu_host = value_as_complex_tensor(cpu);
731 assert_eq!(gpu_host.shape, cpu_host.shape);
732 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
733 assert!(approx_eq(*a, *b, 1e-10));
734 }
735 provider.free(&handle).ok();
736 });
737 }
738
739 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
740 #[test]
741 fn fft_gpu_prime_length_on_non_last_dimension_matches_cpu() {
742 test_support::with_test_provider(|provider| {
743 let tensor = Tensor::new((1..=18).map(|v| v as f64).collect(), vec![2, 3, 3]).unwrap();
744 let view = runmat_accelerate_api::HostTensorView {
745 data: &tensor.data,
746 shape: &tensor.shape,
747 };
748 let handle = provider.upload(&view).expect("upload");
749 let args = vec![Value::Int(IntValue::I32(7)), Value::Int(IntValue::I32(2))];
750 let gpu =
751 fft_builtin_sync(Value::GpuTensor(handle.clone()), args.clone()).expect("fft gpu");
752 let cpu = fft_builtin_sync(Value::Tensor(tensor), args).expect("fft cpu");
753 let gpu_host = value_as_complex_tensor(gpu);
754 let cpu_host = value_as_complex_tensor(cpu);
755 assert_eq!(gpu_host.shape, cpu_host.shape);
756 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
757 assert!(approx_eq(*a, *b, 1e-10), "{a:?} vs {b:?}");
758 }
759 provider.free(&handle).ok();
760 });
761 }
762
763 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764 #[test]
765 #[cfg(feature = "wgpu")]
766 fn fft_wgpu_matches_cpu() {
767 if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
768 .expect("wgpu provider")
769 {
770 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
771 let tensor_cpu = tensor.clone();
772 let view = runmat_accelerate_api::HostTensorView {
773 data: &tensor.data,
774 shape: &tensor.shape,
775 };
776 let handle = provider.upload(&view).expect("upload");
777 let gpu =
778 fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
779 let cpu = fft_builtin_sync(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
780 let gpu_ct = value_as_complex_tensor(gpu);
781 let cpu_ct = value_as_complex_tensor(cpu);
782 let tol = match provider.precision() {
783 runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
784 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
785 };
786 assert_eq!(gpu_ct.shape, cpu_ct.shape);
787 for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
788 assert!(approx_eq(*a, *b, tol), "{a:?} vs {b:?}");
789 }
790 provider.free(&handle).ok();
791 }
792 }
793}