extern crate proc_macro;
use darling::FromMeta;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, Pat};
#[derive(Debug, FromMeta)]
enum DeviceType {
Cpu,
Gpu,
}
#[derive(Debug)]
struct DefaultDeviceInput {
device: DeviceType,
}
impl FromMeta for DefaultDeviceInput {
fn from_meta(meta: &syn::Meta) -> darling::Result<Self> {
let syn::Meta::NameValue(meta_name_value) = meta else {
return Err(darling::Error::unsupported_format(
"expected a name-value attribute",
));
};
let ident = meta_name_value.path.get_ident().unwrap();
assert_eq!(ident, "device", "expected `device`");
let device = DeviceType::from_expr(&meta_name_value.value)?;
Ok(DefaultDeviceInput { device })
}
}
#[proc_macro_attribute]
pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = if !attr.is_empty() {
let meta = syn::parse_macro_input!(attr as syn::Meta);
Some(DefaultDeviceInput::from_meta(&meta).unwrap())
} else {
None
};
let mut input_fn = parse_macro_input!(item as ItemFn);
let original_fn = input_fn.clone();
if !input_fn.sig.ident.to_string().contains("_device") {
panic!("Function name must end with '_device'");
}
let new_fn_name = format_ident!("{}", &input_fn.sig.ident.to_string().replace("_device", ""));
input_fn.sig.ident = new_fn_name;
let filtered_inputs = input_fn
.sig
.inputs
.iter()
.filter(|arg| match arg {
FnArg::Typed(pat_typed) => {
if let Pat::Ident(pat_ident) = &*pat_typed.pat {
pat_ident.ident != "stream"
} else {
true
}
}
_ => true,
})
.cloned()
.collect::<Vec<_>>();
input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs);
let default_stream_stmt = match input.map(|input| input.device) {
Some(DeviceType::Cpu) => parse_quote! {
let stream = StreamOrDevice::cpu();
},
Some(DeviceType::Gpu) => parse_quote! {
let stream = StreamOrDevice::gpu();
},
None => parse_quote! {
let stream = StreamOrDevice::default();
},
};
input_fn.block.stmts.insert(0, default_stream_stmt);
let expanded = quote! {
#original_fn
#input_fn
};
TokenStream::from(expanded)
}
#[proc_macro_derive(GenerateDtypeTestCases)]
pub fn generate_test_cases(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let tests = quote! {
#[rustfmt::skip]
const TYPE_RULES: [[Dtype; 13]; 13] = [
[Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float32, Dtype::Complex64], [Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Complex64], [Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], [Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], ];
#[cfg(test)]
mod generated_tests {
use super::*;
use strum::IntoEnumIterator;
use pretty_assertions::assert_eq;
#[test]
fn test_all_combinations() {
for a in #name::iter() {
for b in #name::iter() {
let result = a.promote_with(b);
let expected = TYPE_RULES[a as usize][b as usize];
assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b));
}
}
}
}
};
TokenStream::from(tests)
}