use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::LitInt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CodegenError {
InvalidSize(usize),
UnsupportedSize(usize),
EmitError(String),
}
impl core::fmt::Display for CodegenError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::InvalidSize(n) => write!(f, "invalid codelet size: {n}"),
Self::UnsupportedSize(n) => write!(f, "unsupported codelet size: {n}"),
Self::EmitError(s) => write!(f, "codegen emit error: {s}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SizeClass {
Notw(usize),
Odd(usize),
RaderHardcoded(usize),
MixedRadix(Vec<u16>),
RaderPrime(usize),
Bluestein(usize),
}
pub fn classify(n: usize) -> Result<SizeClass, CodegenError> {
if n == 0 {
return Err(CodegenError::InvalidSize(0));
}
if n == 1 {
return Ok(SizeClass::Notw(1));
}
if matches!(n, 2 | 4 | 8 | 16 | 32 | 64) {
return Ok(SizeClass::Notw(n));
}
if matches!(n, 3 | 5 | 7) {
return Ok(SizeClass::Odd(n));
}
if matches!(n, 11 | 13) {
return Ok(SizeClass::RaderHardcoded(n));
}
if let Some(factors) = try_factor_smooth7(n) {
return Ok(SizeClass::MixedRadix(factors));
}
if is_prime(n) && n <= 1021 {
return Ok(SizeClass::RaderPrime(n));
}
Ok(SizeClass::Bluestein(n))
}
fn try_factor_smooth7(mut n: usize) -> Option<Vec<u16>> {
const RADICES: &[usize] = &[16, 8, 7, 5, 4, 3, 2];
let mut factors = Vec::new();
for &r in RADICES {
while n % r == 0 {
#[allow(clippy::cast_possible_truncation)]
factors.push(r as u16);
n /= r;
}
}
if n == 1 && !factors.is_empty() {
Some(factors)
} else {
None
}
}
const fn is_prime(n: usize) -> bool {
if n < 2 {
return false;
}
if n == 2 {
return true;
}
if n % 2 == 0 {
return false;
}
let mut i = 3usize;
while let Some(sq) = i.checked_mul(i) {
if sq > n {
break;
}
if n % i == 0 {
return false;
}
i += 2;
}
true
}
pub fn generate(n: usize) -> Result<TokenStream, CodegenError> {
match classify(n)? {
SizeClass::Notw(sz) => generate_notw_any(sz),
SizeClass::Odd(sz) => generate_odd_any(sz),
SizeClass::RaderHardcoded(sz) => generate_rader_hardcoded(sz),
SizeClass::MixedRadix(_) | SizeClass::RaderPrime(_) | SizeClass::Bluestein(_) => {
Ok(generate_runtime_wrapper(n))
}
}
}
fn generate_notw_any(sz: usize) -> Result<TokenStream, CodegenError> {
if sz == 1 {
return Ok(generate_identity_codelet());
}
let literal = proc_macro2::Literal::usize_unsuffixed(sz);
let ts = quote! { #literal };
crate::gen_notw::generate(ts).map_err(|e| CodegenError::EmitError(e.to_string()))
}
fn generate_odd_any(sz: usize) -> Result<TokenStream, CodegenError> {
let literal = proc_macro2::Literal::usize_unsuffixed(sz);
let ts = quote! { #literal };
crate::gen_odd::generate_from_macro(ts).map_err(|e| CodegenError::EmitError(e.to_string()))
}
fn generate_rader_hardcoded(sz: usize) -> Result<TokenStream, CodegenError> {
if matches!(sz, 11 | 13) {
Ok(crate::gen_rader::generate_rader(sz))
} else {
Err(CodegenError::EmitError(format!(
"generate_rader_hardcoded: size {sz} is not in the hardcoded set {{11, 13}}"
)))
}
}
fn generate_identity_codelet() -> TokenStream {
quote! {
#[inline(always)]
#[allow(clippy::trivially_copy_pass_by_ref, unused_variables)]
pub fn codelet_any_1<T: crate::kernel::Float>(
x: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
debug_assert!(x.len() >= 1, "codelet_any_1: input must have at least 1 element");
}
}
}
fn generate_runtime_wrapper(n: usize) -> TokenStream {
let fn_name = format_ident!("codelet_any_{n}");
let n_lit = proc_macro2::Literal::usize_unsuffixed(n);
quote! {
pub fn #fn_name<T: crate::kernel::Float>(
x: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
use ::oxifft::api::{Direction, Flags, Plan};
debug_assert_eq!(x.len(), #n_lit, "codelet input length mismatch");
let direction = if sign < 0 {
Direction::Forward
} else {
Direction::Backward
};
let plan = Plan::<T>::dft_1d(#n_lit, direction, Flags::ESTIMATE)
.unwrap_or_else(|| {
panic!(
"OxiFFT: Plan::dft_1d failed for compile-time-verified size {}",
#n_lit
)
});
let input_snapshot: ::std::vec::Vec<crate::kernel::Complex<T>> = x.to_vec();
plan.execute(&input_snapshot, x);
}
}
}
#[must_use]
pub fn generate_from_macro(input: TokenStream) -> TokenStream {
match parse_and_generate(input) {
Ok(ts) => ts,
Err(e) => {
let msg = e.to_string();
quote! { compile_error!(#msg); }
}
}
}
fn parse_and_generate(input: TokenStream) -> Result<TokenStream, CodegenError> {
let size: LitInt = syn::parse2(input).map_err(|e| CodegenError::EmitError(e.to_string()))?;
let n: usize = size
.base10_parse()
.map_err(|e| CodegenError::EmitError(e.to_string()))?;
generate(n)
}
pub struct CodeletBuilder {
n: usize,
#[allow(dead_code)]
name_override: Option<String>,
}
impl CodeletBuilder {
#[must_use]
pub const fn new(n: usize) -> Self {
Self {
n,
name_override: None,
}
}
#[must_use]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name_override = Some(name.into());
self
}
pub fn build(self) -> Result<TokenStream, CodegenError> {
generate(self.n)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_all_notw() {
for &n in &[2usize, 4, 8, 16, 32, 64] {
assert!(
matches!(classify(n).unwrap(), SizeClass::Notw(_)),
"n={n} should be Notw"
);
}
}
#[test]
fn classify_all_odd() {
for &n in &[3usize, 5, 7] {
assert!(
matches!(classify(n).unwrap(), SizeClass::Odd(_)),
"n={n} should be Odd"
);
}
}
#[test]
fn classify_rader_hardcoded() {
assert!(matches!(
classify(11).unwrap(),
SizeClass::RaderHardcoded(11)
));
assert!(matches!(
classify(13).unwrap(),
SizeClass::RaderHardcoded(13)
));
}
#[test]
fn classify_rader_prime_runtime() {
for &p in &[
17usize, 19, 23, 29, 31, 37, 41, 43, 47, 53, 97, 101, 1013, 1019, 1021,
] {
assert!(
matches!(classify(p).unwrap(), SizeClass::RaderPrime(_)),
"n={p} should be RaderPrime"
);
}
}
#[test]
fn classify_bluestein_large_prime() {
assert!(matches!(
classify(2003).unwrap(),
SizeClass::Bluestein(2003)
));
}
#[test]
fn classify_invalid_zero() {
assert_eq!(classify(0).unwrap_err(), CodegenError::InvalidSize(0));
}
#[test]
fn smooth7_factoring() {
for &n in &[
6usize, 10, 12, 14, 15, 21, 24, 28, 30, 35, 40, 42, 48, 56, 60, 80, 84, 96, 112, 120,
168, 240,
] {
assert!(
matches!(classify(n).unwrap(), SizeClass::MixedRadix(_)),
"n={n} expected MixedRadix"
);
}
}
#[test]
fn smooth7_factors_correct_for_15() {
match classify(15).unwrap() {
SizeClass::MixedRadix(factors) => {
assert!(factors.contains(&5), "15 factors must include 5");
assert!(factors.contains(&3), "15 factors must include 3");
}
other => panic!("expected MixedRadix, got {other:?}"),
}
}
#[test]
fn is_prime_helper() {
assert!(is_prime(2));
assert!(is_prime(3));
assert!(is_prime(5));
assert!(is_prime(7));
assert!(is_prime(11));
assert!(is_prime(97));
assert!(!is_prime(0));
assert!(!is_prime(1));
assert!(!is_prime(4));
assert!(!is_prime(100));
}
#[test]
fn generate_emits_nonempty_for_direct_size() {
let ts = generate(8).unwrap();
assert!(!ts.to_string().is_empty());
}
#[test]
fn generate_emits_nonempty_for_odd_size() {
let ts = generate(3).unwrap();
assert!(!ts.to_string().is_empty());
}
#[test]
fn generate_emits_nonempty_for_rader_hardcoded() {
let ts = generate(11).unwrap();
assert!(!ts.to_string().is_empty());
}
#[test]
fn generate_emits_nonempty_for_mixed_radix() {
let ts = generate(15).unwrap();
assert!(!ts.to_string().is_empty());
}
#[test]
fn generate_emits_nonempty_for_bluestein() {
let ts = generate(2003).unwrap();
assert!(!ts.to_string().is_empty());
}
#[test]
fn generate_emits_nonempty_for_identity() {
let ts = generate(1).unwrap();
assert!(!ts.to_string().is_empty());
}
#[test]
fn generate_zero_returns_err() {
assert!(generate(0).is_err());
}
#[test]
fn codelet_builder_zero_returns_err() {
assert!(CodeletBuilder::new(0).build().is_err());
}
#[test]
fn codelet_builder_happy_path() {
let ts = CodeletBuilder::new(8).build().unwrap();
assert!(!ts.to_string().is_empty());
}
}