rustkernel_derive/lib.rs
1//! Procedural macros for RustKernels.
2//!
3//! This crate provides the following macros:
4//! - `#[gpu_kernel]` - Define a GPU kernel with metadata
5//! - `#[derive(KernelMessage)]` - Derive serialization for kernel messages
6//! - `#[kernel_state]` - Mark types as GPU-compatible kernel state
7//!
8//! For low-level ring kernel macros, see `ringkernel-derive` 0.4.2 which provides:
9//! - `#[derive(RingMessage)]` - Ring message serialization with domain-based type IDs
10//! - `#[derive(PersistentMessage)]` - CUDA persistent message dispatch
11//! - `#[derive(ControlBlockState)]` - Embedded state for GPU ControlBlocks
12//! - `#[derive(GpuType)]` - Pod+Zeroable for GPU data transfer
13//! - `#[ring_kernel]` - Ring kernel handler generation
14//! - `#[stencil_kernel]` - CUDA stencil pattern kernels
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustkernel_derive::gpu_kernel;
20//!
21//! #[gpu_kernel(
22//! id = "graph/pagerank",
23//! mode = "ring",
24//! domain = "GraphAnalytics",
25//! throughput = 100_000,
26//! latency_us = 1.0
27//! )]
28//! pub async fn pagerank_kernel(
29//! ctx: &mut RingContext,
30//! request: PageRankRequest,
31//! ) -> PageRankResponse {
32//! // Implementation
33//! }
34//! ```
35
36use darling::{FromDeriveInput, FromMeta};
37use proc_macro::TokenStream;
38use quote::quote;
39use syn::{DeriveInput, ItemFn, parse_macro_input};
40
41/// Arguments for the `#[gpu_kernel]` attribute.
42#[derive(Debug, FromMeta)]
43struct GpuKernelArgs {
44 /// Kernel ID (e.g., "graph/pagerank").
45 id: String,
46
47 /// Kernel mode: "batch" or "ring".
48 mode: String,
49
50 /// Domain name (e.g., "GraphAnalytics").
51 domain: String,
52
53 /// Description (optional).
54 #[darling(default)]
55 description: Option<String>,
56
57 /// Expected throughput in ops/sec (optional).
58 #[darling(default)]
59 throughput: Option<u64>,
60
61 /// Target latency in microseconds (optional).
62 #[darling(default)]
63 latency_us: Option<f64>,
64
65 /// Whether GPU-native execution is required (optional).
66 #[darling(default)]
67 gpu_native: Option<bool>,
68}
69
70/// Define a GPU kernel with metadata.
71///
72/// This attribute generates a kernel struct and implements the necessary traits.
73///
74/// # Attributes
75///
76/// - `id` - Unique kernel identifier (required)
77/// - `mode` - Kernel mode: "batch" or "ring" (required)
78/// - `domain` - Business domain (required)
79/// - `description` - Human-readable description (optional)
80/// - `throughput` - Expected throughput in ops/sec (optional)
81/// - `latency_us` - Target latency in microseconds (optional)
82/// - `gpu_native` - Whether GPU-native execution is required (optional)
83///
84/// # Example
85///
86/// ```ignore
87/// #[gpu_kernel(
88/// id = "graph/pagerank",
89/// mode = "ring",
90/// domain = "GraphAnalytics",
91/// description = "PageRank centrality calculation",
92/// throughput = 100_000,
93/// latency_us = 1.0,
94/// gpu_native = true
95/// )]
96/// pub async fn pagerank(ctx: &mut RingContext, req: PageRankRequest) -> PageRankResponse {
97/// // Implementation
98/// }
99/// ```
100#[proc_macro_attribute]
101pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
102 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
103 Ok(v) => v,
104 Err(e) => return TokenStream::from(e.to_compile_error()),
105 };
106
107 let args = match GpuKernelArgs::from_list(&args) {
108 Ok(v) => v,
109 Err(e) => return TokenStream::from(e.write_errors()),
110 };
111
112 let input = parse_macro_input!(item as ItemFn);
113 let fn_name = &input.sig.ident;
114 let fn_vis = &input.vis;
115 let fn_block = &input.block;
116 let fn_inputs = &input.sig.inputs;
117 let fn_output = &input.sig.output;
118 let fn_asyncness = &input.sig.asyncness;
119
120 // Generate struct name from function name (PascalCase)
121 let struct_name = to_pascal_case(&fn_name.to_string());
122 let struct_ident = syn::Ident::new(&struct_name, fn_name.span());
123
124 // Parse mode
125 let mode = match args.mode.as_str() {
126 "batch" => quote! { rustkernel_core::kernel::KernelMode::Batch },
127 "ring" => quote! { rustkernel_core::kernel::KernelMode::Ring },
128 _ => {
129 return syn::Error::new_spanned(&input.sig, "mode must be 'batch' or 'ring'")
130 .to_compile_error()
131 .into();
132 }
133 };
134
135 // Parse domain
136 let domain = &args.domain;
137 let domain_ident = syn::Ident::new(domain, proc_macro2::Span::call_site());
138
139 // Default values
140 let description = args.description.unwrap_or_default();
141 let throughput = args.throughput.unwrap_or(10_000);
142 let latency_us = args.latency_us.unwrap_or(50.0);
143 let gpu_native = args.gpu_native.unwrap_or(false);
144 let kernel_id = &args.id;
145
146 // Generate the kernel struct and implementation
147 let expanded = quote! {
148 /// Generated kernel struct for #fn_name.
149 #[derive(Debug, Clone)]
150 #fn_vis struct #struct_ident {
151 metadata: rustkernel_core::kernel::KernelMetadata,
152 }
153
154 impl #struct_ident {
155 /// Create a new instance of this kernel.
156 #[must_use]
157 pub fn new() -> Self {
158 Self {
159 metadata: rustkernel_core::kernel::KernelMetadata {
160 id: #kernel_id.to_string(),
161 mode: #mode,
162 domain: rustkernel_core::domain::Domain::#domain_ident,
163 description: #description.to_string(),
164 expected_throughput: #throughput,
165 target_latency_us: #latency_us,
166 requires_gpu_native: #gpu_native,
167 version: 1,
168 },
169 }
170 }
171 }
172
173 impl Default for #struct_ident {
174 fn default() -> Self {
175 Self::new()
176 }
177 }
178
179 impl rustkernel_core::traits::GpuKernel for #struct_ident {
180 fn metadata(&self) -> &rustkernel_core::kernel::KernelMetadata {
181 &self.metadata
182 }
183 }
184
185 // Keep the original function for implementation
186 #fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output
187 #fn_block
188 };
189
190 TokenStream::from(expanded)
191}
192
193/// Convert a snake_case string to PascalCase.
194fn to_pascal_case(s: &str) -> String {
195 s.split('_')
196 .filter(|part| !part.is_empty())
197 .map(|part| {
198 let mut chars = part.chars();
199 match chars.next() {
200 Some(first) => first.to_uppercase().chain(chars).collect::<String>(),
201 None => String::new(),
202 }
203 })
204 .collect()
205}
206
207/// Arguments for `#[derive(KernelMessage)]`.
208#[derive(Debug, FromDeriveInput)]
209#[darling(attributes(message))]
210struct KernelMessageArgs {
211 ident: syn::Ident,
212 generics: syn::Generics,
213
214 /// Message type ID.
215 #[darling(default)]
216 type_id: Option<u64>,
217
218 /// Domain for the message (reserved for future use).
219 #[darling(default)]
220 #[allow(dead_code)]
221 domain: Option<String>,
222}
223
224/// Derive macro for kernel messages.
225///
226/// This generates implementations for the `BatchMessage` trait, providing
227/// serialization and type information for batch kernel messages.
228///
229/// # Attributes
230///
231/// - `type_id` - Unique message type identifier (optional, defaults to hash of type name)
232/// - `domain` - Domain for the message (optional)
233///
234/// # Example
235///
236/// ```ignore
237/// #[derive(Debug, Clone, Serialize, Deserialize, KernelMessage)]
238/// #[message(type_id = 100, domain = "GraphAnalytics")]
239/// pub struct PageRankInput {
240/// pub graph: CsrGraph,
241/// pub damping: f64,
242/// }
243/// ```
244///
245/// # Generated Implementation
246///
247/// The macro generates:
248/// - `BatchMessage` trait implementation with `message_type_id()`
249/// - `to_json()` and `from_json()` methods for JSON serialization
250/// - A `message_type_id()` associated function on the type itself
251#[proc_macro_derive(KernelMessage, attributes(message))]
252pub fn derive_kernel_message(input: TokenStream) -> TokenStream {
253 let input = parse_macro_input!(input as DeriveInput);
254
255 let args = match KernelMessageArgs::from_derive_input(&input) {
256 Ok(v) => v,
257 Err(e) => return TokenStream::from(e.write_errors()),
258 };
259
260 let name = args.ident;
261 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
262
263 // Calculate type_id: use provided value or hash of type name
264 let type_id = args.type_id.unwrap_or_else(|| {
265 use std::collections::hash_map::DefaultHasher;
266 use std::hash::{Hash, Hasher};
267 let mut hasher = DefaultHasher::new();
268 name.to_string().hash(&mut hasher);
269 hasher.finish()
270 });
271
272 let expanded = quote! {
273 // Associated function for direct access
274 impl #impl_generics #name #ty_generics #where_clause {
275 /// Get the message type ID.
276 #[must_use]
277 pub const fn message_type_id() -> u64 {
278 #type_id
279 }
280 }
281
282 // Implement BatchMessage trait for batch kernel communication
283 impl #impl_generics ::rustkernel_core::messages::BatchMessage for #name #ty_generics #where_clause {
284 fn message_type_id() -> u64 {
285 #type_id
286 }
287 }
288 };
289
290 TokenStream::from(expanded)
291}
292
293/// Attribute for marking kernel state types.
294///
295/// This ensures the type meets GPU requirements (unmanaged, fixed layout).
296///
297/// # Example
298///
299/// ```ignore
300/// #[kernel_state(size = 256)]
301/// pub struct PageRankState {
302/// pub scores: [f32; 64],
303/// }
304/// ```
305#[proc_macro_attribute]
306pub fn kernel_state(_attr: TokenStream, item: TokenStream) -> TokenStream {
307 // For now, just pass through - state validation can be added later
308 let input = parse_macro_input!(item as DeriveInput);
309
310 let expanded = quote! {
311 #[repr(C)]
312 #[derive(Clone, Copy, Debug, Default)]
313 #input
314 };
315
316 TokenStream::from(expanded)
317}