1use super::common::{
4 default_dimension, gather_gpu_complex_tensor, parse_length, transform_complex_tensor,
5 value_to_complex_tensor, TransformDirection,
6};
7use runmat_accelerate_api::GpuTensorHandle;
8use runmat_builtins::{ComplexTensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{shape::normalize_scalar_shape, tensor};
17use crate::builtins::math::fft::type_resolvers::fft_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::forward")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "fft",
23 op_kind: GpuOpKind::Custom("fft"),
24 supported_precisions: &[ScalarType::F32, ScalarType::F64],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[ProviderHook::Custom("fft_dim")],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::NewHandle,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::forward")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "fft",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes:
45 "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
46};
47
48const BUILTIN_NAME: &str = "fft";
49
50fn fft_error(message: impl Into<String>) -> RuntimeError {
51 build_runtime_error(message)
52 .with_builtin(BUILTIN_NAME)
53 .build()
54}
55
56#[runtime_builtin(
57 name = "fft",
58 category = "math/fft",
59 summary = "Compute the discrete Fourier transform (DFT) of numeric or complex data.",
60 keywords = "fft,fourier transform,complex,gpu",
61 type_resolver(fft_type),
62 builtin_path = "crate::builtins::math::fft::forward"
63)]
64async fn fft_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
65 let (length, dimension) = parse_arguments(&rest).await?;
66 match value {
67 Value::GpuTensor(handle) => fft_gpu(handle, length, dimension).await,
68 other => fft_host(other, length, dimension),
69 }
70}
71
72fn fft_host(value: Value, length: Option<usize>, dimension: Option<usize>) -> BuiltinResult<Value> {
73 let tensor = value_to_complex_tensor(value, BUILTIN_NAME)?;
74 let transformed = fft_complex_tensor(tensor, length, dimension)?;
75 Ok(complex_tensor_into_value(transformed))
76}
77
78async fn fft_gpu(
79 handle: GpuTensorHandle,
80 length: Option<usize>,
81 dimension: Option<usize>,
82) -> BuiltinResult<Value> {
83 let mut shape = normalize_scalar_shape(&handle.shape);
84
85 let dim_one_based = match dimension {
86 Some(0) => return Err(fft_error("fft: dimension must be >= 1")),
87 Some(dim) => dim,
88 None => default_dimension(&shape),
89 };
90
91 let dim_index = dim_one_based - 1;
92 while shape.len() <= dim_index {
93 shape.push(1);
94 }
95 let current_len = shape[dim_index];
96 let target_len = length.unwrap_or(current_len);
97
98 if target_len == 0 {
99 let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
100 let transformed = fft_complex_tensor(complex, length, dimension)?;
101 return Ok(complex_tensor_into_value(transformed));
102 }
103
104 if let Some(provider) = runmat_accelerate_api::provider() {
105 if let Ok(out) = provider.fft_dim(&handle, length, dim_index).await {
106 return Ok(Value::GpuTensor(out));
107 }
108 }
109
110 let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
111 let transformed = fft_complex_tensor(complex, length, dimension)?;
112 Ok(complex_tensor_into_value(transformed))
113}
114
115async fn parse_dimension_arg(value: &Value) -> BuiltinResult<usize> {
116 tensor::dimension_from_value_async(value, BUILTIN_NAME, false)
117 .await
118 .map_err(fft_error)?
119 .ok_or_else(|| {
120 fft_error(format!(
121 "{BUILTIN_NAME}: dimension must be numeric, got {value:?}"
122 ))
123 })
124}
125
126async fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<usize>, Option<usize>)> {
127 match args.len() {
128 0 => Ok((None, None)),
129 1 => {
130 let len = parse_length(&args[0], BUILTIN_NAME)?;
131 Ok((len, None))
132 }
133 2 => {
134 let len = parse_length(&args[0], BUILTIN_NAME)?;
135 let dim = Some(parse_dimension_arg(&args[1]).await?);
136 Ok((len, dim))
137 }
138 _ => Err(fft_error(
139 "fft: expected fft(X), fft(X, N), or fft(X, N, DIM)",
140 )),
141 }
142}
143
144pub(super) fn fft_complex_tensor(
145 tensor: ComplexTensor,
146 length: Option<usize>,
147 dimension: Option<usize>,
148) -> BuiltinResult<ComplexTensor> {
149 transform_complex_tensor(
150 tensor,
151 length,
152 dimension,
153 TransformDirection::Forward,
154 BUILTIN_NAME,
155 )
156}
157
158#[cfg(test)]
159pub(crate) mod tests {
160 use super::*;
161 use crate::builtins::common::test_support;
162 use crate::builtins::math::fft::common;
163 use futures::executor::block_on;
164 use num_complex::Complex;
165 #[cfg(feature = "wgpu")]
166 use runmat_accelerate_api::AccelProvider;
167 use runmat_builtins::{
168 ComplexTensor as HostComplexTensor, IntValue, ResolveContext, Tensor, Type,
169 };
170 use rustfft::FftPlanner;
171
172 fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
173 (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
174 }
175
176 fn error_message(error: crate::RuntimeError) -> String {
177 error.message().to_string()
178 }
179
180 fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
181 match value {
182 Value::ComplexTensor(tensor) => tensor,
183 Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
184 Value::GpuTensor(handle) => {
185 let provider = runmat_accelerate_api::provider_for_handle(&handle)
186 .or_else(runmat_accelerate_api::provider)
187 .expect("provider for gpu handle");
188 let host = block_on(provider.download(&handle)).expect("download gpu fft output");
189 common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
190 }
191 other => panic!("expected complex tensor, got {other:?}"),
192 }
193 }
194
195 fn fft_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
196 block_on(super::fft_builtin(value, rest))
197 }
198
199 #[test]
200 fn fft_type_preserves_shape() {
201 let out = fft_type(
202 &[Type::Tensor {
203 shape: Some(vec![Some(2), Some(3)]),
204 }],
205 &ResolveContext::new(Vec::new()),
206 );
207 assert_eq!(
208 out,
209 Type::Tensor {
210 shape: Some(vec![Some(2), Some(3)])
211 }
212 );
213 }
214
215 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
216 #[test]
217 fn fft_real_vector() {
218 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
219 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
220 match result {
221 Value::ComplexTensor(ct) => {
222 assert_eq!(ct.shape, vec![4]);
223 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
224 for (idx, val) in ct.data.iter().enumerate() {
225 assert!(
226 approx_eq(*val, expected[idx], 1e-12),
227 "idx {idx} {:?} ~= {:?}",
228 val,
229 expected[idx]
230 );
231 }
232 }
233 other => panic!("expected complex tensor, got {other:?}"),
234 }
235 }
236
237 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
238 #[test]
239 fn fft_row_vector_default_dimension_preserves_orientation() {
240 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
241 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
242 match result {
243 Value::ComplexTensor(ct) => {
244 assert_eq!(ct.shape, vec![1, 4]);
245 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
246 for (idx, val) in ct.data.iter().enumerate() {
247 assert!(
248 approx_eq(*val, expected[idx], 1e-12),
249 "idx {idx} {:?} ~= {:?}",
250 val,
251 expected[idx]
252 );
253 }
254 }
255 other => panic!("expected complex tensor, got {other:?}"),
256 }
257 }
258
259 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
260 #[test]
261 fn fft_matrix_default_dimension() {
262 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
263 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
264 match result {
265 Value::ComplexTensor(ct) => {
266 assert_eq!(ct.shape, vec![2, 3]);
267 let expected = [
268 (5.0, 0.0),
269 (-3.0, 0.0),
270 (7.0, 0.0),
271 (-3.0, 0.0),
272 (9.0, 0.0),
273 (-3.0, 0.0),
274 ];
275 for (idx, val) in ct.data.iter().enumerate() {
276 assert!(approx_eq(*val, expected[idx], 1e-12));
277 }
278 }
279 other => panic!("expected complex tensor, got {other:?}"),
280 }
281 }
282
283 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284 #[test]
285 fn fft_zero_padding_with_length_argument() {
286 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
287 let result =
288 fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
289 match result {
290 Value::ComplexTensor(ct) => {
291 assert_eq!(ct.shape, vec![5]);
292 assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
293 assert_eq!(ct.data.len(), 5);
294 }
295 other => panic!("expected complex tensor, got {other:?}"),
296 }
297 }
298
299 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
300 #[test]
301 fn fft_empty_length_argument_defaults_to_input_length() {
302 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
303 let baseline =
304 fft_builtin_sync(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
305 let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
306 let result = fft_builtin_sync(
307 Value::Tensor(tensor),
308 vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
309 )
310 .expect("fft with empty length");
311 let base_ct = value_as_complex_tensor(baseline);
312 let result_ct = value_as_complex_tensor(result);
313 assert_eq!(base_ct.shape, result_ct.shape);
314 assert_eq!(base_ct.data.len(), result_ct.data.len());
315 for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
316 assert!(
317 approx_eq(*a, *b, 1e-12),
318 "mismatch at index {idx}: {:?} vs {:?}",
319 a,
320 b
321 );
322 }
323 }
324
325 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
326 #[test]
327 fn fft_truncates_when_length_smaller() {
328 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
329 let result =
330 fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
331 match result {
332 Value::ComplexTensor(ct) => {
333 assert_eq!(ct.shape, vec![2]);
334 let expected = [(3.0, 0.0), (-1.0, 0.0)];
335 for (idx, val) in ct.data.iter().enumerate() {
336 assert!(approx_eq(*val, expected[idx], 1e-12));
337 }
338 }
339 other => panic!("expected complex tensor, got {other:?}"),
340 }
341 }
342
343 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
344 #[test]
345 fn fft_zero_length_returns_empty_tensor() {
346 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
347 let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
348 match result {
349 Value::ComplexTensor(ct) => {
350 assert_eq!(ct.shape, vec![0]);
351 assert!(ct.data.is_empty());
352 }
353 other => panic!("expected complex tensor, got {other:?}"),
354 }
355 }
356
357 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
358 #[test]
359 fn fft_complex_input_preserves_imaginary_components() {
360 let tensor =
361 HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
362 let result =
363 fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
364 let mut expected = tensor
365 .data
366 .iter()
367 .map(|(re, im)| Complex::new(*re, *im))
368 .collect::<Vec<_>>();
369 FftPlanner::<f64>::new()
370 .plan_fft_forward(expected.len())
371 .process(&mut expected);
372 match result {
373 Value::ComplexTensor(ct) => {
374 assert_eq!(ct.shape, vec![3]);
375 assert_eq!(ct.data.len(), 3);
376 for (idx, val) in ct.data.iter().enumerate() {
377 let exp = expected[idx];
378 assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
379 }
380 }
381 other => panic!("expected complex tensor, got {other:?}"),
382 }
383 }
384
385 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
386 #[test]
387 fn fft_row_vector_dimension_two() {
388 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
389 let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
390 match result {
391 Value::ComplexTensor(ct) => {
392 assert_eq!(ct.shape, vec![1, 4]);
393 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
394 for (idx, val) in ct.data.iter().enumerate() {
395 assert!(approx_eq(*val, expected[idx], 1e-12));
396 }
397 }
398 other => panic!("expected complex tensor, got {other:?}"),
399 }
400 }
401
402 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403 #[test]
404 fn fft_dimension_extends_rank() {
405 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
406 let original = tensor.clone();
407 let result =
408 fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
409 match result {
410 Value::ComplexTensor(ct) => {
411 assert_eq!(ct.shape, vec![1, 4, 1]);
412 assert_eq!(ct.data.len(), original.data.len());
413 for (idx, (re, im)) in ct.data.iter().enumerate() {
414 assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
415 }
416 }
417 other => panic!("expected complex tensor, got {other:?}"),
418 }
419 }
420
421 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
422 #[test]
423 fn fft_dimension_extends_rank_with_padding() {
424 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
425 let original = tensor.clone();
426 let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
427 .expect("fft with padded third dimension");
428 match result {
429 Value::ComplexTensor(ct) => {
430 assert_eq!(ct.shape, vec![1, 4, 4]);
431 let mut expected = Vec::with_capacity(16);
432 for _depth in 0..4 {
433 for &value in &original.data {
434 expected.push((value, 0.0));
435 }
436 }
437 assert_eq!(ct.data.len(), expected.len());
438 for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
439 assert!(
440 approx_eq(*actual, *expected, 1e-12),
441 "idx {idx}: {:?} != {:?}",
442 actual,
443 expected
444 );
445 }
446 }
447 other => panic!("expected complex tensor, got {other:?}"),
448 }
449 }
450
451 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
452 #[test]
453 fn fft_rejects_non_numeric_length() {
454 assert!(block_on(parse_arguments(&[Value::Bool(true)])).is_err());
455 }
456
457 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
458 #[test]
459 fn fft_rejects_negative_length() {
460 let err = error_message(block_on(parse_arguments(&[Value::Num(-1.0)])).unwrap_err());
461 assert!(err.contains("length must be non-negative"));
462 }
463
464 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
465 #[test]
466 fn fft_rejects_fractional_length() {
467 let err = error_message(block_on(parse_arguments(&[Value::Num(1.5)])).unwrap_err());
468 assert!(err.contains("length must be an integer"));
469 }
470
471 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
472 #[test]
473 fn fft_rejects_dimension_zero() {
474 let err = error_message(
475 block_on(parse_arguments(&[
476 Value::Num(4.0),
477 Value::Int(IntValue::I32(0)),
478 ]))
479 .unwrap_err(),
480 );
481 assert!(err.contains("dimension must be >= 1"));
482 }
483
484 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
485 #[test]
486 fn fft_accepts_scalar_tensor_dimension_argument() {
487 let dim = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
488 let (len, parsed_dim) = block_on(parse_arguments(&[Value::Num(4.0), Value::Tensor(dim)]))
489 .expect("parse arguments");
490 assert_eq!(len, Some(4));
491 assert_eq!(parsed_dim, Some(2));
492 }
493
494 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495 #[test]
496 fn fft_gpu_roundtrip_matches_cpu() {
497 test_support::with_test_provider(|provider| {
498 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
499 let view = runmat_accelerate_api::HostTensorView {
500 data: &tensor.data,
501 shape: &tensor.shape,
502 };
503 let handle = provider.upload(&view).expect("upload");
504 let gpu = fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
505 let cpu = fft_builtin_sync(Value::Tensor(tensor), Vec::new()).expect("fft");
506 let gpu_host = value_as_complex_tensor(gpu);
507 let cpu_host = value_as_complex_tensor(cpu);
508 assert_eq!(gpu_host.shape, cpu_host.shape);
509 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
510 assert!(approx_eq(*a, *b, 1e-12));
511 }
512 provider.free(&handle).ok();
513 });
514 }
515
516 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517 #[test]
518 fn fft_gpu_non_power_of_two_length_matches_cpu() {
519 test_support::with_test_provider(|provider| {
520 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
521 let view = runmat_accelerate_api::HostTensorView {
522 data: &tensor.data,
523 shape: &tensor.shape,
524 };
525 let handle = provider.upload(&view).expect("upload");
526 let gpu = fft_builtin_sync(
527 Value::GpuTensor(handle.clone()),
528 vec![Value::Int(IntValue::I32(7))],
529 )
530 .expect("fft gpu");
531 let cpu = fft_builtin_sync(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(7))])
532 .expect("fft cpu");
533 let gpu_host = value_as_complex_tensor(gpu);
534 let cpu_host = value_as_complex_tensor(cpu);
535 assert_eq!(gpu_host.shape, cpu_host.shape);
536 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
537 assert!(approx_eq(*a, *b, 1e-10));
538 }
539 provider.free(&handle).ok();
540 });
541 }
542
543 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
544 #[test]
545 fn fft_gpu_prime_length_on_non_last_dimension_matches_cpu() {
546 test_support::with_test_provider(|provider| {
547 let tensor = Tensor::new((1..=18).map(|v| v as f64).collect(), vec![2, 3, 3]).unwrap();
548 let view = runmat_accelerate_api::HostTensorView {
549 data: &tensor.data,
550 shape: &tensor.shape,
551 };
552 let handle = provider.upload(&view).expect("upload");
553 let args = vec![Value::Int(IntValue::I32(7)), Value::Int(IntValue::I32(2))];
554 let gpu =
555 fft_builtin_sync(Value::GpuTensor(handle.clone()), args.clone()).expect("fft gpu");
556 let cpu = fft_builtin_sync(Value::Tensor(tensor), args).expect("fft cpu");
557 let gpu_host = value_as_complex_tensor(gpu);
558 let cpu_host = value_as_complex_tensor(cpu);
559 assert_eq!(gpu_host.shape, cpu_host.shape);
560 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
561 assert!(approx_eq(*a, *b, 1e-10), "{a:?} vs {b:?}");
562 }
563 provider.free(&handle).ok();
564 });
565 }
566
567 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
568 #[test]
569 #[cfg(feature = "wgpu")]
570 fn fft_wgpu_matches_cpu() {
571 if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
572 .expect("wgpu provider")
573 {
574 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
575 let tensor_cpu = tensor.clone();
576 let view = runmat_accelerate_api::HostTensorView {
577 data: &tensor.data,
578 shape: &tensor.shape,
579 };
580 let handle = provider.upload(&view).expect("upload");
581 let gpu =
582 fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
583 let cpu = fft_builtin_sync(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
584 let gpu_ct = value_as_complex_tensor(gpu);
585 let cpu_ct = value_as_complex_tensor(cpu);
586 let tol = match provider.precision() {
587 runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
588 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
589 };
590 assert_eq!(gpu_ct.shape, cpu_ct.shape);
591 for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
592 assert!(approx_eq(*a, *b, tol), "{a:?} vs {b:?}");
593 }
594 provider.free(&handle).ok();
595 }
596 }
597}