use darling::{FromDeriveInput, FromMeta};
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, ItemFn, parse_macro_input};
#[derive(Debug, FromMeta)]
struct GpuKernelArgs {
id: String,
mode: String,
domain: String,
#[darling(default)]
description: Option<String>,
#[darling(default)]
throughput: Option<u64>,
#[darling(default)]
latency_us: Option<f64>,
#[darling(default)]
gpu_native: Option<bool>,
}
#[proc_macro_attribute]
pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.to_compile_error()),
};
let args = match GpuKernelArgs::from_list(&args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_vis = &input.vis;
let fn_block = &input.block;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_asyncness = &input.sig.asyncness;
let struct_name = to_pascal_case(&fn_name.to_string());
let struct_ident = syn::Ident::new(&struct_name, fn_name.span());
let mode = match args.mode.as_str() {
"batch" => quote! { rustkernel_core::kernel::KernelMode::Batch },
"ring" => quote! { rustkernel_core::kernel::KernelMode::Ring },
_ => {
return syn::Error::new_spanned(&input.sig, "mode must be 'batch' or 'ring'")
.to_compile_error()
.into();
}
};
let domain = &args.domain;
let domain_ident = syn::Ident::new(domain, proc_macro2::Span::call_site());
let description = args.description.unwrap_or_default();
let throughput = args.throughput.unwrap_or(10_000);
let latency_us = args.latency_us.unwrap_or(50.0);
let gpu_native = args.gpu_native.unwrap_or(false);
let kernel_id = &args.id;
let expanded = quote! {
#[derive(Debug, Clone)]
#fn_vis struct #struct_ident {
metadata: rustkernel_core::kernel::KernelMetadata,
}
impl #struct_ident {
#[must_use]
pub fn new() -> Self {
Self {
metadata: rustkernel_core::kernel::KernelMetadata {
id: #kernel_id.to_string(),
mode: #mode,
domain: rustkernel_core::domain::Domain::#domain_ident,
description: #description.to_string(),
expected_throughput: #throughput,
target_latency_us: #latency_us,
requires_gpu_native: #gpu_native,
version: 1,
},
}
}
}
impl Default for #struct_ident {
fn default() -> Self {
Self::new()
}
}
impl rustkernel_core::traits::GpuKernel for #struct_ident {
fn metadata(&self) -> &rustkernel_core::kernel::KernelMetadata {
&self.metadata
}
}
#fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output
#fn_block
};
TokenStream::from(expanded)
}
fn to_pascal_case(s: &str) -> String {
s.split('_')
.filter(|part| !part.is_empty())
.map(|part| {
let mut chars = part.chars();
match chars.next() {
Some(first) => first.to_uppercase().chain(chars).collect::<String>(),
None => String::new(),
}
})
.collect()
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(message))]
struct KernelMessageArgs {
ident: syn::Ident,
generics: syn::Generics,
#[darling(default)]
type_id: Option<u64>,
#[darling(default)]
#[allow(dead_code)]
domain: Option<String>,
}
#[proc_macro_derive(KernelMessage, attributes(message))]
pub fn derive_kernel_message(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let args = match KernelMessageArgs::from_derive_input(&input) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let name = args.ident;
let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
let type_id = args.type_id.unwrap_or_else(|| {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
name.to_string().hash(&mut hasher);
hasher.finish()
});
let expanded = quote! {
impl #impl_generics #name #ty_generics #where_clause {
#[must_use]
pub const fn message_type_id() -> u64 {
#type_id
}
}
impl #impl_generics ::rustkernel_core::messages::BatchMessage for #name #ty_generics #where_clause {
fn message_type_id() -> u64 {
#type_id
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn kernel_state(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let expanded = quote! {
#[repr(C)]
#[derive(Clone, Copy, Debug, Default)]
#input
};
TokenStream::from(expanded)
}