rustkernel_derive/
lib.rs

1//! Procedural macros for RustKernels.
2//!
3//! This crate provides the following macros:
4//! - `#[gpu_kernel]` - Define a GPU kernel with metadata
5//! - `#[derive(KernelMessage)]` - Derive serialization for kernel messages
6//!
7//! # Example
8//!
9//! ```ignore
10//! use rustkernel_derive::gpu_kernel;
11//!
12//! #[gpu_kernel(
13//!     id = "graph/pagerank",
14//!     mode = "ring",
15//!     domain = "GraphAnalytics",
16//!     throughput = 100_000,
17//!     latency_us = 1.0
18//! )]
19//! pub async fn pagerank_kernel(
20//!     ctx: &mut RingContext,
21//!     request: PageRankRequest,
22//! ) -> PageRankResponse {
23//!     // Implementation
24//! }
25//! ```
26
27use darling::{FromDeriveInput, FromMeta};
28use proc_macro::TokenStream;
29use quote::quote;
30use syn::{DeriveInput, ItemFn, parse_macro_input};
31
32/// Arguments for the `#[gpu_kernel]` attribute.
33#[derive(Debug, FromMeta)]
34struct GpuKernelArgs {
35    /// Kernel ID (e.g., "graph/pagerank").
36    id: String,
37
38    /// Kernel mode: "batch" or "ring".
39    mode: String,
40
41    /// Domain name (e.g., "GraphAnalytics").
42    domain: String,
43
44    /// Description (optional).
45    #[darling(default)]
46    description: Option<String>,
47
48    /// Expected throughput in ops/sec (optional).
49    #[darling(default)]
50    throughput: Option<u64>,
51
52    /// Target latency in microseconds (optional).
53    #[darling(default)]
54    latency_us: Option<f64>,
55
56    /// Whether GPU-native execution is required (optional).
57    #[darling(default)]
58    gpu_native: Option<bool>,
59}
60
61/// Define a GPU kernel with metadata.
62///
63/// This attribute generates a kernel struct and implements the necessary traits.
64///
65/// # Attributes
66///
67/// - `id` - Unique kernel identifier (required)
68/// - `mode` - Kernel mode: "batch" or "ring" (required)
69/// - `domain` - Business domain (required)
70/// - `description` - Human-readable description (optional)
71/// - `throughput` - Expected throughput in ops/sec (optional)
72/// - `latency_us` - Target latency in microseconds (optional)
73/// - `gpu_native` - Whether GPU-native execution is required (optional)
74///
75/// # Example
76///
77/// ```ignore
78/// #[gpu_kernel(
79///     id = "graph/pagerank",
80///     mode = "ring",
81///     domain = "GraphAnalytics",
82///     description = "PageRank centrality calculation",
83///     throughput = 100_000,
84///     latency_us = 1.0,
85///     gpu_native = true
86/// )]
87/// pub async fn pagerank(ctx: &mut RingContext, req: PageRankRequest) -> PageRankResponse {
88///     // Implementation
89/// }
90/// ```
91#[proc_macro_attribute]
92pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
93    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
94        Ok(v) => v,
95        Err(e) => return TokenStream::from(e.to_compile_error()),
96    };
97
98    let args = match GpuKernelArgs::from_list(&args) {
99        Ok(v) => v,
100        Err(e) => return TokenStream::from(e.write_errors()),
101    };
102
103    let input = parse_macro_input!(item as ItemFn);
104    let fn_name = &input.sig.ident;
105    let fn_vis = &input.vis;
106    let fn_block = &input.block;
107    let fn_inputs = &input.sig.inputs;
108    let fn_output = &input.sig.output;
109    let fn_asyncness = &input.sig.asyncness;
110
111    // Generate struct name from function name (PascalCase)
112    let struct_name = to_pascal_case(&fn_name.to_string());
113    let struct_ident = syn::Ident::new(&struct_name, fn_name.span());
114
115    // Parse mode
116    let mode = match args.mode.as_str() {
117        "batch" => quote! { rustkernel_core::kernel::KernelMode::Batch },
118        "ring" => quote! { rustkernel_core::kernel::KernelMode::Ring },
119        _ => {
120            return syn::Error::new_spanned(&input.sig, "mode must be 'batch' or 'ring'")
121                .to_compile_error()
122                .into();
123        }
124    };
125
126    // Parse domain
127    let domain = &args.domain;
128    let domain_ident = syn::Ident::new(domain, proc_macro2::Span::call_site());
129
130    // Default values
131    let description = args.description.unwrap_or_default();
132    let throughput = args.throughput.unwrap_or(10_000);
133    let latency_us = args.latency_us.unwrap_or(50.0);
134    let gpu_native = args.gpu_native.unwrap_or(false);
135    let kernel_id = &args.id;
136
137    // Generate the kernel struct and implementation
138    let expanded = quote! {
139        /// Generated kernel struct for #fn_name.
140        #[derive(Debug, Clone)]
141        #fn_vis struct #struct_ident {
142            metadata: rustkernel_core::kernel::KernelMetadata,
143        }
144
145        impl #struct_ident {
146            /// Create a new instance of this kernel.
147            #[must_use]
148            pub fn new() -> Self {
149                Self {
150                    metadata: rustkernel_core::kernel::KernelMetadata {
151                        id: #kernel_id.to_string(),
152                        mode: #mode,
153                        domain: rustkernel_core::domain::Domain::#domain_ident,
154                        description: #description.to_string(),
155                        expected_throughput: #throughput,
156                        target_latency_us: #latency_us,
157                        requires_gpu_native: #gpu_native,
158                        version: 1,
159                    },
160                }
161            }
162        }
163
164        impl Default for #struct_ident {
165            fn default() -> Self {
166                Self::new()
167            }
168        }
169
170        impl rustkernel_core::traits::GpuKernel for #struct_ident {
171            fn metadata(&self) -> &rustkernel_core::kernel::KernelMetadata {
172                &self.metadata
173            }
174        }
175
176        // Keep the original function for implementation
177        #fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output
178        #fn_block
179    };
180
181    TokenStream::from(expanded)
182}
183
184/// Convert a snake_case string to PascalCase.
185fn to_pascal_case(s: &str) -> String {
186    s.split('_')
187        .filter(|part| !part.is_empty())
188        .map(|part| {
189            let mut chars = part.chars();
190            match chars.next() {
191                Some(first) => first.to_uppercase().chain(chars).collect::<String>(),
192                None => String::new(),
193            }
194        })
195        .collect()
196}
197
198/// Arguments for `#[derive(KernelMessage)]`.
199#[derive(Debug, FromDeriveInput)]
200#[darling(attributes(message))]
201struct KernelMessageArgs {
202    ident: syn::Ident,
203    generics: syn::Generics,
204
205    /// Message type ID.
206    #[darling(default)]
207    type_id: Option<u64>,
208
209    /// Domain for the message (reserved for future use).
210    #[darling(default)]
211    #[allow(dead_code)]
212    domain: Option<String>,
213}
214
215/// Derive macro for kernel messages.
216///
217/// This generates implementations for the `BatchMessage` trait, providing
218/// serialization and type information for batch kernel messages.
219///
220/// # Attributes
221///
222/// - `type_id` - Unique message type identifier (optional, defaults to hash of type name)
223/// - `domain` - Domain for the message (optional)
224///
225/// # Example
226///
227/// ```ignore
228/// #[derive(Debug, Clone, Serialize, Deserialize, KernelMessage)]
229/// #[message(type_id = 100, domain = "GraphAnalytics")]
230/// pub struct PageRankInput {
231///     pub graph: CsrGraph,
232///     pub damping: f64,
233/// }
234/// ```
235///
236/// # Generated Implementation
237///
238/// The macro generates:
239/// - `BatchMessage` trait implementation with `message_type_id()`
240/// - `to_json()` and `from_json()` methods for JSON serialization
241/// - A `message_type_id()` associated function on the type itself
242#[proc_macro_derive(KernelMessage, attributes(message))]
243pub fn derive_kernel_message(input: TokenStream) -> TokenStream {
244    let input = parse_macro_input!(input as DeriveInput);
245
246    let args = match KernelMessageArgs::from_derive_input(&input) {
247        Ok(v) => v,
248        Err(e) => return TokenStream::from(e.write_errors()),
249    };
250
251    let name = args.ident;
252    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
253
254    // Calculate type_id: use provided value or hash of type name
255    let type_id = args.type_id.unwrap_or_else(|| {
256        use std::collections::hash_map::DefaultHasher;
257        use std::hash::{Hash, Hasher};
258        let mut hasher = DefaultHasher::new();
259        name.to_string().hash(&mut hasher);
260        hasher.finish()
261    });
262
263    let expanded = quote! {
264        // Associated function for direct access
265        impl #impl_generics #name #ty_generics #where_clause {
266            /// Get the message type ID.
267            #[must_use]
268            pub const fn message_type_id() -> u64 {
269                #type_id
270            }
271        }
272
273        // Implement BatchMessage trait for batch kernel communication
274        impl #impl_generics ::rustkernel_core::messages::BatchMessage for #name #ty_generics #where_clause {
275            fn message_type_id() -> u64 {
276                #type_id
277            }
278        }
279    };
280
281    TokenStream::from(expanded)
282}
283
284/// Attribute for marking kernel state types.
285///
286/// This ensures the type meets GPU requirements (unmanaged, fixed layout).
287///
288/// # Example
289///
290/// ```ignore
291/// #[kernel_state(size = 256)]
292/// pub struct PageRankState {
293///     pub scores: [f32; 64],
294/// }
295/// ```
296#[proc_macro_attribute]
297pub fn kernel_state(_attr: TokenStream, item: TokenStream) -> TokenStream {
298    // For now, just pass through - state validation can be added later
299    let input = parse_macro_input!(item as DeriveInput);
300
301    let expanded = quote! {
302        #[repr(C)]
303        #[derive(Clone, Copy, Debug, Default)]
304        #input
305    };
306
307    TokenStream::from(expanded)
308}