1use proc_macro2::TokenStream;
28use quote::{format_ident, quote};
29use syn::LitInt;
30
31#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum CodegenError {
38 InvalidSize(usize),
40 UnsupportedSize(usize),
42 EmitError(String),
44}
45
46impl core::fmt::Display for CodegenError {
47 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48 match self {
49 Self::InvalidSize(n) => write!(f, "invalid codelet size: {n}"),
50 Self::UnsupportedSize(n) => write!(f, "unsupported codelet size: {n}"),
51 Self::EmitError(s) => write!(f, "codegen emit error: {s}"),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum SizeClass {
63 Notw(usize),
66 Odd(usize),
68 RaderHardcoded(usize),
70 MixedRadix(Vec<u16>),
73 RaderPrime(usize),
76 Bluestein(usize),
78}
79
80pub fn classify(n: usize) -> Result<SizeClass, CodegenError> {
86 if n == 0 {
87 return Err(CodegenError::InvalidSize(0));
88 }
89 if n == 1 {
90 return Ok(SizeClass::Notw(1));
92 }
93
94 if matches!(n, 2 | 4 | 8 | 16 | 32 | 64) {
96 return Ok(SizeClass::Notw(n));
97 }
98 if matches!(n, 3 | 5 | 7) {
100 return Ok(SizeClass::Odd(n));
101 }
102 if matches!(n, 11 | 13) {
104 return Ok(SizeClass::RaderHardcoded(n));
105 }
106
107 if let Some(factors) = try_factor_smooth7(n) {
109 return Ok(SizeClass::MixedRadix(factors));
110 }
111
112 if is_prime(n) && n <= 1021 {
114 return Ok(SizeClass::RaderPrime(n));
115 }
116
117 Ok(SizeClass::Bluestein(n))
119}
120
121fn try_factor_smooth7(mut n: usize) -> Option<Vec<u16>> {
126 const RADICES: &[usize] = &[16, 8, 7, 5, 4, 3, 2];
129 let mut factors = Vec::new();
130 for &r in RADICES {
131 while n % r == 0 {
132 #[allow(clippy::cast_possible_truncation)]
134 factors.push(r as u16);
135 n /= r;
136 }
137 }
138 if n == 1 && !factors.is_empty() {
139 Some(factors)
140 } else {
141 None
142 }
143}
144
145const fn is_prime(n: usize) -> bool {
146 if n < 2 {
147 return false;
148 }
149 if n == 2 {
150 return true;
151 }
152 if n % 2 == 0 {
153 return false;
154 }
155 let mut i = 3usize;
156 while let Some(sq) = i.checked_mul(i) {
158 if sq > n {
159 break;
160 }
161 if n % i == 0 {
162 return false;
163 }
164 i += 2;
165 }
166 true
167}
168
169pub fn generate(n: usize) -> Result<TokenStream, CodegenError> {
187 match classify(n)? {
188 SizeClass::Notw(sz) => generate_notw_any(sz),
189 SizeClass::Odd(sz) => generate_odd_any(sz),
190 SizeClass::RaderHardcoded(sz) => generate_rader_hardcoded(sz),
191 SizeClass::MixedRadix(_) | SizeClass::RaderPrime(_) | SizeClass::Bluestein(_) => {
192 Ok(generate_runtime_wrapper(n))
193 }
194 }
195}
196
197fn generate_notw_any(sz: usize) -> Result<TokenStream, CodegenError> {
200 if sz == 1 {
201 return Ok(generate_identity_codelet());
202 }
203 let literal = proc_macro2::Literal::usize_unsuffixed(sz);
205 let ts = quote! { #literal };
206 crate::gen_notw::generate(ts).map_err(|e| CodegenError::EmitError(e.to_string()))
207}
208
209fn generate_odd_any(sz: usize) -> Result<TokenStream, CodegenError> {
210 let literal = proc_macro2::Literal::usize_unsuffixed(sz);
211 let ts = quote! { #literal };
212 crate::gen_odd::generate_from_macro(ts).map_err(|e| CodegenError::EmitError(e.to_string()))
213}
214
215fn generate_rader_hardcoded(sz: usize) -> Result<TokenStream, CodegenError> {
216 if matches!(sz, 11 | 13) {
219 Ok(crate::gen_rader::generate_rader(sz))
220 } else {
221 Err(CodegenError::EmitError(format!(
222 "generate_rader_hardcoded: size {sz} is not in the hardcoded set {{11, 13}}"
223 )))
224 }
225}
226
227fn generate_identity_codelet() -> TokenStream {
230 quote! {
231 #[inline(always)]
233 #[allow(clippy::trivially_copy_pass_by_ref, unused_variables)]
234 pub fn codelet_any_1<T: crate::kernel::Float>(
235 x: &mut [crate::kernel::Complex<T>],
236 sign: i32,
237 ) {
238 debug_assert!(x.len() >= 1, "codelet_any_1: input must have at least 1 element");
239 }
241 }
242}
243
244fn generate_runtime_wrapper(n: usize) -> TokenStream {
251 let fn_name = format_ident!("codelet_any_{n}");
252 let n_lit = proc_macro2::Literal::usize_unsuffixed(n);
253
254 quote! {
255 pub fn #fn_name<T: crate::kernel::Float>(
260 x: &mut [crate::kernel::Complex<T>],
261 sign: i32,
262 ) {
263 use ::oxifft::api::{Direction, Flags, Plan};
264
265 debug_assert_eq!(x.len(), #n_lit, "codelet input length mismatch");
266
267 let direction = if sign < 0 {
268 Direction::Forward
269 } else {
270 Direction::Backward
271 };
272
273 let plan = Plan::<T>::dft_1d(#n_lit, direction, Flags::ESTIMATE)
274 .unwrap_or_else(|| {
275 panic!(
276 "OxiFFT: Plan::dft_1d failed for compile-time-verified size {}",
277 #n_lit
278 )
279 });
280
281 let input_snapshot: ::std::vec::Vec<crate::kernel::Complex<T>> = x.to_vec();
283 plan.execute(&input_snapshot, x);
284 }
285 }
286}
287
288#[must_use]
303pub fn generate_from_macro(input: TokenStream) -> TokenStream {
304 match parse_and_generate(input) {
305 Ok(ts) => ts,
306 Err(e) => {
307 let msg = e.to_string();
308 quote! { compile_error!(#msg); }
309 }
310 }
311}
312
313fn parse_and_generate(input: TokenStream) -> Result<TokenStream, CodegenError> {
314 let size: LitInt = syn::parse2(input).map_err(|e| CodegenError::EmitError(e.to_string()))?;
315 let n: usize = size
316 .base10_parse()
317 .map_err(|e| CodegenError::EmitError(e.to_string()))?;
318 generate(n)
319}
320
321pub struct CodeletBuilder {
335 n: usize,
336 #[allow(dead_code)]
339 name_override: Option<String>,
340}
341
342impl CodeletBuilder {
343 #[must_use]
345 pub const fn new(n: usize) -> Self {
346 Self {
347 n,
348 name_override: None,
349 }
350 }
351
352 #[must_use]
354 pub fn name(mut self, name: impl Into<String>) -> Self {
355 self.name_override = Some(name.into());
356 self
357 }
358
359 pub fn build(self) -> Result<TokenStream, CodegenError> {
365 generate(self.n)
366 }
367}
368
369#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn classify_all_notw() {
379 for &n in &[2usize, 4, 8, 16, 32, 64] {
380 assert!(
381 matches!(classify(n).unwrap(), SizeClass::Notw(_)),
382 "n={n} should be Notw"
383 );
384 }
385 }
386
387 #[test]
388 fn classify_all_odd() {
389 for &n in &[3usize, 5, 7] {
390 assert!(
391 matches!(classify(n).unwrap(), SizeClass::Odd(_)),
392 "n={n} should be Odd"
393 );
394 }
395 }
396
397 #[test]
398 fn classify_rader_hardcoded() {
399 assert!(matches!(
400 classify(11).unwrap(),
401 SizeClass::RaderHardcoded(11)
402 ));
403 assert!(matches!(
404 classify(13).unwrap(),
405 SizeClass::RaderHardcoded(13)
406 ));
407 }
408
409 #[test]
410 fn classify_rader_prime_runtime() {
411 for &p in &[
413 17usize, 19, 23, 29, 31, 37, 41, 43, 47, 53, 97, 101, 1013, 1019, 1021,
414 ] {
415 assert!(
416 matches!(classify(p).unwrap(), SizeClass::RaderPrime(_)),
417 "n={p} should be RaderPrime"
418 );
419 }
420 }
421
422 #[test]
423 fn classify_bluestein_large_prime() {
424 assert!(matches!(
426 classify(2003).unwrap(),
427 SizeClass::Bluestein(2003)
428 ));
429 }
430
431 #[test]
432 fn classify_invalid_zero() {
433 assert_eq!(classify(0).unwrap_err(), CodegenError::InvalidSize(0));
434 }
435
436 #[test]
437 fn smooth7_factoring() {
438 for &n in &[
440 6usize, 10, 12, 14, 15, 21, 24, 28, 30, 35, 40, 42, 48, 56, 60, 80, 84, 96, 112, 120,
441 168, 240,
442 ] {
443 assert!(
444 matches!(classify(n).unwrap(), SizeClass::MixedRadix(_)),
445 "n={n} expected MixedRadix"
446 );
447 }
448 }
449
450 #[test]
451 fn smooth7_factors_correct_for_15() {
452 match classify(15).unwrap() {
453 SizeClass::MixedRadix(factors) => {
454 assert!(factors.contains(&5), "15 factors must include 5");
455 assert!(factors.contains(&3), "15 factors must include 3");
456 }
457 other => panic!("expected MixedRadix, got {other:?}"),
458 }
459 }
460
461 #[test]
462 fn is_prime_helper() {
463 assert!(is_prime(2));
464 assert!(is_prime(3));
465 assert!(is_prime(5));
466 assert!(is_prime(7));
467 assert!(is_prime(11));
468 assert!(is_prime(97));
469 assert!(!is_prime(0));
470 assert!(!is_prime(1));
471 assert!(!is_prime(4));
472 assert!(!is_prime(100));
473 }
474
475 #[test]
476 fn generate_emits_nonempty_for_direct_size() {
477 let ts = generate(8).unwrap();
478 assert!(!ts.to_string().is_empty());
479 }
480
481 #[test]
482 fn generate_emits_nonempty_for_odd_size() {
483 let ts = generate(3).unwrap();
484 assert!(!ts.to_string().is_empty());
485 }
486
487 #[test]
488 fn generate_emits_nonempty_for_rader_hardcoded() {
489 let ts = generate(11).unwrap();
490 assert!(!ts.to_string().is_empty());
491 }
492
493 #[test]
494 fn generate_emits_nonempty_for_mixed_radix() {
495 let ts = generate(15).unwrap();
496 assert!(!ts.to_string().is_empty());
497 }
498
499 #[test]
500 fn generate_emits_nonempty_for_bluestein() {
501 let ts = generate(2003).unwrap();
502 assert!(!ts.to_string().is_empty());
503 }
504
505 #[test]
506 fn generate_emits_nonempty_for_identity() {
507 let ts = generate(1).unwrap();
508 assert!(!ts.to_string().is_empty());
509 }
510
511 #[test]
512 fn generate_zero_returns_err() {
513 assert!(generate(0).is_err());
514 }
515
516 #[test]
517 fn codelet_builder_zero_returns_err() {
518 assert!(CodeletBuilder::new(0).build().is_err());
519 }
520
521 #[test]
522 fn codelet_builder_happy_path() {
523 let ts = CodeletBuilder::new(8).build().unwrap();
524 assert!(!ts.to_string().is_empty());
525 }
526}