use darling::{ast, FromDeriveInput, FromField, FromMeta};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, DeriveInput, ItemFn};
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(message, ring_message), supports(struct_named))]
struct RingMessageArgs {
ident: syn::Ident,
generics: syn::Generics,
data: ast::Data<(), RingMessageField>,
#[darling(default)]
type_id: Option<u64>,
#[darling(default)]
domain: Option<String>,
#[darling(default)]
k2k_routable: bool,
#[darling(default)]
category: Option<String>,
}
#[derive(Debug, FromField)]
#[darling(attributes(message))]
struct RingMessageField {
ident: Option<syn::Ident>,
#[allow(dead_code)]
ty: syn::Type,
#[darling(default)]
id: bool,
#[darling(default)]
correlation: bool,
#[darling(default)]
priority: bool,
}
#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
pub fn derive_ring_message(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let args = match RingMessageArgs::from_derive_input(&input) {
Ok(args) => args,
Err(e) => return e.write_errors().into(),
};
let name = &args.ident;
let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
let base_type_id = args.type_id.unwrap_or_else(|| {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
name.to_string().hash(&mut hasher);
if args.domain.is_some() {
hasher.finish() % 100
} else {
hasher.finish()
}
});
let fields = match &args.data {
ast::Data::Struct(fields) => fields,
_ => panic!("RingMessage can only be derived for structs"),
};
let mut id_field: Option<&syn::Ident> = None;
let mut correlation_field: Option<&syn::Ident> = None;
let mut priority_field: Option<&syn::Ident> = None;
for field in fields.iter() {
if field.id {
id_field = field.ident.as_ref();
}
if field.correlation {
correlation_field = field.ident.as_ref();
}
if field.priority {
priority_field = field.ident.as_ref();
}
}
let message_id_impl = if let Some(field) = id_field {
quote! { self.#field }
} else {
quote! { ::ringkernel_core::message::MessageId::new(0) }
};
let correlation_id_impl = if let Some(field) = correlation_field {
quote! { self.#field }
} else {
quote! { ::ringkernel_core::message::CorrelationId::none() }
};
let priority_impl = if let Some(field) = priority_field {
quote! { self.#field }
} else {
quote! { ::ringkernel_core::message::Priority::Normal }
};
let message_type_impl = if let Some(ref domain_str) = args.domain {
quote! {
::ringkernel_core::domain::Domain::from_str(#domain_str)
.unwrap_or(::ringkernel_core::domain::Domain::General)
.base_type_id() + #base_type_id
}
} else {
quote! { #base_type_id }
};
let domain_impl = if let Some(ref domain_str) = args.domain {
quote! {
impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
fn domain() -> ::ringkernel_core::domain::Domain {
::ringkernel_core::domain::Domain::from_str(#domain_str)
.unwrap_or(::ringkernel_core::domain::Domain::General)
}
}
}
} else {
quote! {}
};
let k2k_registration = if args.k2k_routable {
let registration_name = format_ident!(
"__K2K_MESSAGE_REGISTRATION_{}",
name.to_string().to_uppercase()
);
let type_name_str = name.to_string();
let category_tokens = match &args.category {
Some(cat) => quote! { ::std::option::Option::Some(#cat) },
None => quote! { ::std::option::Option::None },
};
quote! {
#[allow(non_upper_case_globals)]
#[::inventory::submit]
static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
::ringkernel_core::k2k::K2KMessageRegistration {
type_id: {
#base_type_id
},
type_name: #type_name_str,
k2k_routable: true,
category: #category_tokens,
};
}
} else {
quote! {}
};
let expanded = quote! {
impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
fn message_type() -> u64 {
#message_type_impl
}
fn message_id(&self) -> ::ringkernel_core::message::MessageId {
#message_id_impl
}
fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
#correlation_id_impl
}
fn priority(&self) -> ::ringkernel_core::message::Priority {
#priority_impl
}
fn serialize(&self) -> Vec<u8> {
::rkyv::to_bytes::<_, 4096>(self)
.map(|v| v.to_vec())
.unwrap_or_default()
}
fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
where
Self: Sized,
{
use ::rkyv::Deserialize as _;
let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
.map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
"rkyv deserialization failed".to_string()
))?;
Ok(deserialized)
}
fn size_hint(&self) -> usize {
::std::mem::size_of::<Self>()
}
}
#domain_impl
#k2k_registration
};
TokenStream::from(expanded)
}
#[allow(dead_code)]
const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(persistent_message), supports(struct_named))]
struct PersistentMessageArgs {
ident: syn::Ident,
generics: syn::Generics,
#[allow(dead_code)]
data: ast::Data<(), PersistentMessageField>,
handler_id: u32,
#[darling(default)]
requires_response: bool,
}
#[derive(Debug, FromField)]
#[darling(attributes(persistent_message))]
struct PersistentMessageField {
#[allow(dead_code)]
ident: Option<syn::Ident>,
#[allow(dead_code)]
ty: syn::Type,
}
#[proc_macro_derive(PersistentMessage, attributes(persistent_message))]
pub fn derive_persistent_message(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let args = match PersistentMessageArgs::from_derive_input(&input) {
Ok(args) => args,
Err(e) => return e.write_errors().into(),
};
let name = &args.ident;
let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
let handler_id = args.handler_id;
let requires_response = args.requires_response;
let expanded = quote! {
impl #impl_generics ::ringkernel_core::persistent_message::PersistentMessage for #name #ty_generics #where_clause {
fn handler_id() -> u32 {
#handler_id
}
fn requires_response() -> bool {
#requires_response
}
fn payload_size() -> usize {
::std::mem::size_of::<Self>()
}
fn to_inline_payload(&self) -> ::std::option::Option<[u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
if ::std::mem::size_of::<Self>() > ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE {
return ::std::option::Option::None;
}
let mut payload = [0u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE];
unsafe {
::std::ptr::copy_nonoverlapping(
self as *const Self as *const u8,
payload.as_mut_ptr(),
::std::mem::size_of::<Self>()
);
}
::std::option::Option::Some(payload)
}
fn from_inline_payload(payload: &[u8]) -> ::ringkernel_core::error::Result<Self> {
let size = ::std::mem::size_of::<Self>();
if payload.len() < size {
return ::std::result::Result::Err(
::ringkernel_core::error::RingKernelError::DeserializationError(
::std::format!(
"Payload too small: expected {} bytes, got {}",
size,
payload.len()
)
)
);
}
let value = unsafe {
::std::ptr::read(payload.as_ptr() as *const Self)
};
::std::result::Result::Ok(value)
}
}
};
TokenStream::from(expanded)
}
#[derive(Debug, FromMeta)]
struct RingKernelArgs {
id: String,
#[darling(default)]
mode: Option<String>,
#[darling(default)]
grid_size: Option<u32>,
#[darling(default)]
block_size: Option<u32>,
#[darling(default)]
publishes_to: Option<String>,
}
#[proc_macro_attribute]
pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
};
let args = match RingKernelArgs::from_list(&args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input = parse_macro_input!(item as ItemFn);
let kernel_id = &args.id;
let fn_name = &input.sig.ident;
let fn_vis = &input.vis;
let fn_block = &input.block;
let fn_attrs = &input.attrs;
let inputs = &input.sig.inputs;
let output = &input.sig.output;
let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
let ctx = inputs.first();
let msg = inputs.iter().nth(1);
(ctx, msg)
} else {
(None, None)
};
let msg_type = msg_arg
.map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
pat_type.ty.clone()
} else {
syn::parse_quote!(())
}
})
.unwrap_or_else(|| syn::parse_quote!(()));
let mode = args.mode.as_deref().unwrap_or("persistent");
let mode_expr = if mode == "event_driven" {
quote! { ::ringkernel_core::types::KernelMode::EventDriven }
} else {
quote! { ::ringkernel_core::types::KernelMode::Persistent }
};
let grid_size = args.grid_size.unwrap_or(1);
let block_size = args.block_size.unwrap_or(256);
let publishes_to_targets: Vec<String> = args
.publishes_to
.as_ref()
.map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
.unwrap_or_default();
let registration_name = format_ident!(
"__RINGKERNEL_REGISTRATION_{}",
fn_name.to_string().to_uppercase()
);
let handler_name = format_ident!("{}_handler", fn_name);
let expanded = quote! {
#(#fn_attrs)*
#fn_vis async fn #fn_name #inputs #output #fn_block
#fn_vis fn #handler_name(
ctx: &mut ::ringkernel_core::RingContext<'_>,
envelope: ::ringkernel_core::message::MessageEnvelope,
) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
Box::pin(async move {
let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
let response = #fn_name(ctx, msg).await;
let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
let response_header = ::ringkernel_core::message::MessageHeader::new(
<_ as ::ringkernel_core::message::RingMessage>::message_type(),
envelope.header.dest_kernel,
envelope.header.source_kernel,
response_payload.len(),
ctx.now(),
).with_correlation(envelope.header.correlation_id);
Ok(::ringkernel_core::message::MessageEnvelope {
header: response_header,
payload: response_payload,
..::std::default::Default::default()
})
})
}
#[allow(non_upper_case_globals)]
#[::inventory::submit]
static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
id: #kernel_id,
mode: #mode_expr,
grid_size: #grid_size,
block_size: #block_size,
publishes_to: &[#(#publishes_to_targets),*],
};
};
TokenStream::from(expanded)
}
#[proc_macro_derive(GpuType)]
pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let expanded = quote! {
const _: fn() = || {
fn assert_copy<T: Copy>() {}
assert_copy::<#name #ty_generics>();
};
unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
};
TokenStream::from(expanded)
}
#[derive(Debug, FromMeta)]
struct StencilKernelArgs {
id: String,
#[darling(default)]
grid: Option<String>,
#[darling(default)]
tile_size: Option<u32>,
#[darling(default)]
tile_width: Option<u32>,
#[darling(default)]
tile_height: Option<u32>,
#[darling(default)]
halo: Option<u32>,
}
#[proc_macro_attribute]
pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
};
let args = match StencilKernelArgs::from_list(&args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input = parse_macro_input!(item as ItemFn);
stencil_kernel_impl(args, input)
}
fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
let kernel_id = &args.id;
let fn_name = &input.sig.ident;
let fn_vis = &input.vis;
let fn_block = &input.block;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_attrs = &input.attrs;
let grid = args.grid.as_deref().unwrap_or("2d");
let tile_width = args
.tile_width
.unwrap_or_else(|| args.tile_size.unwrap_or(16));
let tile_height = args
.tile_height
.unwrap_or_else(|| args.tile_size.unwrap_or(16));
let halo = args.halo.unwrap_or(1);
let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
let registration_name = format_ident!(
"__STENCIL_KERNEL_REGISTRATION_{}",
fn_name.to_string().to_uppercase()
);
#[cfg(feature = "cuda-codegen")]
let cuda_source_code = {
use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
let grid_type = match grid {
"1d" => Grid::Grid1D,
"2d" => Grid::Grid2D,
"3d" => Grid::Grid3D,
_ => Grid::Grid2D,
};
let config = StencilConfig::new(kernel_id.clone())
.with_grid(grid_type)
.with_tile_size(tile_width as usize, tile_height as usize)
.with_halo(halo as usize);
match transpile_stencil_kernel(&input, &config) {
Ok(cuda) => cuda,
Err(e) => {
return TokenStream::from(
syn::Error::new_spanned(
&input.sig.ident,
format!("CUDA transpilation failed: {}", e),
)
.to_compile_error(),
);
}
}
};
#[cfg(not(feature = "cuda-codegen"))]
let cuda_source_code = format!(
"// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
kernel_id
);
let expanded = quote! {
#(#fn_attrs)*
#fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
#fn_vis const #cuda_const_name: &str = #cuda_source_code;
#[allow(non_upper_case_globals)]
#[::inventory::submit]
static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
::ringkernel_core::__private::StencilKernelRegistration {
id: #kernel_id,
grid: #grid,
tile_width: #tile_width,
tile_height: #tile_height,
halo: #halo,
cuda_source: #cuda_source_code,
};
};
TokenStream::from(expanded)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GpuBackend {
Cuda,
Metal,
Wgpu,
Cpu,
}
impl GpuBackend {
fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"cuda" => Some(Self::Cuda),
"metal" => Some(Self::Metal),
"wgpu" | "webgpu" => Some(Self::Wgpu),
"cpu" => Some(Self::Cpu),
_ => None,
}
}
fn as_str(&self) -> &'static str {
match self {
Self::Cuda => "cuda",
Self::Metal => "metal",
Self::Wgpu => "wgpu",
Self::Cpu => "cpu",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GpuCapability {
Float64,
Int64,
Atomic64,
CooperativeGroups,
Subgroups,
SharedMemory,
DynamicParallelism,
Float16,
}
impl GpuCapability {
fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"f64" | "float64" => Some(Self::Float64),
"i64" | "int64" => Some(Self::Int64),
"atomic64" => Some(Self::Atomic64),
"cooperative_groups" | "cooperativegroups" | "grid_sync" => {
Some(Self::CooperativeGroups)
}
"subgroups" | "warp" | "simd" => Some(Self::Subgroups),
"shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
"dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
"f16" | "float16" | "half" => Some(Self::Float16),
_ => None,
}
}
fn as_str(&self) -> &'static str {
match self {
Self::Float64 => "f64",
Self::Int64 => "i64",
Self::Atomic64 => "atomic64",
Self::CooperativeGroups => "cooperative_groups",
Self::Subgroups => "subgroups",
Self::SharedMemory => "shared_memory",
Self::DynamicParallelism => "dynamic_parallelism",
Self::Float16 => "f16",
}
}
fn supported_by(&self, backend: GpuBackend) -> bool {
match (self, backend) {
(_, GpuBackend::Cuda) => true,
(Self::Float64, GpuBackend::Metal) => false,
(Self::CooperativeGroups, GpuBackend::Metal) => false,
(Self::DynamicParallelism, GpuBackend::Metal) => false,
(_, GpuBackend::Metal) => true,
(Self::Float64, GpuBackend::Wgpu) => false,
(Self::Int64, GpuBackend::Wgpu) => false,
(Self::Atomic64, GpuBackend::Wgpu) => false, (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
(Self::DynamicParallelism, GpuBackend::Wgpu) => false,
(Self::Subgroups, GpuBackend::Wgpu) => true, (_, GpuBackend::Wgpu) => true,
(_, GpuBackend::Cpu) => true,
}
}
}
#[derive(Debug)]
struct GpuKernelArgs {
id: Option<String>,
backends: Vec<GpuBackend>,
fallback: Vec<GpuBackend>,
requires: Vec<GpuCapability>,
block_size: Option<u32>,
}
impl Default for GpuKernelArgs {
fn default() -> Self {
Self {
id: None,
backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
fallback: vec![
GpuBackend::Cuda,
GpuBackend::Metal,
GpuBackend::Wgpu,
GpuBackend::Cpu,
],
requires: Vec::new(),
block_size: None,
}
}
}
impl GpuKernelArgs {
fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
let mut args = Self::default();
let attr_str = attr.to_string();
if let Some(start) = attr_str.find("backends") {
if let Some(bracket_start) = attr_str[start..].find('[') {
if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
let backends_str =
&attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
args.backends = backends_str
.split(',')
.filter_map(|s| GpuBackend::from_str(s.trim()))
.collect();
}
}
}
if let Some(start) = attr_str.find("fallback") {
if let Some(bracket_start) = attr_str[start..].find('[') {
if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
let fallback_str =
&attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
args.fallback = fallback_str
.split(',')
.filter_map(|s| GpuBackend::from_str(s.trim()))
.collect();
}
}
}
if let Some(start) = attr_str.find("requires") {
if let Some(bracket_start) = attr_str[start..].find('[') {
if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
let requires_str =
&attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
args.requires = requires_str
.split(',')
.filter_map(|s| GpuCapability::from_str(s.trim()))
.collect();
}
}
}
if let Some(start) = attr_str.find("id") {
if let Some(quote_start) = attr_str[start..].find('"') {
if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
args.id = Some(
attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
.to_string(),
);
}
}
}
if let Some(start) = attr_str.find("block_size") {
if let Some(eq) = attr_str[start..].find('=') {
let rest = &attr_str[start + eq + 1..];
let num_end = rest
.find(|c: char| !c.is_numeric() && c != ' ')
.unwrap_or(rest.len());
if let Ok(n) = rest[..num_end].trim().parse() {
args.block_size = Some(n);
}
}
}
Ok(args)
}
fn validate_capabilities(&self) -> Result<(), String> {
for cap in &self.requires {
let mut supported_by_any = false;
for backend in &self.backends {
if cap.supported_by(*backend) {
supported_by_any = true;
break;
}
}
if !supported_by_any {
return Err(format!(
"Capability '{}' is not supported by any of the specified backends: {:?}",
cap.as_str(),
self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
));
}
}
Ok(())
}
fn compatible_backends(&self) -> Vec<GpuBackend> {
self.backends
.iter()
.filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
.copied()
.collect()
}
}
#[proc_macro_attribute]
pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr2: proc_macro2::TokenStream = attr.into();
let args = match GpuKernelArgs::parse(attr2) {
Ok(args) => args,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input = parse_macro_input!(item as ItemFn);
if let Err(msg) = args.validate_capabilities() {
return TokenStream::from(
syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
);
}
gpu_kernel_impl(args, input)
}
fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
let fn_name = &input.sig.ident;
let fn_vis = &input.vis;
let fn_block = &input.block;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_attrs = &input.attrs;
let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
let block_size = args.block_size.unwrap_or(256);
let compatible_backends = args.compatible_backends();
let mut source_constants = Vec::new();
for backend in &compatible_backends {
let const_name = format_ident!(
"{}_{}",
fn_name.to_string().to_uppercase(),
backend.as_str().to_uppercase()
);
let backend_str = backend.as_str();
let source_placeholder = format!(
"// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
backend_str.to_uppercase(),
kernel_id,
args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
);
source_constants.push(quote! {
#fn_vis const #const_name: &str = #source_placeholder;
});
}
let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
let registration_name = format_ident!(
"__GPU_KERNEL_REGISTRATION_{}",
fn_name.to_string().to_uppercase()
);
let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
let expanded = quote! {
#(#fn_attrs)*
#fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
#(#source_constants)*
#fn_vis mod #info_name {
pub const ID: &str = #kernel_id;
pub const BLOCK_SIZE: u32 = #block_size;
pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
}
#[allow(non_upper_case_globals)]
#[::inventory::submit]
static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
::ringkernel_core::__private::GpuKernelRegistration {
id: #kernel_id,
block_size: #block_size,
capabilities: &[#(#capability_strs),*],
backends: &[#(#backend_strs),*],
fallback_order: &[#(#fallback_strs),*],
};
};
TokenStream::from(expanded)
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(state), supports(struct_named))]
struct ControlBlockStateArgs {
ident: syn::Ident,
generics: syn::Generics,
#[darling(default)]
version: Option<u32>,
}
#[proc_macro_derive(ControlBlockState, attributes(state))]
pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let args = match ControlBlockStateArgs::from_derive_input(&input) {
Ok(args) => args,
Err(e) => return e.write_errors().into(),
};
let name = &args.ident;
let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
let version = args.version.unwrap_or(1);
let expanded = quote! {
const _: () = {
assert!(
::std::mem::size_of::<#name #ty_generics>() <= 24,
"ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
);
};
const _: fn() = || {
fn assert_copy<T: Copy>() {}
assert_copy::<#name #ty_generics>();
};
unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
const VERSION: u32 = #version;
fn is_embedded() -> bool {
true
}
}
};
TokenStream::from(expanded)
}