ringkernel_derive/
lib.rs

1//! Procedural macros for RingKernel.
2//!
3//! This crate provides the following macros:
4//!
5//! - `#[derive(RingMessage)]` - Implement the RingMessage trait for message types
6//! - `#[ring_kernel]` - Define a ring kernel handler
7//! - `#[stencil_kernel]` - Define a GPU stencil kernel (with `cuda-codegen` feature)
8//! - `#[gpu_kernel]` - Define a multi-backend GPU kernel with capability checking
9//!
10//! # Example
11//!
12//! ```ignore
13//! use ringkernel_derive::{RingMessage, ring_kernel};
14//!
15//! #[derive(RingMessage)]
16//! struct AddRequest {
17//!     #[message(id)]
18//!     id: MessageId,
19//!     a: f32,
20//!     b: f32,
21//! }
22//!
23//! #[derive(RingMessage)]
24//! struct AddResponse {
25//!     #[message(id)]
26//!     id: MessageId,
27//!     result: f32,
28//! }
29//!
30//! #[ring_kernel(id = "adder")]
31//! async fn process(ctx: &mut RingContext, req: AddRequest) -> AddResponse {
32//!     AddResponse {
33//!         id: MessageId::generate(),
34//!         result: req.a + req.b,
35//!     }
36//! }
37//! ```
38//!
39//! # Multi-Backend GPU Kernels
40//!
41//! The `#[gpu_kernel]` macro enables multi-backend code generation with capability checking:
42//!
43//! ```ignore
44//! use ringkernel_derive::gpu_kernel;
45//!
46//! // Generate code for CUDA and Metal, with fallback order
47//! #[gpu_kernel(backends = [cuda, metal], fallback = [wgpu, cpu])]
48//! fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
49//!     let idx = global_thread_id_x();
50//!     if idx < n {
51//!         y[idx as usize] = a * x[idx as usize] + y[idx as usize];
52//!     }
53//! }
54//!
55//! // Require specific capabilities at compile time
56//! #[gpu_kernel(backends = [cuda], requires = [f64, atomic64])]
57//! fn double_precision(data: &mut [f64], n: i32) {
58//!     // Uses f64 operations - validated at compile time
59//! }
60//! ```
61//!
62//! # Stencil Kernels (with `cuda-codegen` feature)
63//!
64//! ```ignore
65//! use ringkernel_derive::stencil_kernel;
66//! use ringkernel_cuda_codegen::GridPos;
67//!
68//! #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
69//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
70//!     let curr = p[pos.idx()];
71//!     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
72//!     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
73//! }
74//! ```
75
76use darling::{ast, FromDeriveInput, FromField, FromMeta};
77use proc_macro::TokenStream;
78use quote::{format_ident, quote};
79use syn::{parse_macro_input, DeriveInput, ItemFn};
80
81/// Attributes for the RingMessage derive macro.
82#[derive(Debug, FromDeriveInput)]
83#[darling(attributes(message, ring_message), supports(struct_named))]
84struct RingMessageArgs {
85    ident: syn::Ident,
86    generics: syn::Generics,
87    data: ast::Data<(), RingMessageField>,
88    /// Optional explicit message type ID.
89    /// If domain is specified, this is the offset within the domain (0-99).
90    /// If domain is not specified, this is the absolute type ID.
91    #[darling(default)]
92    type_id: Option<u64>,
93    /// Optional domain for message classification.
94    /// When specified, the final type ID = domain.base_type_id() + type_id.
95    #[darling(default)]
96    domain: Option<String>,
97    /// Whether this message is routable via K2K.
98    /// When true, generates a K2KMessageRegistration for runtime discovery.
99    #[darling(default)]
100    k2k_routable: bool,
101    /// Optional category for K2K routing groups.
102    /// Multiple messages can share a category for grouped routing.
103    #[darling(default)]
104    category: Option<String>,
105}
106
107/// Field attributes for RingMessage.
108#[derive(Debug, FromField)]
109#[darling(attributes(message))]
110struct RingMessageField {
111    ident: Option<syn::Ident>,
112    #[allow(dead_code)]
113    ty: syn::Type,
114    /// Mark this field as the message ID.
115    #[darling(default)]
116    id: bool,
117    /// Mark this field as the correlation ID.
118    #[darling(default)]
119    correlation: bool,
120    /// Mark this field as the priority.
121    #[darling(default)]
122    priority: bool,
123}
124
125/// Derive macro for implementing the RingMessage trait.
126///
127/// # Attributes
128///
129/// On the struct (via `#[message(...)]` or `#[ring_message(...)]`):
130/// - `type_id = 123` - Set explicit message type ID (or domain offset if domain is set)
131/// - `domain = "OrderMatching"` - Assign to a business domain (adds base type ID)
132/// - `k2k_routable = true` - Register for K2K routing discovery
133/// - `category = "orders"` - Group messages for K2K routing
134///
135/// On fields:
136/// - `#[message(id)]` - Mark as message ID field
137/// - `#[message(correlation)]` - Mark as correlation ID field
138/// - `#[message(priority)]` - Mark as priority field
139///
140/// # Examples
141///
142/// Basic usage:
143/// ```ignore
144/// #[derive(RingMessage)]
145/// #[message(type_id = 1)]
146/// struct MyMessage {
147///     #[message(id)]
148///     id: MessageId,
149///     #[message(correlation)]
150///     correlation: CorrelationId,
151///     #[message(priority)]
152///     priority: Priority,
153///     payload: Vec<u8>,
154/// }
155/// ```
156///
157/// With domain (type ID = 500 + 1 = 501):
158/// ```ignore
159/// #[derive(RingMessage)]
160/// #[ring_message(type_id = 1, domain = "OrderMatching")]
161/// pub struct SubmitOrderInput {
162///     #[message(id)]
163///     id: MessageId,
164///     pub order: Order,
165/// }
166/// // Also implements DomainMessage trait
167/// assert_eq!(SubmitOrderInput::domain(), Domain::OrderMatching);
168/// ```
169///
170/// K2K-routable message:
171/// ```ignore
172/// #[derive(RingMessage)]
173/// #[ring_message(type_id = 1, domain = "OrderMatching", k2k_routable = true, category = "orders")]
174/// pub struct SubmitOrderInput { ... }
175///
176/// // Runtime discovery:
177/// let registry = K2KTypeRegistry::discover();
178/// assert!(registry.is_routable(501));
179/// ```
180#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
181pub fn derive_ring_message(input: TokenStream) -> TokenStream {
182    let input = parse_macro_input!(input as DeriveInput);
183
184    let args = match RingMessageArgs::from_derive_input(&input) {
185        Ok(args) => args,
186        Err(e) => return e.write_errors().into(),
187    };
188
189    let name = &args.ident;
190    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
191
192    // Calculate base type ID (offset within domain, or absolute if no domain)
193    let base_type_id = args.type_id.unwrap_or_else(|| {
194        use std::collections::hash_map::DefaultHasher;
195        use std::hash::{Hash, Hasher};
196        let mut hasher = DefaultHasher::new();
197        name.to_string().hash(&mut hasher);
198        // If domain is set, hash to a value within 0-99 range
199        if args.domain.is_some() {
200            hasher.finish() % 100
201        } else {
202            hasher.finish()
203        }
204    });
205
206    // Find annotated fields
207    let fields = match &args.data {
208        ast::Data::Struct(fields) => fields,
209        _ => panic!("RingMessage can only be derived for structs"),
210    };
211
212    let mut id_field: Option<&syn::Ident> = None;
213    let mut correlation_field: Option<&syn::Ident> = None;
214    let mut priority_field: Option<&syn::Ident> = None;
215
216    for field in fields.iter() {
217        if field.id {
218            id_field = field.ident.as_ref();
219        }
220        if field.correlation {
221            correlation_field = field.ident.as_ref();
222        }
223        if field.priority {
224            priority_field = field.ident.as_ref();
225        }
226    }
227
228    // Generate message_id method
229    let message_id_impl = if let Some(field) = id_field {
230        quote! { self.#field }
231    } else {
232        quote! { ::ringkernel_core::message::MessageId::new(0) }
233    };
234
235    // Generate correlation_id method
236    let correlation_id_impl = if let Some(field) = correlation_field {
237        quote! { self.#field }
238    } else {
239        quote! { ::ringkernel_core::message::CorrelationId::none() }
240    };
241
242    // Generate priority method
243    let priority_impl = if let Some(field) = priority_field {
244        quote! { self.#field }
245    } else {
246        quote! { ::ringkernel_core::message::Priority::Normal }
247    };
248
249    // Generate message_type() implementation based on whether domain is specified
250    let message_type_impl = if let Some(ref domain_str) = args.domain {
251        // With domain: type_id = domain.base_type_id() + offset
252        quote! {
253            ::ringkernel_core::domain::Domain::from_str(#domain_str)
254                .unwrap_or(::ringkernel_core::domain::Domain::General)
255                .base_type_id() + #base_type_id
256        }
257    } else {
258        // Without domain: use absolute type_id
259        quote! { #base_type_id }
260    };
261
262    // Generate DomainMessage impl if domain is specified
263    let domain_impl = if let Some(ref domain_str) = args.domain {
264        quote! {
265            impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
266                fn domain() -> ::ringkernel_core::domain::Domain {
267                    ::ringkernel_core::domain::Domain::from_str(#domain_str)
268                        .unwrap_or(::ringkernel_core::domain::Domain::General)
269                }
270            }
271        }
272    } else {
273        quote! {}
274    };
275
276    // Generate K2K registration if k2k_routable is set
277    let k2k_registration = if args.k2k_routable {
278        let registration_name = format_ident!(
279            "__K2K_MESSAGE_REGISTRATION_{}",
280            name.to_string().to_uppercase()
281        );
282        let type_name_str = name.to_string();
283        let category_tokens = match &args.category {
284            Some(cat) => quote! { ::std::option::Option::Some(#cat) },
285            None => quote! { ::std::option::Option::None },
286        };
287
288        quote! {
289            #[allow(non_upper_case_globals)]
290            #[::inventory::submit]
291            static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
292                ::ringkernel_core::k2k::K2KMessageRegistration {
293                    type_id: {
294                        // Note: This is a const context, so we use the base calculation
295                        // For domain types, we need to add the base manually
296                        #base_type_id
297                    },
298                    type_name: #type_name_str,
299                    k2k_routable: true,
300                    category: #category_tokens,
301                };
302        }
303    } else {
304        quote! {}
305    };
306
307    let expanded = quote! {
308        impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
309            fn message_type() -> u64 {
310                #message_type_impl
311            }
312
313            fn message_id(&self) -> ::ringkernel_core::message::MessageId {
314                #message_id_impl
315            }
316
317            fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
318                #correlation_id_impl
319            }
320
321            fn priority(&self) -> ::ringkernel_core::message::Priority {
322                #priority_impl
323            }
324
325            fn serialize(&self) -> Vec<u8> {
326                // Use rkyv for serialization with a 4KB scratch buffer
327                // For larger payloads, rkyv will allocate as needed
328                ::rkyv::to_bytes::<_, 4096>(self)
329                    .map(|v| v.to_vec())
330                    .unwrap_or_default()
331            }
332
333            fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
334            where
335                Self: Sized,
336            {
337                use ::rkyv::Deserialize as _;
338                let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
339                let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
340                    .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
341                        "rkyv deserialization failed".to_string()
342                    ))?;
343                Ok(deserialized)
344            }
345
346            fn size_hint(&self) -> usize {
347                ::std::mem::size_of::<Self>()
348            }
349        }
350
351        #domain_impl
352
353        #k2k_registration
354    };
355
356    TokenStream::from(expanded)
357}
358
359/// Attributes for the ring_kernel macro.
360#[derive(Debug, FromMeta)]
361struct RingKernelArgs {
362    /// Kernel identifier.
363    id: String,
364    /// Execution mode (persistent or event_driven).
365    #[darling(default)]
366    mode: Option<String>,
367    /// Grid size.
368    #[darling(default)]
369    grid_size: Option<u32>,
370    /// Block size.
371    #[darling(default)]
372    block_size: Option<u32>,
373    /// Target kernels this kernel publishes to.
374    #[darling(default)]
375    publishes_to: Option<String>,
376}
377
378/// Attribute macro for defining ring kernel handlers.
379///
380/// # Attributes
381///
382/// - `id` (required) - Unique kernel identifier
383/// - `mode` - Execution mode: "persistent" (default) or "event_driven"
384/// - `grid_size` - Number of blocks (default: 1)
385/// - `block_size` - Threads per block (default: 256)
386/// - `publishes_to` - Comma-separated list of target kernel IDs
387///
388/// # Example
389///
390/// ```ignore
391/// #[ring_kernel(id = "processor", mode = "persistent", block_size = 128)]
392/// async fn handle(ctx: &mut RingContext, msg: MyMessage) -> MyResponse {
393///     // Process message
394///     MyResponse { ... }
395/// }
396/// ```
397#[proc_macro_attribute]
398pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
399    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
400        Ok(v) => v,
401        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
402    };
403
404    let args = match RingKernelArgs::from_list(&args) {
405        Ok(v) => v,
406        Err(e) => return TokenStream::from(e.write_errors()),
407    };
408
409    let input = parse_macro_input!(item as ItemFn);
410
411    let kernel_id = &args.id;
412    let fn_name = &input.sig.ident;
413    let fn_vis = &input.vis;
414    let fn_block = &input.block;
415    let fn_attrs = &input.attrs;
416
417    // Parse function signature
418    let inputs = &input.sig.inputs;
419    let output = &input.sig.output;
420
421    // Extract context and message types from signature
422    let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
423        let ctx = inputs.first();
424        let msg = inputs.iter().nth(1);
425        (ctx, msg)
426    } else {
427        (None, None)
428    };
429
430    // Get message type
431    let msg_type = msg_arg
432        .map(|arg| {
433            if let syn::FnArg::Typed(pat_type) = arg {
434                pat_type.ty.clone()
435            } else {
436                syn::parse_quote!(())
437            }
438        })
439        .unwrap_or_else(|| syn::parse_quote!(()));
440
441    // Generate kernel mode
442    let mode = args.mode.as_deref().unwrap_or("persistent");
443    let mode_expr = if mode == "event_driven" {
444        quote! { ::ringkernel_core::types::KernelMode::EventDriven }
445    } else {
446        quote! { ::ringkernel_core::types::KernelMode::Persistent }
447    };
448
449    // Generate grid/block size
450    let grid_size = args.grid_size.unwrap_or(1);
451    let block_size = args.block_size.unwrap_or(256);
452
453    // Parse publishes_to into a list of target kernel IDs
454    let publishes_to_targets: Vec<String> = args
455        .publishes_to
456        .as_ref()
457        .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
458        .unwrap_or_default();
459
460    // Generate registration struct name
461    let registration_name = format_ident!(
462        "__RINGKERNEL_REGISTRATION_{}",
463        fn_name.to_string().to_uppercase()
464    );
465    let handler_name = format_ident!("{}_handler", fn_name);
466
467    // Generate the expanded code
468    let expanded = quote! {
469        // Original function (preserved for documentation/testing)
470        #(#fn_attrs)*
471        #fn_vis async fn #fn_name #inputs #output #fn_block
472
473        // Kernel handler wrapper
474        #fn_vis fn #handler_name(
475            ctx: &mut ::ringkernel_core::RingContext<'_>,
476            envelope: ::ringkernel_core::message::MessageEnvelope,
477        ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
478            Box::pin(async move {
479                // Deserialize input message
480                let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
481
482                // Call the actual handler
483                let response = #fn_name(ctx, msg).await;
484
485                // Serialize response
486                let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
487                let response_header = ::ringkernel_core::message::MessageHeader::new(
488                    <_ as ::ringkernel_core::message::RingMessage>::message_type(),
489                    envelope.header.dest_kernel,
490                    envelope.header.source_kernel,
491                    response_payload.len(),
492                    ctx.now(),
493                ).with_correlation(envelope.header.correlation_id);
494
495                Ok(::ringkernel_core::message::MessageEnvelope {
496                    header: response_header,
497                    payload: response_payload,
498                })
499            })
500        }
501
502        // Kernel registration
503        #[allow(non_upper_case_globals)]
504        #[::inventory::submit]
505        static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
506            id: #kernel_id,
507            mode: #mode_expr,
508            grid_size: #grid_size,
509            block_size: #block_size,
510            publishes_to: &[#(#publishes_to_targets),*],
511        };
512    };
513
514    TokenStream::from(expanded)
515}
516
517/// Derive macro for GPU-compatible types.
518///
519/// Ensures the type has a stable memory layout suitable for GPU transfer.
520#[proc_macro_derive(GpuType)]
521pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
522    let input = parse_macro_input!(input as DeriveInput);
523    let name = &input.ident;
524    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
525
526    // Generate assertions for GPU compatibility
527    let expanded = quote! {
528        // Verify type is Copy (required for GPU transfer)
529        const _: fn() = || {
530            fn assert_copy<T: Copy>() {}
531            assert_copy::<#name #ty_generics>();
532        };
533
534        // Verify type is Pod (plain old data)
535        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
536        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
537    };
538
539    TokenStream::from(expanded)
540}
541
542// ============================================================================
543// Stencil Kernel Macro (requires cuda-codegen feature)
544// ============================================================================
545
546/// Attributes for the stencil_kernel macro.
547#[derive(Debug, FromMeta)]
548struct StencilKernelArgs {
549    /// Kernel identifier.
550    id: String,
551    /// Grid dimensionality: "1d", "2d", or "3d".
552    #[darling(default)]
553    grid: Option<String>,
554    /// Tile/block size (single value for square tiles).
555    #[darling(default)]
556    tile_size: Option<u32>,
557    /// Tile width (for non-square tiles).
558    #[darling(default)]
559    tile_width: Option<u32>,
560    /// Tile height (for non-square tiles).
561    #[darling(default)]
562    tile_height: Option<u32>,
563    /// Halo/ghost cell width (stencil radius).
564    #[darling(default)]
565    halo: Option<u32>,
566}
567
568/// Attribute macro for defining stencil kernels that transpile to CUDA.
569///
570/// This macro generates CUDA C code from Rust stencil kernel functions at compile time.
571/// The generated CUDA source is embedded in the binary and can be compiled at runtime
572/// using NVRTC.
573///
574/// # Attributes
575///
576/// - `id` (required) - Unique kernel identifier
577/// - `grid` - Grid dimensionality: "1d", "2d" (default), or "3d"
578/// - `tile_size` - Tile/block size (default: 16)
579/// - `tile_width` / `tile_height` - Non-square tile dimensions
580/// - `halo` - Stencil radius / ghost cell width (default: 1)
581///
582/// # Supported Rust Subset
583///
584/// - Primitives: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
585/// - Slices: `&[T]`, `&mut [T]`
586/// - Arithmetic: `+`, `-`, `*`, `/`, `%`
587/// - Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=`
588/// - Let bindings: `let x = expr;`
589/// - If/else: `if cond { a } else { b }`
590/// - Stencil intrinsics via `GridPos`
591///
592/// # Example
593///
594/// ```ignore
595/// use ringkernel_derive::stencil_kernel;
596/// use ringkernel_cuda_codegen::GridPos;
597///
598/// #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
599/// fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
600///     let curr = p[pos.idx()];
601///     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
602///     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
603/// }
604///
605/// // Access generated CUDA source:
606/// assert!(FDTD_CUDA_SOURCE.contains("__global__"));
607/// ```
608#[proc_macro_attribute]
609pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
610    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
611        Ok(v) => v,
612        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
613    };
614
615    let args = match StencilKernelArgs::from_list(&args) {
616        Ok(v) => v,
617        Err(e) => return TokenStream::from(e.write_errors()),
618    };
619
620    let input = parse_macro_input!(item as ItemFn);
621
622    // Generate the stencil kernel code
623    stencil_kernel_impl(args, input)
624}
625
626fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
627    let kernel_id = &args.id;
628    let fn_name = &input.sig.ident;
629    let fn_vis = &input.vis;
630    let fn_block = &input.block;
631    let fn_inputs = &input.sig.inputs;
632    let fn_output = &input.sig.output;
633    let fn_attrs = &input.attrs;
634
635    // Parse configuration
636    let grid = args.grid.as_deref().unwrap_or("2d");
637    let tile_width = args
638        .tile_width
639        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
640    let tile_height = args
641        .tile_height
642        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
643    let halo = args.halo.unwrap_or(1);
644
645    // Generate CUDA source constant name
646    let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
647
648    // Generate registration name
649    let registration_name = format_ident!(
650        "__STENCIL_KERNEL_REGISTRATION_{}",
651        fn_name.to_string().to_uppercase()
652    );
653
654    // Transpile to CUDA (if feature enabled)
655    #[cfg(feature = "cuda-codegen")]
656    let cuda_source_code = {
657        use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
658
659        let grid_type = match grid {
660            "1d" => Grid::Grid1D,
661            "2d" => Grid::Grid2D,
662            "3d" => Grid::Grid3D,
663            _ => Grid::Grid2D,
664        };
665
666        let config = StencilConfig::new(kernel_id.clone())
667            .with_grid(grid_type)
668            .with_tile_size(tile_width as usize, tile_height as usize)
669            .with_halo(halo as usize);
670
671        match transpile_stencil_kernel(&input, &config) {
672            Ok(cuda) => cuda,
673            Err(e) => {
674                return TokenStream::from(
675                    syn::Error::new_spanned(
676                        &input.sig.ident,
677                        format!("CUDA transpilation failed: {}", e),
678                    )
679                    .to_compile_error(),
680                );
681            }
682        }
683    };
684
685    #[cfg(not(feature = "cuda-codegen"))]
686    let cuda_source_code = format!(
687        "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
688        kernel_id
689    );
690
691    // Generate the expanded code
692    let expanded = quote! {
693        // Original function (for documentation/testing/CPU fallback)
694        #(#fn_attrs)*
695        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
696
697        /// Generated CUDA source code for this stencil kernel.
698        #fn_vis const #cuda_const_name: &str = #cuda_source_code;
699
700        /// Stencil kernel registration for runtime discovery.
701        #[allow(non_upper_case_globals)]
702        #[::inventory::submit]
703        static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
704            ::ringkernel_core::__private::StencilKernelRegistration {
705                id: #kernel_id,
706                grid: #grid,
707                tile_width: #tile_width,
708                tile_height: #tile_height,
709                halo: #halo,
710                cuda_source: #cuda_source_code,
711            };
712    };
713
714    TokenStream::from(expanded)
715}
716
717// ============================================================================
718// Multi-Backend GPU Kernel Macro
719// ============================================================================
720
721/// GPU backend targets (internal use only).
722#[derive(Debug, Clone, Copy, PartialEq, Eq)]
723enum GpuBackend {
724    /// NVIDIA CUDA backend.
725    Cuda,
726    /// Apple Metal backend.
727    Metal,
728    /// WebGPU backend (cross-platform).
729    Wgpu,
730    /// CPU fallback backend.
731    Cpu,
732}
733
734impl GpuBackend {
735    fn from_str(s: &str) -> Option<Self> {
736        match s.to_lowercase().as_str() {
737            "cuda" => Some(Self::Cuda),
738            "metal" => Some(Self::Metal),
739            "wgpu" | "webgpu" => Some(Self::Wgpu),
740            "cpu" => Some(Self::Cpu),
741            _ => None,
742        }
743    }
744
745    fn as_str(&self) -> &'static str {
746        match self {
747            Self::Cuda => "cuda",
748            Self::Metal => "metal",
749            Self::Wgpu => "wgpu",
750            Self::Cpu => "cpu",
751        }
752    }
753}
754
755/// GPU capability flags that can be required by a kernel (internal use only).
756#[derive(Debug, Clone, Copy, PartialEq, Eq)]
757enum GpuCapability {
758    /// 64-bit floating point support.
759    Float64,
760    /// 64-bit integer support.
761    Int64,
762    /// 64-bit atomics support.
763    Atomic64,
764    /// Cooperative groups / grid-wide sync.
765    CooperativeGroups,
766    /// Subgroup / warp / SIMD operations.
767    Subgroups,
768    /// Shared memory / threadgroup memory.
769    SharedMemory,
770    /// Dynamic parallelism (launching kernels from kernels).
771    DynamicParallelism,
772    /// Half-precision (f16) support.
773    Float16,
774}
775
776impl GpuCapability {
777    fn from_str(s: &str) -> Option<Self> {
778        match s.to_lowercase().as_str() {
779            "f64" | "float64" => Some(Self::Float64),
780            "i64" | "int64" => Some(Self::Int64),
781            "atomic64" => Some(Self::Atomic64),
782            "cooperative_groups" | "cooperativegroups" | "grid_sync" => {
783                Some(Self::CooperativeGroups)
784            }
785            "subgroups" | "warp" | "simd" => Some(Self::Subgroups),
786            "shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
787            "dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
788            "f16" | "float16" | "half" => Some(Self::Float16),
789            _ => None,
790        }
791    }
792
793    fn as_str(&self) -> &'static str {
794        match self {
795            Self::Float64 => "f64",
796            Self::Int64 => "i64",
797            Self::Atomic64 => "atomic64",
798            Self::CooperativeGroups => "cooperative_groups",
799            Self::Subgroups => "subgroups",
800            Self::SharedMemory => "shared_memory",
801            Self::DynamicParallelism => "dynamic_parallelism",
802            Self::Float16 => "f16",
803        }
804    }
805
806    /// Check if a backend supports this capability.
807    fn supported_by(&self, backend: GpuBackend) -> bool {
808        match (self, backend) {
809            // CUDA supports everything
810            (_, GpuBackend::Cuda) => true,
811
812            // Metal capabilities
813            (Self::Float64, GpuBackend::Metal) => false,
814            (Self::CooperativeGroups, GpuBackend::Metal) => false,
815            (Self::DynamicParallelism, GpuBackend::Metal) => false,
816            (_, GpuBackend::Metal) => true,
817
818            // WebGPU capabilities
819            (Self::Float64, GpuBackend::Wgpu) => false,
820            (Self::Int64, GpuBackend::Wgpu) => false,
821            (Self::Atomic64, GpuBackend::Wgpu) => false, // Emulated only
822            (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
823            (Self::DynamicParallelism, GpuBackend::Wgpu) => false,
824            (Self::Subgroups, GpuBackend::Wgpu) => true, // Optional extension
825            (_, GpuBackend::Wgpu) => true,
826
827            // CPU supports everything (in emulation)
828            (_, GpuBackend::Cpu) => true,
829        }
830    }
831}
832
833/// Attributes for the gpu_kernel macro.
834#[derive(Debug)]
835struct GpuKernelArgs {
836    /// Kernel identifier.
837    id: Option<String>,
838    /// Target backends to generate code for.
839    backends: Vec<GpuBackend>,
840    /// Fallback order for backend selection.
841    fallback: Vec<GpuBackend>,
842    /// Required capabilities.
843    requires: Vec<GpuCapability>,
844    /// Block/workgroup size.
845    block_size: Option<u32>,
846}
847
848impl Default for GpuKernelArgs {
849    fn default() -> Self {
850        Self {
851            id: None,
852            backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
853            fallback: vec![
854                GpuBackend::Cuda,
855                GpuBackend::Metal,
856                GpuBackend::Wgpu,
857                GpuBackend::Cpu,
858            ],
859            requires: Vec::new(),
860            block_size: None,
861        }
862    }
863}
864
865impl GpuKernelArgs {
866    fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
867        let mut args = Self::default();
868        let attr_str = attr.to_string();
869
870        // Parse backends = [...]
871        if let Some(start) = attr_str.find("backends") {
872            if let Some(bracket_start) = attr_str[start..].find('[') {
873                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
874                    let backends_str =
875                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
876                    args.backends = backends_str
877                        .split(',')
878                        .filter_map(|s| GpuBackend::from_str(s.trim()))
879                        .collect();
880                }
881            }
882        }
883
884        // Parse fallback = [...]
885        if let Some(start) = attr_str.find("fallback") {
886            if let Some(bracket_start) = attr_str[start..].find('[') {
887                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
888                    let fallback_str =
889                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
890                    args.fallback = fallback_str
891                        .split(',')
892                        .filter_map(|s| GpuBackend::from_str(s.trim()))
893                        .collect();
894                }
895            }
896        }
897
898        // Parse requires = [...]
899        if let Some(start) = attr_str.find("requires") {
900            if let Some(bracket_start) = attr_str[start..].find('[') {
901                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
902                    let requires_str =
903                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
904                    args.requires = requires_str
905                        .split(',')
906                        .filter_map(|s| GpuCapability::from_str(s.trim()))
907                        .collect();
908                }
909            }
910        }
911
912        // Parse id = "..."
913        if let Some(start) = attr_str.find("id") {
914            if let Some(quote_start) = attr_str[start..].find('"') {
915                if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
916                    args.id = Some(
917                        attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
918                            .to_string(),
919                    );
920                }
921            }
922        }
923
924        // Parse block_size = N
925        if let Some(start) = attr_str.find("block_size") {
926            if let Some(eq) = attr_str[start..].find('=') {
927                let rest = &attr_str[start + eq + 1..];
928                let num_end = rest
929                    .find(|c: char| !c.is_numeric() && c != ' ')
930                    .unwrap_or(rest.len());
931                if let Ok(n) = rest[..num_end].trim().parse() {
932                    args.block_size = Some(n);
933                }
934            }
935        }
936
937        Ok(args)
938    }
939
940    /// Validate that all required capabilities are supported by at least one backend.
941    fn validate_capabilities(&self) -> Result<(), String> {
942        for cap in &self.requires {
943            let mut supported_by_any = false;
944            for backend in &self.backends {
945                if cap.supported_by(*backend) {
946                    supported_by_any = true;
947                    break;
948                }
949            }
950            if !supported_by_any {
951                return Err(format!(
952                    "Capability '{}' is not supported by any of the specified backends: {:?}",
953                    cap.as_str(),
954                    self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
955                ));
956            }
957        }
958        Ok(())
959    }
960
961    /// Get backends that support all required capabilities.
962    fn compatible_backends(&self) -> Vec<GpuBackend> {
963        self.backends
964            .iter()
965            .filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
966            .copied()
967            .collect()
968    }
969}
970
971/// Attribute macro for defining multi-backend GPU kernels.
972///
973/// This macro generates code for multiple GPU backends with compile-time
974/// capability validation. It integrates with the `ringkernel-ir` crate
975/// to lower Rust DSL to backend-specific shader code.
976///
977/// # Attributes
978///
979/// - `backends = [cuda, metal, wgpu]` - Target backends (default: all)
980/// - `fallback = [cuda, metal, wgpu, cpu]` - Fallback order for runtime selection
981/// - `requires = [f64, atomic64]` - Required capabilities (validated at compile time)
982/// - `id = "kernel_name"` - Explicit kernel identifier
983/// - `block_size = 256` - Thread block size
984///
985/// # Example
986///
987/// ```ignore
988/// use ringkernel_derive::gpu_kernel;
989///
990/// #[gpu_kernel(backends = [cuda, metal], requires = [subgroups])]
991/// fn warp_reduce(data: &mut [f32], n: i32) {
992///     let idx = global_thread_id_x();
993///     if idx < n {
994///         // Use warp shuffle for reduction
995///         let val = data[idx as usize];
996///         let reduced = warp_reduce_sum(val);
997///         if lane_id() == 0 {
998///             data[idx as usize] = reduced;
999///         }
1000///     }
1001/// }
1002/// ```
1003///
1004/// # Capability Checking
1005///
1006/// The macro validates at compile time that all required capabilities are
1007/// supported by at least one target backend:
1008///
1009/// | Capability | CUDA | Metal | WebGPU | CPU |
1010/// |------------|------|-------|--------|-----|
1011/// | f64        | Yes  | No    | No     | Yes |
1012/// | i64        | Yes  | Yes   | No     | Yes |
1013/// | atomic64   | Yes  | Yes   | No*    | Yes |
1014/// | cooperative_groups | Yes | No | No | Yes |
1015/// | subgroups  | Yes  | Yes   | Opt    | Yes |
1016/// | shared_memory | Yes | Yes | Yes    | Yes |
1017/// | f16        | Yes  | Yes   | Yes    | Yes |
1018///
1019/// *WebGPU emulates 64-bit atomics with 32-bit pairs.
1020///
1021/// # Generated Code
1022///
1023/// For each compatible backend, the macro generates:
1024/// - Backend-specific source code constant (e.g., `KERNEL_NAME_CUDA_SOURCE`)
1025/// - Registration entry for runtime discovery
1026/// - CPU fallback function (if `cpu_fallback = true`)
1027#[proc_macro_attribute]
1028pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
1029    let attr2: proc_macro2::TokenStream = attr.into();
1030    let args = match GpuKernelArgs::parse(attr2) {
1031        Ok(args) => args,
1032        Err(e) => return TokenStream::from(e.write_errors()),
1033    };
1034
1035    let input = parse_macro_input!(item as ItemFn);
1036
1037    // Validate capabilities
1038    if let Err(msg) = args.validate_capabilities() {
1039        return TokenStream::from(
1040            syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
1041        );
1042    }
1043
1044    gpu_kernel_impl(args, input)
1045}
1046
1047fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
1048    let fn_name = &input.sig.ident;
1049    let fn_vis = &input.vis;
1050    let fn_block = &input.block;
1051    let fn_inputs = &input.sig.inputs;
1052    let fn_output = &input.sig.output;
1053    let fn_attrs = &input.attrs;
1054
1055    let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
1056    let block_size = args.block_size.unwrap_or(256);
1057
1058    // Get compatible backends
1059    let compatible_backends = args.compatible_backends();
1060
1061    // Generate backend-specific source constants
1062    let mut source_constants = Vec::new();
1063
1064    for backend in &compatible_backends {
1065        let const_name = format_ident!(
1066            "{}_{}",
1067            fn_name.to_string().to_uppercase(),
1068            backend.as_str().to_uppercase()
1069        );
1070
1071        let backend_str = backend.as_str();
1072
1073        // Generate placeholder source (actual IR lowering happens at build time)
1074        // In a full implementation, this would call ringkernel-ir lowering
1075        let source_placeholder = format!(
1076            "// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
1077            backend_str.to_uppercase(),
1078            kernel_id,
1079            args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
1080        );
1081
1082        source_constants.push(quote! {
1083            /// Generated source code for this kernel.
1084            #fn_vis const #const_name: &str = #source_placeholder;
1085        });
1086    }
1087
1088    // Generate capability flags as strings
1089    let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
1090    let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
1091    let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
1092
1093    // Generate registration struct name
1094    let registration_name = format_ident!(
1095        "__GPU_KERNEL_REGISTRATION_{}",
1096        fn_name.to_string().to_uppercase()
1097    );
1098
1099    // Generate info struct name
1100    let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
1101
1102    // Generate the expanded code
1103    let expanded = quote! {
1104        // Original function (CPU fallback / documentation / testing)
1105        #(#fn_attrs)*
1106        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
1107
1108        // Backend source constants
1109        #(#source_constants)*
1110
1111        /// Multi-backend kernel information.
1112        #fn_vis mod #info_name {
1113            /// Kernel identifier.
1114            pub const ID: &str = #kernel_id;
1115
1116            /// Block/workgroup size.
1117            pub const BLOCK_SIZE: u32 = #block_size;
1118
1119            /// Required capabilities.
1120            pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
1121
1122            /// Compatible backends (those that support all required capabilities).
1123            pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
1124
1125            /// Fallback order for runtime backend selection.
1126            pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
1127        }
1128
1129        /// GPU kernel registration for runtime discovery.
1130        #[allow(non_upper_case_globals)]
1131        #[::inventory::submit]
1132        static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
1133            ::ringkernel_core::__private::GpuKernelRegistration {
1134                id: #kernel_id,
1135                block_size: #block_size,
1136                capabilities: &[#(#capability_strs),*],
1137                backends: &[#(#backend_strs),*],
1138                fallback_order: &[#(#fallback_strs),*],
1139            };
1140    };
1141
1142    TokenStream::from(expanded)
1143}
1144
1145// ============================================================================
1146// ControlBlockState Derive Macro (FR-4)
1147// ============================================================================
1148
1149/// Attributes for the ControlBlockState derive macro.
1150#[derive(Debug, FromDeriveInput)]
1151#[darling(attributes(state), supports(struct_named))]
1152struct ControlBlockStateArgs {
1153    ident: syn::Ident,
1154    generics: syn::Generics,
1155    /// State version for forward compatibility.
1156    #[darling(default)]
1157    version: Option<u32>,
1158}
1159
1160/// Derive macro for implementing EmbeddedState trait.
1161///
1162/// This macro generates implementations for types that can be stored in
1163/// the ControlBlock's 24-byte `_reserved` field for zero-copy state access.
1164///
1165/// # Requirements
1166///
1167/// The type must:
1168/// - Be `#[repr(C)]` for stable memory layout
1169/// - Be <= 24 bytes in size (checked at compile time)
1170/// - Implement `Clone`, `Copy`, and `Default`
1171/// - Contain only POD (Plain Old Data) types
1172///
1173/// # Attributes
1174///
1175/// - `#[state(version = N)]` - Set state version for migrations (default: 1)
1176///
1177/// # Example
1178///
1179/// ```ignore
1180/// #[derive(ControlBlockState, Default, Clone, Copy)]
1181/// #[repr(C, align(8))]
1182/// #[state(version = 1)]
1183/// pub struct OrderBookState {
1184///     pub best_bid: u64,    // 8 bytes
1185///     pub best_ask: u64,    // 8 bytes
1186///     pub order_count: u32, // 4 bytes
1187///     pub _pad: u32,        // 4 bytes (padding for alignment)
1188/// }  // Total: 24 bytes - fits in ControlBlock._reserved
1189///
1190/// // Use with ControlBlockStateHelper:
1191/// let mut block = ControlBlock::new();
1192/// let state = OrderBookState { best_bid: 100, best_ask: 101, order_count: 42, _pad: 0 };
1193/// ControlBlockStateHelper::write_embedded(&mut block, &state)?;
1194/// ```
1195///
1196/// # Size Validation
1197///
1198/// The macro generates a compile-time assertion that fails if the type
1199/// exceeds 24 bytes:
1200///
1201/// ```ignore
1202/// #[derive(ControlBlockState, Default, Clone, Copy)]
1203/// #[repr(C)]
1204/// struct TooLarge {
1205///     data: [u8; 32],  // 32 bytes - COMPILE ERROR!
1206/// }
1207/// ```
1208#[proc_macro_derive(ControlBlockState, attributes(state))]
1209pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
1210    let input = parse_macro_input!(input as DeriveInput);
1211
1212    let args = match ControlBlockStateArgs::from_derive_input(&input) {
1213        Ok(args) => args,
1214        Err(e) => return e.write_errors().into(),
1215    };
1216
1217    let name = &args.ident;
1218    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
1219    let version = args.version.unwrap_or(1);
1220
1221    let expanded = quote! {
1222        // Compile-time size check: EmbeddedState must fit in 24 bytes
1223        const _: () = {
1224            assert!(
1225                ::std::mem::size_of::<#name #ty_generics>() <= 24,
1226                "ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
1227            );
1228        };
1229
1230        // Verify type is Copy (required for GPU transfer)
1231        const _: fn() = || {
1232            fn assert_copy<T: Copy>() {}
1233            assert_copy::<#name #ty_generics>();
1234        };
1235
1236        // Implement Pod and Zeroable (required by EmbeddedState)
1237        // SAFETY: Type is #[repr(C)] with only primitive types, verified by user
1238        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
1239        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
1240
1241        // Implement EmbeddedState
1242        impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
1243            const VERSION: u32 = #version;
1244
1245            fn is_embedded() -> bool {
1246                true
1247            }
1248        }
1249    };
1250
1251    TokenStream::from(expanded)
1252}