1use super::common::{
4 default_dimension, gather_gpu_complex_tensor, parse_length, transform_complex_tensor,
5 value_to_complex_tensor, TransformDirection,
6};
7use runmat_accelerate_api::GpuTensorHandle;
8use runmat_builtins::{ComplexTensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{shape::normalize_scalar_shape, tensor};
17use crate::builtins::math::fft::type_resolvers::fft_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::fft::forward")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "fft",
23 op_kind: GpuOpKind::Custom("fft"),
24 supported_precisions: &[ScalarType::F32, ScalarType::F64],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[ProviderHook::Custom("fft_dim")],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::NewHandle,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Providers should implement `fft_dim` to transform along an arbitrary dimension; the runtime gathers to host when unavailable.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::fft::forward")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "fft",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes:
45 "FFT participates in fusion plans only as a boundary; no fused kernels are generated today.",
46};
47
48const BUILTIN_NAME: &str = "fft";
49
50fn fft_error(message: impl Into<String>) -> RuntimeError {
51 build_runtime_error(message)
52 .with_builtin(BUILTIN_NAME)
53 .build()
54}
55
56#[runtime_builtin(
57 name = "fft",
58 category = "math/fft",
59 summary = "Compute the discrete Fourier transform (DFT) of numeric or complex data.",
60 keywords = "fft,fourier transform,complex,gpu",
61 type_resolver(fft_type),
62 builtin_path = "crate::builtins::math::fft::forward"
63)]
64async fn fft_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
65 let (length, dimension) = parse_arguments(&rest).await?;
66 match value {
67 Value::GpuTensor(handle) => fft_gpu(handle, length, dimension).await,
68 other => fft_host(other, length, dimension),
69 }
70}
71
72fn fft_host(value: Value, length: Option<usize>, dimension: Option<usize>) -> BuiltinResult<Value> {
73 let tensor = value_to_complex_tensor(value, BUILTIN_NAME)?;
74 let transformed = fft_complex_tensor(tensor, length, dimension)?;
75 Ok(complex_tensor_into_value(transformed))
76}
77
78async fn fft_gpu(
79 handle: GpuTensorHandle,
80 length: Option<usize>,
81 dimension: Option<usize>,
82) -> BuiltinResult<Value> {
83 let mut shape = normalize_scalar_shape(&handle.shape);
84
85 let dim_one_based = match dimension {
86 Some(0) => return Err(fft_error("fft: dimension must be >= 1")),
87 Some(dim) => dim,
88 None => default_dimension(&shape),
89 };
90
91 let dim_index = dim_one_based - 1;
92 while shape.len() <= dim_index {
93 shape.push(1);
94 }
95 let current_len = shape[dim_index];
96 let target_len = length.unwrap_or(current_len);
97
98 if target_len == 0 {
99 let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
100 let transformed = fft_complex_tensor(complex, length, dimension)?;
101 return Ok(complex_tensor_into_value(transformed));
102 }
103
104 if let Some(provider) = runmat_accelerate_api::provider() {
105 if let Ok(out) = provider.fft_dim(&handle, length, dim_index).await {
106 return Ok(Value::GpuTensor(out));
107 }
108 }
109
110 let complex = gather_gpu_complex_tensor(&handle, BUILTIN_NAME).await?;
111 let transformed = fft_complex_tensor(complex, length, dimension)?;
112 Ok(complex_tensor_into_value(transformed))
113}
114
115async fn parse_dimension_arg(value: &Value) -> BuiltinResult<usize> {
116 tensor::dimension_from_value_async(value, BUILTIN_NAME, false)
117 .await
118 .map_err(fft_error)?
119 .ok_or_else(|| {
120 fft_error(format!(
121 "{BUILTIN_NAME}: dimension must be numeric, got {value:?}"
122 ))
123 })
124}
125
126async fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<usize>, Option<usize>)> {
127 match args.len() {
128 0 => Ok((None, None)),
129 1 => {
130 let len = parse_length(&args[0], BUILTIN_NAME)?;
131 Ok((len, None))
132 }
133 2 => {
134 let len = parse_length(&args[0], BUILTIN_NAME)?;
135 let dim = Some(parse_dimension_arg(&args[1]).await?);
136 Ok((len, dim))
137 }
138 _ => Err(fft_error(
139 "fft: expected fft(X), fft(X, N), or fft(X, N, DIM)",
140 )),
141 }
142}
143
144pub(super) fn fft_complex_tensor(
145 tensor: ComplexTensor,
146 length: Option<usize>,
147 dimension: Option<usize>,
148) -> BuiltinResult<ComplexTensor> {
149 transform_complex_tensor(
150 tensor,
151 length,
152 dimension,
153 TransformDirection::Forward,
154 BUILTIN_NAME,
155 )
156}
157
158#[cfg(test)]
159pub(crate) mod tests {
160 use super::*;
161 use crate::builtins::common::test_support;
162 use crate::builtins::math::fft::common;
163 use futures::executor::block_on;
164 use num_complex::Complex;
165 #[cfg(feature = "wgpu")]
166 use runmat_accelerate_api::AccelProvider;
167 use runmat_builtins::{
168 ComplexTensor as HostComplexTensor, IntValue, ResolveContext, Tensor, Type,
169 };
170 use rustfft::FftPlanner;
171
172 fn approx_eq(a: (f64, f64), b: (f64, f64), tol: f64) -> bool {
173 (a.0 - b.0).abs() <= tol && (a.1 - b.1).abs() <= tol
174 }
175
176 fn error_message(error: crate::RuntimeError) -> String {
177 error.message().to_string()
178 }
179
180 fn value_as_complex_tensor(value: Value) -> HostComplexTensor {
181 match value {
182 Value::ComplexTensor(tensor) => tensor,
183 Value::Complex(re, im) => HostComplexTensor::new(vec![(re, im)], vec![1, 1]).unwrap(),
184 Value::GpuTensor(handle) => {
185 let provider = runmat_accelerate_api::provider_for_handle(&handle)
186 .or_else(runmat_accelerate_api::provider)
187 .expect("provider for gpu handle");
188 let host = block_on(provider.download(&handle)).expect("download gpu fft output");
189 common::host_to_complex_tensor(host, BUILTIN_NAME).expect("decode gpu complex")
190 }
191 other => panic!("expected complex tensor, got {other:?}"),
192 }
193 }
194
195 fn fft_builtin_sync(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
196 block_on(super::fft_builtin(value, rest))
197 }
198
199 #[test]
200 fn fft_type_preserves_shape() {
201 let out = fft_type(
202 &[Type::Tensor {
203 shape: Some(vec![Some(2), Some(3)]),
204 }],
205 &ResolveContext::new(Vec::new()),
206 );
207 assert_eq!(
208 out,
209 Type::Tensor {
210 shape: Some(vec![Some(2), Some(3)])
211 }
212 );
213 }
214
215 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
216 #[test]
217 fn fft_real_vector() {
218 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
219 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
220 match result {
221 Value::ComplexTensor(ct) => {
222 assert_eq!(ct.shape, vec![4]);
223 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
224 for (idx, val) in ct.data.iter().enumerate() {
225 assert!(
226 approx_eq(*val, expected[idx], 1e-12),
227 "idx {idx} {:?} ~= {:?}",
228 val,
229 expected[idx]
230 );
231 }
232 }
233 other => panic!("expected complex tensor, got {other:?}"),
234 }
235 }
236
237 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
238 #[test]
239 fn fft_matrix_default_dimension() {
240 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
241 let result = fft_host(Value::Tensor(tensor), None, None).expect("fft");
242 match result {
243 Value::ComplexTensor(ct) => {
244 assert_eq!(ct.shape, vec![2, 3]);
245 let expected = [
246 (5.0, 0.0),
247 (-3.0, 0.0),
248 (7.0, 0.0),
249 (-3.0, 0.0),
250 (9.0, 0.0),
251 (-3.0, 0.0),
252 ];
253 for (idx, val) in ct.data.iter().enumerate() {
254 assert!(approx_eq(*val, expected[idx], 1e-12));
255 }
256 }
257 other => panic!("expected complex tensor, got {other:?}"),
258 }
259 }
260
261 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
262 #[test]
263 fn fft_zero_padding_with_length_argument() {
264 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
265 let result =
266 fft_host(Value::Tensor(tensor), Some(5), None).expect("fft with explicit length");
267 match result {
268 Value::ComplexTensor(ct) => {
269 assert_eq!(ct.shape, vec![5]);
270 assert!(approx_eq(ct.data[0], (6.0, 0.0), 1e-12));
271 assert_eq!(ct.data.len(), 5);
272 }
273 other => panic!("expected complex tensor, got {other:?}"),
274 }
275 }
276
277 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
278 #[test]
279 fn fft_empty_length_argument_defaults_to_input_length() {
280 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
281 let baseline =
282 fft_builtin_sync(Value::Tensor(tensor.clone()), Vec::new()).expect("baseline fft");
283 let empty = Tensor::new(Vec::<f64>::new(), vec![0]).unwrap();
284 let result = fft_builtin_sync(
285 Value::Tensor(tensor),
286 vec![Value::Tensor(empty), Value::Int(IntValue::I32(1))],
287 )
288 .expect("fft with empty length");
289 let base_ct = value_as_complex_tensor(baseline);
290 let result_ct = value_as_complex_tensor(result);
291 assert_eq!(base_ct.shape, result_ct.shape);
292 assert_eq!(base_ct.data.len(), result_ct.data.len());
293 for (idx, (a, b)) in base_ct.data.iter().zip(result_ct.data.iter()).enumerate() {
294 assert!(
295 approx_eq(*a, *b, 1e-12),
296 "mismatch at index {idx}: {:?} vs {:?}",
297 a,
298 b
299 );
300 }
301 }
302
303 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304 #[test]
305 fn fft_truncates_when_length_smaller() {
306 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
307 let result =
308 fft_host(Value::Tensor(tensor), Some(2), None).expect("fft with truncation length");
309 match result {
310 Value::ComplexTensor(ct) => {
311 assert_eq!(ct.shape, vec![2]);
312 let expected = [(3.0, 0.0), (-1.0, 0.0)];
313 for (idx, val) in ct.data.iter().enumerate() {
314 assert!(approx_eq(*val, expected[idx], 1e-12));
315 }
316 }
317 other => panic!("expected complex tensor, got {other:?}"),
318 }
319 }
320
321 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
322 #[test]
323 fn fft_zero_length_returns_empty_tensor() {
324 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
325 let result = fft_host(Value::Tensor(tensor), Some(0), None).expect("fft with zero length");
326 match result {
327 Value::ComplexTensor(ct) => {
328 assert_eq!(ct.shape, vec![0]);
329 assert!(ct.data.is_empty());
330 }
331 other => panic!("expected complex tensor, got {other:?}"),
332 }
333 }
334
335 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
336 #[test]
337 fn fft_complex_input_preserves_imaginary_components() {
338 let tensor =
339 HostComplexTensor::new(vec![(1.0, 1.0), (0.0, -1.0), (2.0, 0.5)], vec![3]).unwrap();
340 let result =
341 fft_host(Value::ComplexTensor(tensor.clone()), None, None).expect("fft complex");
342 let mut expected = tensor
343 .data
344 .iter()
345 .map(|(re, im)| Complex::new(*re, *im))
346 .collect::<Vec<_>>();
347 FftPlanner::<f64>::new()
348 .plan_fft_forward(expected.len())
349 .process(&mut expected);
350 match result {
351 Value::ComplexTensor(ct) => {
352 assert_eq!(ct.shape, vec![3]);
353 assert_eq!(ct.data.len(), 3);
354 for (idx, val) in ct.data.iter().enumerate() {
355 let exp = expected[idx];
356 assert!(approx_eq(*val, (exp.re, exp.im), 1e-12));
357 }
358 }
359 other => panic!("expected complex tensor, got {other:?}"),
360 }
361 }
362
363 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364 #[test]
365 fn fft_row_vector_dimension_two() {
366 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
367 let result = fft_host(Value::Tensor(tensor), None, Some(2)).expect("fft along dimension 2");
368 match result {
369 Value::ComplexTensor(ct) => {
370 assert_eq!(ct.shape, vec![1, 4]);
371 let expected = [(10.0, 0.0), (-2.0, 2.0), (-2.0, 0.0), (-2.0, -2.0)];
372 for (idx, val) in ct.data.iter().enumerate() {
373 assert!(approx_eq(*val, expected[idx], 1e-12));
374 }
375 }
376 other => panic!("expected complex tensor, got {other:?}"),
377 }
378 }
379
380 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
381 #[test]
382 fn fft_dimension_extends_rank() {
383 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
384 let original = tensor.clone();
385 let result =
386 fft_host(Value::Tensor(tensor), None, Some(3)).expect("fft with extra dimension");
387 match result {
388 Value::ComplexTensor(ct) => {
389 assert_eq!(ct.shape, vec![1, 4, 1]);
390 assert_eq!(ct.data.len(), original.data.len());
391 for (idx, (re, im)) in ct.data.iter().enumerate() {
392 assert!(approx_eq((*re, *im), (original.data[idx], 0.0), 1e-12));
393 }
394 }
395 other => panic!("expected complex tensor, got {other:?}"),
396 }
397 }
398
399 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400 #[test]
401 fn fft_dimension_extends_rank_with_padding() {
402 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
403 let original = tensor.clone();
404 let result = fft_host(Value::Tensor(tensor), Some(4), Some(3))
405 .expect("fft with padded third dimension");
406 match result {
407 Value::ComplexTensor(ct) => {
408 assert_eq!(ct.shape, vec![1, 4, 4]);
409 let mut expected = Vec::with_capacity(16);
410 for _depth in 0..4 {
411 for &value in &original.data {
412 expected.push((value, 0.0));
413 }
414 }
415 assert_eq!(ct.data.len(), expected.len());
416 for (idx, (actual, expected)) in ct.data.iter().zip(expected.iter()).enumerate() {
417 assert!(
418 approx_eq(*actual, *expected, 1e-12),
419 "idx {idx}: {:?} != {:?}",
420 actual,
421 expected
422 );
423 }
424 }
425 other => panic!("expected complex tensor, got {other:?}"),
426 }
427 }
428
429 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
430 #[test]
431 fn fft_rejects_non_numeric_length() {
432 assert!(block_on(parse_arguments(&[Value::Bool(true)])).is_err());
433 }
434
435 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
436 #[test]
437 fn fft_rejects_negative_length() {
438 let err = error_message(block_on(parse_arguments(&[Value::Num(-1.0)])).unwrap_err());
439 assert!(err.contains("length must be non-negative"));
440 }
441
442 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
443 #[test]
444 fn fft_rejects_fractional_length() {
445 let err = error_message(block_on(parse_arguments(&[Value::Num(1.5)])).unwrap_err());
446 assert!(err.contains("length must be an integer"));
447 }
448
449 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
450 #[test]
451 fn fft_rejects_dimension_zero() {
452 let err = error_message(
453 block_on(parse_arguments(&[
454 Value::Num(4.0),
455 Value::Int(IntValue::I32(0)),
456 ]))
457 .unwrap_err(),
458 );
459 assert!(err.contains("dimension must be >= 1"));
460 }
461
462 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
463 #[test]
464 fn fft_accepts_scalar_tensor_dimension_argument() {
465 let dim = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
466 let (len, parsed_dim) = block_on(parse_arguments(&[Value::Num(4.0), Value::Tensor(dim)]))
467 .expect("parse arguments");
468 assert_eq!(len, Some(4));
469 assert_eq!(parsed_dim, Some(2));
470 }
471
472 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
473 #[test]
474 fn fft_gpu_roundtrip_matches_cpu() {
475 test_support::with_test_provider(|provider| {
476 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
477 let view = runmat_accelerate_api::HostTensorView {
478 data: &tensor.data,
479 shape: &tensor.shape,
480 };
481 let handle = provider.upload(&view).expect("upload");
482 let gpu = fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("fft");
483 let cpu = fft_builtin_sync(Value::Tensor(tensor), Vec::new()).expect("fft");
484 let gpu_host = value_as_complex_tensor(gpu);
485 let cpu_host = value_as_complex_tensor(cpu);
486 assert_eq!(gpu_host.shape, cpu_host.shape);
487 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
488 assert!(approx_eq(*a, *b, 1e-12));
489 }
490 provider.free(&handle).ok();
491 });
492 }
493
494 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495 #[test]
496 fn fft_gpu_non_power_of_two_length_matches_cpu() {
497 test_support::with_test_provider(|provider| {
498 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
499 let view = runmat_accelerate_api::HostTensorView {
500 data: &tensor.data,
501 shape: &tensor.shape,
502 };
503 let handle = provider.upload(&view).expect("upload");
504 let gpu = fft_builtin_sync(
505 Value::GpuTensor(handle.clone()),
506 vec![Value::Int(IntValue::I32(7))],
507 )
508 .expect("fft gpu");
509 let cpu = fft_builtin_sync(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(7))])
510 .expect("fft cpu");
511 let gpu_host = value_as_complex_tensor(gpu);
512 let cpu_host = value_as_complex_tensor(cpu);
513 assert_eq!(gpu_host.shape, cpu_host.shape);
514 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
515 assert!(approx_eq(*a, *b, 1e-10));
516 }
517 provider.free(&handle).ok();
518 });
519 }
520
521 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
522 #[test]
523 fn fft_gpu_prime_length_on_non_last_dimension_matches_cpu() {
524 test_support::with_test_provider(|provider| {
525 let tensor = Tensor::new((1..=18).map(|v| v as f64).collect(), vec![2, 3, 3]).unwrap();
526 let view = runmat_accelerate_api::HostTensorView {
527 data: &tensor.data,
528 shape: &tensor.shape,
529 };
530 let handle = provider.upload(&view).expect("upload");
531 let args = vec![Value::Int(IntValue::I32(7)), Value::Int(IntValue::I32(2))];
532 let gpu =
533 fft_builtin_sync(Value::GpuTensor(handle.clone()), args.clone()).expect("fft gpu");
534 let cpu = fft_builtin_sync(Value::Tensor(tensor), args).expect("fft cpu");
535 let gpu_host = value_as_complex_tensor(gpu);
536 let cpu_host = value_as_complex_tensor(cpu);
537 assert_eq!(gpu_host.shape, cpu_host.shape);
538 for (a, b) in gpu_host.data.iter().zip(cpu_host.data.iter()) {
539 assert!(approx_eq(*a, *b, 1e-10), "{a:?} vs {b:?}");
540 }
541 provider.free(&handle).ok();
542 });
543 }
544
545 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
546 #[test]
547 #[cfg(feature = "wgpu")]
548 fn fft_wgpu_matches_cpu() {
549 if let Some(provider) = runmat_accelerate::backend::wgpu::provider::ensure_wgpu_provider()
550 .expect("wgpu provider")
551 {
552 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
553 let tensor_cpu = tensor.clone();
554 let view = runmat_accelerate_api::HostTensorView {
555 data: &tensor.data,
556 shape: &tensor.shape,
557 };
558 let handle = provider.upload(&view).expect("upload");
559 let gpu =
560 fft_builtin_sync(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpu fft");
561 let cpu = fft_builtin_sync(Value::Tensor(tensor_cpu), Vec::new()).expect("cpu fft");
562 let gpu_ct = value_as_complex_tensor(gpu);
563 let cpu_ct = value_as_complex_tensor(cpu);
564 let tol = match provider.precision() {
565 runmat_accelerate_api::ProviderPrecision::F64 => 1e-10,
566 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
567 };
568 assert_eq!(gpu_ct.shape, cpu_ct.shape);
569 for (a, b) in gpu_ct.data.iter().zip(cpu_ct.data.iter()) {
570 assert!(approx_eq(*a, *b, tol), "{a:?} vs {b:?}");
571 }
572 provider.free(&handle).ok();
573 }
574 }
575}