1use crate::winograd_constants::{
18 C3_1, C3_2, C5_COS1, C5_COS2, C5_SIN1, C5_SIN2, C7_COS1, C7_COS2, C7_COS3, C7_SIN1, C7_SIN2,
19 C7_SIN3,
20};
21use proc_macro2::TokenStream;
22use quote::quote;
23use syn::LitInt;
24
25pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
35 let size: LitInt = syn::parse2(input)?;
36 let n: usize = size.base10_parse().map_err(|_| {
37 syn::Error::new(
38 size.span(),
39 "gen_odd_codelet: expected an integer size literal",
40 )
41 })?;
42
43 match n {
44 3 => Ok(gen_size_3()),
45 5 => Ok(gen_size_5()),
46 7 => Ok(gen_size_7()),
47 _ => Err(syn::Error::new(
48 size.span(),
49 format!("gen_odd_codelet: unsupported size {n} (expected one of 3, 5, 7)"),
50 )),
51 }
52}
53
54fn gen_size_3() -> TokenStream {
81 let c3_1 = C3_1;
82 let c3_2 = C3_2;
83 quote! {
84 #[inline(always)]
91 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
92 pub fn codelet_notw_3<T: crate::kernel::Float>(
93 x: &mut [crate::kernel::Complex<T>],
94 sign: i32,
95 ) {
96 debug_assert!(x.len() >= 3);
97
98 let x0 = x[0];
99 let x1 = x[1];
100 let x2 = x[2];
101
102 let s_re = x1.re + x2.re;
104 let s_im = x1.im + x2.im;
105 let d_re = x1.re - x2.re;
106 let d_im = x1.im - x2.im;
107
108 x[0] = crate::kernel::Complex::new(x0.re + s_re, x0.im + s_im);
110
111 let c3_1 = T::from_f64(#c3_1);
113 let c3_2 = T::from_f64(#c3_2);
114 let tmp_re = x0.re + c3_1 * s_re;
115 let tmp_im = x0.im + c3_1 * s_im;
116
117 if sign < 0 {
118 x[1] = crate::kernel::Complex::new(tmp_re + c3_2 * d_im, tmp_im - c3_2 * d_re);
121 x[2] = crate::kernel::Complex::new(tmp_re - c3_2 * d_im, tmp_im + c3_2 * d_re);
122 } else {
123 x[1] = crate::kernel::Complex::new(tmp_re - c3_2 * d_im, tmp_im + c3_2 * d_re);
125 x[2] = crate::kernel::Complex::new(tmp_re + c3_2 * d_im, tmp_im - c3_2 * d_re);
126 }
127 }
128 }
129}
130
131fn gen_size_5() -> TokenStream {
172 let c5_cos1 = C5_COS1;
173 let c5_cos2 = C5_COS2;
174 let c5_sin1 = C5_SIN1;
175 let c5_sin2 = C5_SIN2;
176 quote! {
177 #[inline(always)]
184 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
185 pub fn codelet_notw_5<T: crate::kernel::Float>(
186 x: &mut [crate::kernel::Complex<T>],
187 sign: i32,
188 ) {
189 debug_assert!(x.len() >= 5);
190
191 let x0 = x[0];
192 let x1 = x[1];
193 let x2 = x[2];
194 let x3 = x[3];
195 let x4 = x[4];
196
197 let r1_re = x1.re + x4.re;
201 let r1_im = x1.im + x4.im;
202 let r2_re = x2.re + x3.re;
203 let r2_im = x2.im + x3.im;
204 let i1_re = x1.re - x4.re;
205 let i1_im = x1.im - x4.im;
206 let i2_re = x2.re - x3.re;
207 let i2_im = x2.im - x3.im;
208
209 x[0] = crate::kernel::Complex::new(x0.re + r1_re + r2_re, x0.im + r1_im + r2_im);
211
212 let cos1 = T::from_f64(#c5_cos1);
213 let cos2 = T::from_f64(#c5_cos2);
214 let sin1 = T::from_f64(#c5_sin1);
215 let sin2 = T::from_f64(#c5_sin2);
216
217 let cr1_re = cos1 * r1_re + cos2 * r2_re;
219 let cr1_im = cos1 * r1_im + cos2 * r2_im;
220 let cr2_re = cos2 * r1_re + cos1 * r2_re;
221 let cr2_im = cos2 * r1_im + cos1 * r2_im;
222
223 let sr1_re = sin1 * i1_re + sin2 * i2_re;
225 let sr1_im = sin1 * i1_im + sin2 * i2_im;
226 let sr2_re = sin2 * i1_re - sin1 * i2_re;
227 let sr2_im = sin2 * i1_im - sin1 * i2_im;
228
229 let tmp1_re = x0.re + cr1_re;
231 let tmp1_im = x0.im + cr1_im;
232 let tmp2_re = x0.re + cr2_re;
233 let tmp2_im = x0.im + cr2_im;
234
235 if sign < 0 {
236 x[1] = crate::kernel::Complex::new(tmp1_re + sr1_im, tmp1_im - sr1_re);
240 x[4] = crate::kernel::Complex::new(tmp1_re - sr1_im, tmp1_im + sr1_re);
241 x[2] = crate::kernel::Complex::new(tmp2_re + sr2_im, tmp2_im - sr2_re);
242 x[3] = crate::kernel::Complex::new(tmp2_re - sr2_im, tmp2_im + sr2_re);
243 } else {
244 x[1] = crate::kernel::Complex::new(tmp1_re - sr1_im, tmp1_im + sr1_re);
246 x[4] = crate::kernel::Complex::new(tmp1_re + sr1_im, tmp1_im - sr1_re);
247 x[2] = crate::kernel::Complex::new(tmp2_re - sr2_im, tmp2_im + sr2_re);
248 x[3] = crate::kernel::Complex::new(tmp2_re + sr2_im, tmp2_im - sr2_re);
249 }
250 }
251 }
252}
253
254fn gen_size_7() -> TokenStream {
291 let c7_cos1 = C7_COS1;
292 let c7_cos2 = C7_COS2;
293 let c7_cos3 = C7_COS3;
294 let c7_sin1 = C7_SIN1;
295 let c7_sin2 = C7_SIN2;
296 let c7_sin3 = C7_SIN3;
297 quote! {
298 #[inline(always)]
306 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
307 pub fn codelet_notw_7<T: crate::kernel::Float>(
308 x: &mut [crate::kernel::Complex<T>],
309 sign: i32,
310 ) {
311 debug_assert!(x.len() >= 7);
312
313 let x0 = x[0];
314 let x1 = x[1];
315 let x2 = x[2];
316 let x3 = x[3];
317 let x4 = x[4];
318 let x5 = x[5];
319 let x6 = x[6];
320
321 let r1_re = x1.re + x6.re;
324 let r1_im = x1.im + x6.im;
325 let r2_re = x2.re + x5.re;
326 let r2_im = x2.im + x5.im;
327 let r3_re = x3.re + x4.re;
328 let r3_im = x3.im + x4.im;
329 let i1_re = x1.re - x6.re;
330 let i1_im = x1.im - x6.im;
331 let i2_re = x2.re - x5.re;
332 let i2_im = x2.im - x5.im;
333 let i3_re = x3.re - x4.re;
334 let i3_im = x3.im - x4.im;
335
336 x[0] = crate::kernel::Complex::new(
338 x0.re + r1_re + r2_re + r3_re,
339 x0.im + r1_im + r2_im + r3_im,
340 );
341
342 let cos1 = T::from_f64(#c7_cos1);
343 let cos2 = T::from_f64(#c7_cos2);
344 let cos3 = T::from_f64(#c7_cos3);
345 let sin1 = T::from_f64(#c7_sin1);
346 let sin2 = T::from_f64(#c7_sin2);
347 let sin3 = T::from_f64(#c7_sin3);
348
349 let cp1_re = cos1 * r1_re + cos2 * r2_re + cos3 * r3_re;
352 let cp1_im = cos1 * r1_im + cos2 * r2_im + cos3 * r3_im;
353 let cp2_re = cos2 * r1_re + cos3 * r2_re + cos1 * r3_re;
355 let cp2_im = cos2 * r1_im + cos3 * r2_im + cos1 * r3_im;
356 let cp3_re = cos3 * r1_re + cos1 * r2_re + cos2 * r3_re;
358 let cp3_im = cos3 * r1_im + cos1 * r2_im + cos2 * r3_im;
359
360 let sp1_re = sin1 * i1_im + sin2 * i2_im + sin3 * i3_im;
364 let sp1_im = sin1 * i1_re + sin2 * i2_re + sin3 * i3_re;
365 let sp2_re = sin2 * i1_im - sin3 * i2_im - sin1 * i3_im;
368 let sp2_im = sin2 * i1_re - sin3 * i2_re - sin1 * i3_re;
369 let sp3_re = sin3 * i1_im - sin1 * i2_im + sin2 * i3_im;
372 let sp3_im = sin3 * i1_re - sin1 * i2_re + sin2 * i3_re;
373
374 let tmp1_re = x0.re + cp1_re;
376 let tmp1_im = x0.im + cp1_im;
377 let tmp2_re = x0.re + cp2_re;
378 let tmp2_im = x0.im + cp2_im;
379 let tmp3_re = x0.re + cp3_re;
380 let tmp3_im = x0.im + cp3_im;
381
382 if sign < 0 {
383 x[1] = crate::kernel::Complex::new(tmp1_re + sp1_re, tmp1_im - sp1_im);
387 x[6] = crate::kernel::Complex::new(tmp1_re - sp1_re, tmp1_im + sp1_im);
388 x[2] = crate::kernel::Complex::new(tmp2_re + sp2_re, tmp2_im - sp2_im);
389 x[5] = crate::kernel::Complex::new(tmp2_re - sp2_re, tmp2_im + sp2_im);
390 x[3] = crate::kernel::Complex::new(tmp3_re + sp3_re, tmp3_im - sp3_im);
391 x[4] = crate::kernel::Complex::new(tmp3_re - sp3_re, tmp3_im + sp3_im);
392 } else {
393 x[1] = crate::kernel::Complex::new(tmp1_re - sp1_re, tmp1_im + sp1_im);
395 x[6] = crate::kernel::Complex::new(tmp1_re + sp1_re, tmp1_im - sp1_im);
396 x[2] = crate::kernel::Complex::new(tmp2_re - sp2_re, tmp2_im + sp2_im);
397 x[5] = crate::kernel::Complex::new(tmp2_re + sp2_re, tmp2_im - sp2_im);
398 x[3] = crate::kernel::Complex::new(tmp3_re - sp3_re, tmp3_im + sp3_im);
399 x[4] = crate::kernel::Complex::new(tmp3_re + sp3_re, tmp3_im - sp3_im);
400 }
401 }
402 }
403}
404
405#[cfg(test)]
415#[allow(clippy::suboptimal_flops)]
416pub(crate) fn naive_dft_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
417 let n = x_re.len();
418 debug_assert_eq!(x_im.len(), n);
419 let mut out_re = vec![0.0_f64; n];
420 let mut out_im = vec![0.0_f64; n];
421 for k in 0..n {
422 for j in 0..n {
423 let angle = -2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
424 let (s, c) = angle.sin_cos();
425 out_re[k] += x_re[j] * c - x_im[j] * s;
426 out_im[k] += x_re[j] * s + x_im[j] * c;
427 }
428 }
429 (out_re, out_im)
430}
431
432#[cfg(test)]
434#[allow(clippy::suboptimal_flops)]
435pub(crate) fn naive_dft_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
436 let n = x_re.len();
437 debug_assert_eq!(x_im.len(), n);
438 let mut out_re = vec![0.0_f64; n];
439 let mut out_im = vec![0.0_f64; n];
440 for k in 0..n {
441 for j in 0..n {
442 let angle = 2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
443 let (s, c) = angle.sin_cos();
444 out_re[k] += x_re[j] * c - x_im[j] * s;
445 out_im[k] += x_re[j] * s + x_im[j] * c;
446 }
447 }
448 (out_re, out_im)
449}
450
451#[cfg(test)]
455#[allow(clippy::suboptimal_flops)]
456pub(crate) fn winograd_dft3_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
457 debug_assert_eq!(x_re.len(), 3);
458 let mut out_re = vec![0.0_f64; 3];
459 let mut out_im = vec![0.0_f64; 3];
460
461 let s_re = x_re[1] + x_re[2];
462 let s_im = x_im[1] + x_im[2];
463 let d_re = x_re[1] - x_re[2];
464 let d_im = x_im[1] - x_im[2];
465
466 out_re[0] = x_re[0] + s_re;
467 out_im[0] = x_im[0] + s_im;
468
469 let tmp_re = x_re[0] + C3_1 * s_re;
470 let tmp_im = x_im[0] + C3_1 * s_im;
471
472 out_re[1] = tmp_re + C3_2 * d_im;
473 out_im[1] = tmp_im - C3_2 * d_re;
474 out_re[2] = tmp_re - C3_2 * d_im;
475 out_im[2] = tmp_im + C3_2 * d_re;
476
477 (out_re, out_im)
478}
479
480#[cfg(test)]
482#[allow(clippy::suboptimal_flops)]
483pub(crate) fn winograd_dft3_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
484 debug_assert_eq!(x_re.len(), 3);
485 let mut out_re = vec![0.0_f64; 3];
486 let mut out_im = vec![0.0_f64; 3];
487
488 let s_re = x_re[1] + x_re[2];
489 let s_im = x_im[1] + x_im[2];
490 let d_re = x_re[1] - x_re[2];
491 let d_im = x_im[1] - x_im[2];
492
493 out_re[0] = x_re[0] + s_re;
494 out_im[0] = x_im[0] + s_im;
495
496 let tmp_re = x_re[0] + C3_1 * s_re;
497 let tmp_im = x_im[0] + C3_1 * s_im;
498
499 out_re[1] = tmp_re - C3_2 * d_im;
501 out_im[1] = tmp_im + C3_2 * d_re;
502 out_re[2] = tmp_re + C3_2 * d_im;
503 out_im[2] = tmp_im - C3_2 * d_re;
504
505 (out_re, out_im)
506}
507
508#[cfg(test)]
510#[allow(clippy::suboptimal_flops)]
511pub(crate) fn winograd_dft5_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
512 debug_assert_eq!(x_re.len(), 5);
513 let mut out_re = vec![0.0_f64; 5];
514 let mut out_im = vec![0.0_f64; 5];
515
516 let r1_re = x_re[1] + x_re[4];
517 let r1_im = x_im[1] + x_im[4];
518 let r2_re = x_re[2] + x_re[3];
519 let r2_im = x_im[2] + x_im[3];
520 let i1_re = x_re[1] - x_re[4];
521 let i1_im = x_im[1] - x_im[4];
522 let i2_re = x_re[2] - x_re[3];
523 let i2_im = x_im[2] - x_im[3];
524
525 out_re[0] = x_re[0] + r1_re + r2_re;
526 out_im[0] = x_im[0] + r1_im + r2_im;
527
528 let cr1_re = C5_COS1 * r1_re + C5_COS2 * r2_re;
529 let cr1_im = C5_COS1 * r1_im + C5_COS2 * r2_im;
530 let cr2_re = C5_COS2 * r1_re + C5_COS1 * r2_re;
531 let cr2_im = C5_COS2 * r1_im + C5_COS1 * r2_im;
532
533 let sr1_re = C5_SIN1 * i1_re + C5_SIN2 * i2_re;
534 let sr1_im = C5_SIN1 * i1_im + C5_SIN2 * i2_im;
535 let sr2_re = C5_SIN2 * i1_re - C5_SIN1 * i2_re;
536 let sr2_im = C5_SIN2 * i1_im - C5_SIN1 * i2_im;
537
538 let tmp1_re = x_re[0] + cr1_re;
539 let tmp1_im = x_im[0] + cr1_im;
540 let tmp2_re = x_re[0] + cr2_re;
541 let tmp2_im = x_im[0] + cr2_im;
542
543 out_re[1] = tmp1_re + sr1_im;
545 out_im[1] = tmp1_im - sr1_re;
546 out_re[4] = tmp1_re - sr1_im;
547 out_im[4] = tmp1_im + sr1_re;
548 out_re[2] = tmp2_re + sr2_im;
549 out_im[2] = tmp2_im - sr2_re;
550 out_re[3] = tmp2_re - sr2_im;
551 out_im[3] = tmp2_im + sr2_re;
552
553 (out_re, out_im)
554}
555
556#[cfg(test)]
558#[allow(clippy::suboptimal_flops)]
559pub(crate) fn winograd_dft5_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
560 debug_assert_eq!(x_re.len(), 5);
561 let mut out_re = vec![0.0_f64; 5];
562 let mut out_im = vec![0.0_f64; 5];
563
564 let r1_re = x_re[1] + x_re[4];
565 let r1_im = x_im[1] + x_im[4];
566 let r2_re = x_re[2] + x_re[3];
567 let r2_im = x_im[2] + x_im[3];
568 let i1_re = x_re[1] - x_re[4];
569 let i1_im = x_im[1] - x_im[4];
570 let i2_re = x_re[2] - x_re[3];
571 let i2_im = x_im[2] - x_im[3];
572
573 out_re[0] = x_re[0] + r1_re + r2_re;
574 out_im[0] = x_im[0] + r1_im + r2_im;
575
576 let cr1_re = C5_COS1 * r1_re + C5_COS2 * r2_re;
577 let cr1_im = C5_COS1 * r1_im + C5_COS2 * r2_im;
578 let cr2_re = C5_COS2 * r1_re + C5_COS1 * r2_re;
579 let cr2_im = C5_COS2 * r1_im + C5_COS1 * r2_im;
580
581 let sr1_re = C5_SIN1 * i1_re + C5_SIN2 * i2_re;
582 let sr1_im = C5_SIN1 * i1_im + C5_SIN2 * i2_im;
583 let sr2_re = C5_SIN2 * i1_re - C5_SIN1 * i2_re;
584 let sr2_im = C5_SIN2 * i1_im - C5_SIN1 * i2_im;
585
586 let tmp1_re = x_re[0] + cr1_re;
587 let tmp1_im = x_im[0] + cr1_im;
588 let tmp2_re = x_re[0] + cr2_re;
589 let tmp2_im = x_im[0] + cr2_im;
590
591 out_re[1] = tmp1_re - sr1_im;
593 out_im[1] = tmp1_im + sr1_re;
594 out_re[4] = tmp1_re + sr1_im;
595 out_im[4] = tmp1_im - sr1_re;
596 out_re[2] = tmp2_re - sr2_im;
597 out_im[2] = tmp2_im + sr2_re;
598 out_re[3] = tmp2_re + sr2_im;
599 out_im[3] = tmp2_im - sr2_re;
600
601 (out_re, out_im)
602}
603
604#[cfg(test)]
606#[allow(clippy::suboptimal_flops)]
607pub(crate) fn winograd_dft7_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
608 debug_assert_eq!(x_re.len(), 7);
609 let mut out_re = vec![0.0_f64; 7];
610 let mut out_im = vec![0.0_f64; 7];
611
612 let r1_re = x_re[1] + x_re[6];
613 let r1_im = x_im[1] + x_im[6];
614 let r2_re = x_re[2] + x_re[5];
615 let r2_im = x_im[2] + x_im[5];
616 let r3_re = x_re[3] + x_re[4];
617 let r3_im = x_im[3] + x_im[4];
618 let i1_re = x_re[1] - x_re[6];
619 let i1_im = x_im[1] - x_im[6];
620 let i2_re = x_re[2] - x_re[5];
621 let i2_im = x_im[2] - x_im[5];
622 let i3_re = x_re[3] - x_re[4];
623 let i3_im = x_im[3] - x_im[4];
624
625 out_re[0] = x_re[0] + r1_re + r2_re + r3_re;
626 out_im[0] = x_im[0] + r1_im + r2_im + r3_im;
627
628 let cp1_re = C7_COS1 * r1_re + C7_COS2 * r2_re + C7_COS3 * r3_re;
629 let cp1_im = C7_COS1 * r1_im + C7_COS2 * r2_im + C7_COS3 * r3_im;
630 let cp2_re = C7_COS2 * r1_re + C7_COS3 * r2_re + C7_COS1 * r3_re;
631 let cp2_im = C7_COS2 * r1_im + C7_COS3 * r2_im + C7_COS1 * r3_im;
632 let cp3_re = C7_COS3 * r1_re + C7_COS1 * r2_re + C7_COS2 * r3_re;
633 let cp3_im = C7_COS3 * r1_im + C7_COS1 * r2_im + C7_COS2 * r3_im;
634
635 let sp1_re = C7_SIN1 * i1_im + C7_SIN2 * i2_im + C7_SIN3 * i3_im;
636 let sp1_im = C7_SIN1 * i1_re + C7_SIN2 * i2_re + C7_SIN3 * i3_re;
637 let sp2_re = C7_SIN2 * i1_im - C7_SIN3 * i2_im - C7_SIN1 * i3_im;
638 let sp2_im = C7_SIN2 * i1_re - C7_SIN3 * i2_re - C7_SIN1 * i3_re;
639 let sp3_re = C7_SIN3 * i1_im - C7_SIN1 * i2_im + C7_SIN2 * i3_im;
640 let sp3_im = C7_SIN3 * i1_re - C7_SIN1 * i2_re + C7_SIN2 * i3_re;
641
642 let tmp1_re = x_re[0] + cp1_re;
643 let tmp1_im = x_im[0] + cp1_im;
644 let tmp2_re = x_re[0] + cp2_re;
645 let tmp2_im = x_im[0] + cp2_im;
646 let tmp3_re = x_re[0] + cp3_re;
647 let tmp3_im = x_im[0] + cp3_im;
648
649 out_re[1] = tmp1_re + sp1_re;
651 out_im[1] = tmp1_im - sp1_im;
652 out_re[6] = tmp1_re - sp1_re;
653 out_im[6] = tmp1_im + sp1_im;
654 out_re[2] = tmp2_re + sp2_re;
655 out_im[2] = tmp2_im - sp2_im;
656 out_re[5] = tmp2_re - sp2_re;
657 out_im[5] = tmp2_im + sp2_im;
658 out_re[3] = tmp3_re + sp3_re;
659 out_im[3] = tmp3_im - sp3_im;
660 out_re[4] = tmp3_re - sp3_re;
661 out_im[4] = tmp3_im + sp3_im;
662
663 (out_re, out_im)
664}
665
666#[cfg(test)]
668#[allow(clippy::suboptimal_flops)]
669pub(crate) fn winograd_dft7_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
670 debug_assert_eq!(x_re.len(), 7);
671 let mut out_re = vec![0.0_f64; 7];
672 let mut out_im = vec![0.0_f64; 7];
673
674 let r1_re = x_re[1] + x_re[6];
675 let r1_im = x_im[1] + x_im[6];
676 let r2_re = x_re[2] + x_re[5];
677 let r2_im = x_im[2] + x_im[5];
678 let r3_re = x_re[3] + x_re[4];
679 let r3_im = x_im[3] + x_im[4];
680 let i1_re = x_re[1] - x_re[6];
681 let i1_im = x_im[1] - x_im[6];
682 let i2_re = x_re[2] - x_re[5];
683 let i2_im = x_im[2] - x_im[5];
684 let i3_re = x_re[3] - x_re[4];
685 let i3_im = x_im[3] - x_im[4];
686
687 out_re[0] = x_re[0] + r1_re + r2_re + r3_re;
688 out_im[0] = x_im[0] + r1_im + r2_im + r3_im;
689
690 let cp1_re = C7_COS1 * r1_re + C7_COS2 * r2_re + C7_COS3 * r3_re;
691 let cp1_im = C7_COS1 * r1_im + C7_COS2 * r2_im + C7_COS3 * r3_im;
692 let cp2_re = C7_COS2 * r1_re + C7_COS3 * r2_re + C7_COS1 * r3_re;
693 let cp2_im = C7_COS2 * r1_im + C7_COS3 * r2_im + C7_COS1 * r3_im;
694 let cp3_re = C7_COS3 * r1_re + C7_COS1 * r2_re + C7_COS2 * r3_re;
695 let cp3_im = C7_COS3 * r1_im + C7_COS1 * r2_im + C7_COS2 * r3_im;
696
697 let sp1_re = C7_SIN1 * i1_im + C7_SIN2 * i2_im + C7_SIN3 * i3_im;
698 let sp1_im = C7_SIN1 * i1_re + C7_SIN2 * i2_re + C7_SIN3 * i3_re;
699 let sp2_re = C7_SIN2 * i1_im - C7_SIN3 * i2_im - C7_SIN1 * i3_im;
700 let sp2_im = C7_SIN2 * i1_re - C7_SIN3 * i2_re - C7_SIN1 * i3_re;
701 let sp3_re = C7_SIN3 * i1_im - C7_SIN1 * i2_im + C7_SIN2 * i3_im;
702 let sp3_im = C7_SIN3 * i1_re - C7_SIN1 * i2_re + C7_SIN2 * i3_re;
703
704 let tmp1_re = x_re[0] + cp1_re;
705 let tmp1_im = x_im[0] + cp1_im;
706 let tmp2_re = x_re[0] + cp2_re;
707 let tmp2_im = x_im[0] + cp2_im;
708 let tmp3_re = x_re[0] + cp3_re;
709 let tmp3_im = x_im[0] + cp3_im;
710
711 out_re[1] = tmp1_re - sp1_re;
713 out_im[1] = tmp1_im + sp1_im;
714 out_re[6] = tmp1_re + sp1_re;
715 out_im[6] = tmp1_im - sp1_im;
716 out_re[2] = tmp2_re - sp2_re;
717 out_im[2] = tmp2_im + sp2_im;
718 out_re[5] = tmp2_re + sp2_re;
719 out_im[5] = tmp2_im - sp2_im;
720 out_re[3] = tmp3_re - sp3_re;
721 out_im[3] = tmp3_im + sp3_im;
722 out_re[4] = tmp3_re + sp3_re;
723 out_im[4] = tmp3_im - sp3_im;
724
725 (out_re, out_im)
726}
727
728#[cfg(test)]
733mod tests {
734 use super::*;
735
736 const TOL: f64 = 1e-12;
737
738 fn assert_close(a: &[f64], b: &[f64], label: &str) {
739 assert_eq!(a.len(), b.len(), "{label}: length mismatch");
740 for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
741 assert!(
742 (x - y).abs() < TOL,
743 "{label}[{i}]: got {x}, expected {y}, diff = {}",
744 (x - y).abs()
745 );
746 }
747 }
748
749 #[test]
754 fn test_dft3_forward_f64_impulse() {
755 let x_re = [1.0, 0.0, 0.0];
757 let x_im = [0.0, 0.0, 0.0];
758 let (got_re, got_im) = winograd_dft3_fwd(&x_re, &x_im);
759 assert_close(&got_re, &[1.0, 1.0, 1.0], "dft3_impulse_re");
760 assert_close(&got_im, &[0.0, 0.0, 0.0], "dft3_impulse_im");
761 }
762
763 #[test]
764 fn test_dft3_forward_vs_naive() {
765 let x_re = [1.3, -0.7, 0.4];
767 let x_im = [0.2, 1.1, -0.5];
768 let (got_re, got_im) = winograd_dft3_fwd(&x_re, &x_im);
769 let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
770 assert_close(&got_re, &ref_re, "dft3_fwd_re");
771 assert_close(&got_im, &ref_im, "dft3_fwd_im");
772 }
773
774 #[test]
775 fn test_dft3_inverse_vs_naive() {
776 let x_re = [1.3, -0.7, 0.4];
777 let x_im = [0.2, 1.1, -0.5];
778 let (got_re, got_im) = winograd_dft3_inv(&x_re, &x_im);
779 let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
780 assert_close(&got_re, &ref_re, "dft3_inv_re");
781 assert_close(&got_im, &ref_im, "dft3_inv_im");
782 }
783
784 #[test]
785 fn test_roundtrip_dft3() {
786 let x_re = [1.3, -0.7, 0.4];
788 let x_im = [0.2, 1.1, -0.5];
789 let (fwd_re, fwd_im) = winograd_dft3_fwd(&x_re, &x_im);
790 let (inv_re, inv_im) = winograd_dft3_inv(&fwd_re, &fwd_im);
791 let n = 3.0_f64;
792 let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
793 let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
794 assert_close(&scaled_re, &x_re, "roundtrip_dft3_re");
795 assert_close(&scaled_im, &x_im, "roundtrip_dft3_im");
796 }
797
798 #[test]
803 fn test_dft5_forward_f64_impulse() {
804 let x_re = [1.0, 0.0, 0.0, 0.0, 0.0];
805 let x_im = [0.0, 0.0, 0.0, 0.0, 0.0];
806 let (got_re, got_im) = winograd_dft5_fwd(&x_re, &x_im);
807 assert_close(&got_re, &[1.0, 1.0, 1.0, 1.0, 1.0], "dft5_impulse_re");
808 assert_close(&got_im, &[0.0, 0.0, 0.0, 0.0, 0.0], "dft5_impulse_im");
809 }
810
811 #[test]
812 fn test_dft5_forward_vs_naive() {
813 let x_re = [0.5, -1.2, 0.8, 0.3, -0.6];
814 let x_im = [0.1, 0.4, -0.9, 0.7, -0.2];
815 let (got_re, got_im) = winograd_dft5_fwd(&x_re, &x_im);
816 let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
817 assert_close(&got_re, &ref_re, "dft5_fwd_re");
818 assert_close(&got_im, &ref_im, "dft5_fwd_im");
819 }
820
821 #[test]
822 fn test_dft5_inverse_vs_naive() {
823 let x_re = [0.5, -1.2, 0.8, 0.3, -0.6];
824 let x_im = [0.1, 0.4, -0.9, 0.7, -0.2];
825 let (got_re, got_im) = winograd_dft5_inv(&x_re, &x_im);
826 let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
827 assert_close(&got_re, &ref_re, "dft5_inv_re");
828 assert_close(&got_im, &ref_im, "dft5_inv_im");
829 }
830
831 #[test]
832 fn test_roundtrip_dft5() {
833 let x_re = [0.5, -1.2, 0.8, 0.3, -0.6];
834 let x_im = [0.1, 0.4, -0.9, 0.7, -0.2];
835 let (fwd_re, fwd_im) = winograd_dft5_fwd(&x_re, &x_im);
836 let (inv_re, inv_im) = winograd_dft5_inv(&fwd_re, &fwd_im);
837 let n = 5.0_f64;
838 let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
839 let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
840 assert_close(&scaled_re, &x_re, "roundtrip_dft5_re");
841 assert_close(&scaled_im, &x_im, "roundtrip_dft5_im");
842 }
843
844 #[test]
849 fn test_dft7_forward_f64_impulse() {
850 let x_re = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
851 let x_im = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
852 let (got_re, got_im) = winograd_dft7_fwd(&x_re, &x_im);
853 assert_close(
854 &got_re,
855 &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
856 "dft7_impulse_re",
857 );
858 assert_close(
859 &got_im,
860 &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
861 "dft7_impulse_im",
862 );
863 }
864
865 #[test]
866 fn test_dft7_forward_vs_naive() {
867 let x_re = [0.5, -1.2, 0.8, 0.3, -0.6, 1.4, -0.1];
868 let x_im = [0.1, 0.4, -0.9, 0.7, -0.2, 0.5, 0.3];
869 let (got_re, got_im) = winograd_dft7_fwd(&x_re, &x_im);
870 let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
871 assert_close(&got_re, &ref_re, "dft7_fwd_re");
872 assert_close(&got_im, &ref_im, "dft7_fwd_im");
873 }
874
875 #[test]
876 fn test_dft7_inverse_vs_naive() {
877 let x_re = [0.5, -1.2, 0.8, 0.3, -0.6, 1.4, -0.1];
878 let x_im = [0.1, 0.4, -0.9, 0.7, -0.2, 0.5, 0.3];
879 let (got_re, got_im) = winograd_dft7_inv(&x_re, &x_im);
880 let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
881 assert_close(&got_re, &ref_re, "dft7_inv_re");
882 assert_close(&got_im, &ref_im, "dft7_inv_im");
883 }
884
885 #[test]
886 fn test_roundtrip_dft7() {
887 let x_re = [0.5, -1.2, 0.8, 0.3, -0.6, 1.4, -0.1];
888 let x_im = [0.1, 0.4, -0.9, 0.7, -0.2, 0.5, 0.3];
889 let (fwd_re, fwd_im) = winograd_dft7_fwd(&x_re, &x_im);
890 let (inv_re, inv_im) = winograd_dft7_inv(&fwd_re, &fwd_im);
891 let n = 7.0_f64;
892 let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
893 let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
894 assert_close(&scaled_re, &x_re, "roundtrip_dft7_re");
895 assert_close(&scaled_im, &x_im, "roundtrip_dft7_im");
896 }
897
898 #[test]
903 fn test_winograd_constants_match_runtime() {
904 crate::winograd_constants::verify_constants_match_runtime();
905 }
906
907 #[test]
912 fn test_generate_from_macro_size3() {
913 let input: proc_macro2::TokenStream = "3".parse().expect("parse literal");
914 let result = generate_from_macro(input);
915 assert!(result.is_ok(), "gen_odd_codelet!(3) should succeed");
916 let ts = result.expect("TokenStream for size 3");
917 let s = ts.to_string();
918 assert!(
919 s.contains("codelet_notw_3"),
920 "should contain codelet_notw_3"
921 );
922 assert!(s.contains("sign"), "should contain sign parameter");
923 }
924
925 #[test]
926 fn test_generate_from_macro_size5() {
927 let input: proc_macro2::TokenStream = "5".parse().expect("parse literal");
928 let result = generate_from_macro(input);
929 assert!(result.is_ok(), "gen_odd_codelet!(5) should succeed");
930 let ts = result.expect("TokenStream for size 5");
931 let s = ts.to_string();
932 assert!(
933 s.contains("codelet_notw_5"),
934 "should contain codelet_notw_5"
935 );
936 }
937
938 #[test]
939 fn test_generate_from_macro_size7() {
940 let input: proc_macro2::TokenStream = "7".parse().expect("parse literal");
941 let result = generate_from_macro(input);
942 assert!(result.is_ok(), "gen_odd_codelet!(7) should succeed");
943 let ts = result.expect("TokenStream for size 7");
944 let s = ts.to_string();
945 assert!(
946 s.contains("codelet_notw_7"),
947 "should contain codelet_notw_7"
948 );
949 }
950
951 #[test]
952 fn test_generate_from_macro_unsupported() {
953 let input: proc_macro2::TokenStream = "4".parse().expect("parse literal");
954 let result = generate_from_macro(input);
955 assert!(
956 result.is_err(),
957 "gen_odd_codelet!(4) should fail with unsupported size"
958 );
959 }
960}