Skip to main content

ringkernel_derive/
lib.rs

1//! Procedural macros for RingKernel.
2//!
3//! This crate provides the following macros:
4//!
5//! - `#[derive(RingMessage)]` - Implement the RingMessage trait for message types
6//! - `#[derive(PersistentMessage)]` - Implement PersistentMessage for GPU kernel dispatch
7//! - `#[ring_kernel]` - Define a ring kernel handler
8//! - `#[stencil_kernel]` - Define a GPU stencil kernel (with `cuda-codegen` feature)
9//! - `#[gpu_kernel]` - Define a multi-backend GPU kernel with capability checking
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ringkernel_derive::{RingMessage, ring_kernel};
15//!
16//! #[derive(RingMessage)]
17//! struct AddRequest {
18//!     #[message(id)]
19//!     id: MessageId,
20//!     a: f32,
21//!     b: f32,
22//! }
23//!
24//! #[derive(RingMessage)]
25//! struct AddResponse {
26//!     #[message(id)]
27//!     id: MessageId,
28//!     result: f32,
29//! }
30//!
31//! #[ring_kernel(id = "adder")]
32//! async fn process(ctx: &mut RingContext, req: AddRequest) -> AddResponse {
33//!     AddResponse {
34//!         id: MessageId::generate(),
35//!         result: req.a + req.b,
36//!     }
37//! }
38//! ```
39//!
40//! # Multi-Backend GPU Kernels
41//!
42//! The `#[gpu_kernel]` macro enables multi-backend code generation with capability checking:
43//!
44//! ```ignore
45//! use ringkernel_derive::gpu_kernel;
46//!
47//! // Generate code for CUDA and Metal, with fallback order
48//! #[gpu_kernel(backends = [cuda, metal], fallback = [wgpu, cpu])]
49//! fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
50//!     let idx = global_thread_id_x();
51//!     if idx < n {
52//!         y[idx as usize] = a * x[idx as usize] + y[idx as usize];
53//!     }
54//! }
55//!
56//! // Require specific capabilities at compile time
57//! #[gpu_kernel(backends = [cuda], requires = [f64, atomic64])]
58//! fn double_precision(data: &mut [f64], n: i32) {
59//!     // Uses f64 operations - validated at compile time
60//! }
61//! ```
62//!
63//! # Stencil Kernels (with `cuda-codegen` feature)
64//!
65//! ```ignore
66//! use ringkernel_derive::stencil_kernel;
67//! use ringkernel_cuda_codegen::GridPos;
68//!
69//! #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
70//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
71//!     let curr = p[pos.idx()];
72//!     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
73//!     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
74//! }
75//! ```
76
77use darling::{ast, FromDeriveInput, FromField, FromMeta};
78use proc_macro::TokenStream;
79use quote::{format_ident, quote};
80use syn::{parse_macro_input, DeriveInput, ItemFn};
81
82/// Attributes for the RingMessage derive macro.
83#[derive(Debug, FromDeriveInput)]
84#[darling(attributes(message, ring_message), supports(struct_named))]
85struct RingMessageArgs {
86    ident: syn::Ident,
87    generics: syn::Generics,
88    data: ast::Data<(), RingMessageField>,
89    /// Optional explicit message type ID.
90    /// If domain is specified, this is the offset within the domain (0-99).
91    /// If domain is not specified, this is the absolute type ID.
92    #[darling(default)]
93    type_id: Option<u64>,
94    /// Optional domain for message classification.
95    /// When specified, the final type ID = domain.base_type_id() + type_id.
96    #[darling(default)]
97    domain: Option<String>,
98    /// Whether this message is routable via K2K.
99    /// When true, generates a K2KMessageRegistration for runtime discovery.
100    #[darling(default)]
101    k2k_routable: bool,
102    /// Optional category for K2K routing groups.
103    /// Multiple messages can share a category for grouped routing.
104    #[darling(default)]
105    category: Option<String>,
106}
107
108/// Field attributes for RingMessage.
109#[derive(Debug, FromField)]
110#[darling(attributes(message))]
111struct RingMessageField {
112    ident: Option<syn::Ident>,
113    #[allow(dead_code)]
114    ty: syn::Type,
115    /// Mark this field as the message ID.
116    #[darling(default)]
117    id: bool,
118    /// Mark this field as the correlation ID.
119    #[darling(default)]
120    correlation: bool,
121    /// Mark this field as the priority.
122    #[darling(default)]
123    priority: bool,
124}
125
126/// Derive macro for implementing the RingMessage trait.
127///
128/// # Attributes
129///
130/// On the struct (via `#[message(...)]` or `#[ring_message(...)]`):
131/// - `type_id = 123` - Set explicit message type ID (or domain offset if domain is set)
132/// - `domain = "OrderMatching"` - Assign to a business domain (adds base type ID)
133/// - `k2k_routable = true` - Register for K2K routing discovery
134/// - `category = "orders"` - Group messages for K2K routing
135///
136/// On fields:
137/// - `#[message(id)]` - Mark as message ID field
138/// - `#[message(correlation)]` - Mark as correlation ID field
139/// - `#[message(priority)]` - Mark as priority field
140///
141/// # Examples
142///
143/// Basic usage:
144/// ```ignore
145/// #[derive(RingMessage)]
146/// #[message(type_id = 1)]
147/// struct MyMessage {
148///     #[message(id)]
149///     id: MessageId,
150///     #[message(correlation)]
151///     correlation: CorrelationId,
152///     #[message(priority)]
153///     priority: Priority,
154///     payload: Vec<u8>,
155/// }
156/// ```
157///
158/// With domain (type ID = 500 + 1 = 501):
159/// ```ignore
160/// #[derive(RingMessage)]
161/// #[ring_message(type_id = 1, domain = "OrderMatching")]
162/// pub struct SubmitOrderInput {
163///     #[message(id)]
164///     id: MessageId,
165///     pub order: Order,
166/// }
167/// // Also implements DomainMessage trait
168/// assert_eq!(SubmitOrderInput::domain(), Domain::OrderMatching);
169/// ```
170///
171/// K2K-routable message:
172/// ```ignore
173/// #[derive(RingMessage)]
174/// #[ring_message(type_id = 1, domain = "OrderMatching", k2k_routable = true, category = "orders")]
175/// pub struct SubmitOrderInput { ... }
176///
177/// // Runtime discovery:
178/// let registry = K2KTypeRegistry::discover();
179/// assert!(registry.is_routable(501));
180/// ```
181#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
182pub fn derive_ring_message(input: TokenStream) -> TokenStream {
183    let input = parse_macro_input!(input as DeriveInput);
184
185    let args = match RingMessageArgs::from_derive_input(&input) {
186        Ok(args) => args,
187        Err(e) => return e.write_errors().into(),
188    };
189
190    let name = &args.ident;
191    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
192
193    // Calculate base type ID (offset within domain, or absolute if no domain)
194    let base_type_id = args.type_id.unwrap_or_else(|| {
195        use std::collections::hash_map::DefaultHasher;
196        use std::hash::{Hash, Hasher};
197        let mut hasher = DefaultHasher::new();
198        name.to_string().hash(&mut hasher);
199        // If domain is set, hash to a value within 0-99 range
200        if args.domain.is_some() {
201            hasher.finish() % 100
202        } else {
203            hasher.finish()
204        }
205    });
206
207    // Find annotated fields
208    let fields = match &args.data {
209        ast::Data::Struct(fields) => fields,
210        _ => panic!("RingMessage can only be derived for structs"),
211    };
212
213    let mut id_field: Option<&syn::Ident> = None;
214    let mut correlation_field: Option<&syn::Ident> = None;
215    let mut priority_field: Option<&syn::Ident> = None;
216
217    for field in fields.iter() {
218        if field.id {
219            id_field = field.ident.as_ref();
220        }
221        if field.correlation {
222            correlation_field = field.ident.as_ref();
223        }
224        if field.priority {
225            priority_field = field.ident.as_ref();
226        }
227    }
228
229    // Generate message_id method
230    let message_id_impl = if let Some(field) = id_field {
231        quote! { self.#field }
232    } else {
233        quote! { ::ringkernel_core::message::MessageId::new(0) }
234    };
235
236    // Generate correlation_id method
237    let correlation_id_impl = if let Some(field) = correlation_field {
238        quote! { self.#field }
239    } else {
240        quote! { ::ringkernel_core::message::CorrelationId::none() }
241    };
242
243    // Generate priority method
244    let priority_impl = if let Some(field) = priority_field {
245        quote! { self.#field }
246    } else {
247        quote! { ::ringkernel_core::message::Priority::Normal }
248    };
249
250    // Generate message_type() implementation based on whether domain is specified
251    let message_type_impl = if let Some(ref domain_str) = args.domain {
252        // With domain: type_id = domain.base_type_id() + offset
253        quote! {
254            ::ringkernel_core::domain::Domain::from_str(#domain_str)
255                .unwrap_or(::ringkernel_core::domain::Domain::General)
256                .base_type_id() + #base_type_id
257        }
258    } else {
259        // Without domain: use absolute type_id
260        quote! { #base_type_id }
261    };
262
263    // Generate DomainMessage impl if domain is specified
264    let domain_impl = if let Some(ref domain_str) = args.domain {
265        quote! {
266            impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
267                fn domain() -> ::ringkernel_core::domain::Domain {
268                    ::ringkernel_core::domain::Domain::from_str(#domain_str)
269                        .unwrap_or(::ringkernel_core::domain::Domain::General)
270                }
271            }
272        }
273    } else {
274        quote! {}
275    };
276
277    // Generate K2K registration if k2k_routable is set
278    let k2k_registration = if args.k2k_routable {
279        let registration_name = format_ident!(
280            "__K2K_MESSAGE_REGISTRATION_{}",
281            name.to_string().to_uppercase()
282        );
283        let type_name_str = name.to_string();
284        let category_tokens = match &args.category {
285            Some(cat) => quote! { ::std::option::Option::Some(#cat) },
286            None => quote! { ::std::option::Option::None },
287        };
288
289        quote! {
290            #[allow(non_upper_case_globals)]
291            #[::inventory::submit]
292            static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
293                ::ringkernel_core::k2k::K2KMessageRegistration {
294                    type_id: {
295                        // Note: This is a const context, so we use the base calculation
296                        // For domain types, we need to add the base manually
297                        #base_type_id
298                    },
299                    type_name: #type_name_str,
300                    k2k_routable: true,
301                    category: #category_tokens,
302                };
303        }
304    } else {
305        quote! {}
306    };
307
308    let expanded = quote! {
309        impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
310            fn message_type() -> u64 {
311                #message_type_impl
312            }
313
314            fn message_id(&self) -> ::ringkernel_core::message::MessageId {
315                #message_id_impl
316            }
317
318            fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
319                #correlation_id_impl
320            }
321
322            fn priority(&self) -> ::ringkernel_core::message::Priority {
323                #priority_impl
324            }
325
326            fn serialize(&self) -> Vec<u8> {
327                // Use rkyv for serialization with a 4KB scratch buffer
328                // For larger payloads, rkyv will allocate as needed
329                ::rkyv::to_bytes::<_, 4096>(self)
330                    .map(|v| v.to_vec())
331                    .unwrap_or_default()
332            }
333
334            fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
335            where
336                Self: Sized,
337            {
338                use ::rkyv::Deserialize as _;
339                let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
340                let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
341                    .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
342                        "rkyv deserialization failed".to_string()
343                    ))?;
344                Ok(deserialized)
345            }
346
347            fn size_hint(&self) -> usize {
348                ::std::mem::size_of::<Self>()
349            }
350        }
351
352        #domain_impl
353
354        #k2k_registration
355    };
356
357    TokenStream::from(expanded)
358}
359
360// ============================================================================
361// PersistentMessage Derive Macro
362// ============================================================================
363
364/// Maximum size for inline payload in persistent messages.
365#[allow(dead_code)]
366const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
367
368/// Attributes for the PersistentMessage derive macro.
369#[derive(Debug, FromDeriveInput)]
370#[darling(attributes(persistent_message), supports(struct_named))]
371struct PersistentMessageArgs {
372    ident: syn::Ident,
373    generics: syn::Generics,
374    /// Field data (reserved for future per-field attributes).
375    #[allow(dead_code)]
376    data: ast::Data<(), PersistentMessageField>,
377    /// Handler ID for CUDA dispatch (0-255).
378    handler_id: u32,
379    /// Whether this message type expects a response.
380    #[darling(default)]
381    requires_response: bool,
382}
383
384/// Field attributes for PersistentMessage (reserved for future use).
385#[derive(Debug, FromField)]
386#[darling(attributes(persistent_message))]
387struct PersistentMessageField {
388    /// Field identifier.
389    #[allow(dead_code)]
390    ident: Option<syn::Ident>,
391    /// Field type.
392    #[allow(dead_code)]
393    ty: syn::Type,
394}
395
396/// Derive macro for implementing the PersistentMessage trait.
397///
398/// This macro enables type-based dispatch within persistent GPU kernels by
399/// generating handler_id, inline payload serialization, and deserialization.
400///
401/// # Requirements
402///
403/// The struct must:
404/// - Already implement `RingMessage` (use `#[derive(RingMessage)]`)
405/// - Be `#[repr(C)]` for safe memory layout
406/// - Be `Copy` + `Clone` for inline payload serialization
407///
408/// # Attributes
409///
410/// On the struct:
411/// - `handler_id = N` (required) - CUDA dispatch handler ID (0-255)
412/// - `requires_response = true/false` (optional) - Whether this message expects a response
413///
414/// # Example
415///
416/// ```ignore
417/// use ringkernel_derive::{RingMessage, PersistentMessage};
418///
419/// #[derive(RingMessage, PersistentMessage, Clone, Copy)]
420/// #[repr(C)]
421/// #[message(type_id = 1001)]
422/// #[persistent_message(handler_id = 1, requires_response = true)]
423/// pub struct FraudCheckRequest {
424///     pub transaction_id: u64,
425///     pub amount: f32,
426///     pub account_id: u32,
427/// }
428///
429/// // Generated implementations:
430/// // - handler_id() returns 1
431/// // - requires_response() returns true
432/// // - to_inline_payload() serializes the struct to [u8; 32] if it fits
433/// // - from_inline_payload() deserializes from bytes
434/// // - payload_size() returns the struct size
435/// ```
436///
437/// # Size Validation
438///
439/// For inline payload serialization, structs must be <= 32 bytes.
440/// Larger structs will return `None` from `to_inline_payload()`.
441///
442/// # CUDA Integration
443///
444/// The handler_id maps to a switch case in generated CUDA code:
445///
446/// ```cuda
447/// switch (msg->handler_id) {
448///     case 1: handle_fraud_check(msg, state, response); break;
449///     case 2: handle_aggregate(msg, state, response); break;
450///     // ...
451/// }
452/// ```
453#[proc_macro_derive(PersistentMessage, attributes(persistent_message))]
454pub fn derive_persistent_message(input: TokenStream) -> TokenStream {
455    let input = parse_macro_input!(input as DeriveInput);
456
457    let args = match PersistentMessageArgs::from_derive_input(&input) {
458        Ok(args) => args,
459        Err(e) => return e.write_errors().into(),
460    };
461
462    let name = &args.ident;
463    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
464
465    let handler_id = args.handler_id;
466    let requires_response = args.requires_response;
467
468    // Generate the PersistentMessage implementation
469    let expanded = quote! {
470        impl #impl_generics ::ringkernel_core::persistent_message::PersistentMessage for #name #ty_generics #where_clause {
471            fn handler_id() -> u32 {
472                #handler_id
473            }
474
475            fn requires_response() -> bool {
476                #requires_response
477            }
478
479            fn payload_size() -> usize {
480                ::std::mem::size_of::<Self>()
481            }
482
483            fn to_inline_payload(&self) -> ::std::option::Option<[u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
484                // Only serialize if the struct fits in the inline payload
485                if ::std::mem::size_of::<Self>() > ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE {
486                    return ::std::option::Option::None;
487                }
488
489                let mut payload = [0u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE];
490
491                // Safety: We've verified the struct fits in the payload,
492                // and the struct is repr(C) + Copy
493                unsafe {
494                    ::std::ptr::copy_nonoverlapping(
495                        self as *const Self as *const u8,
496                        payload.as_mut_ptr(),
497                        ::std::mem::size_of::<Self>()
498                    );
499                }
500
501                ::std::option::Option::Some(payload)
502            }
503
504            fn from_inline_payload(payload: &[u8]) -> ::ringkernel_core::error::Result<Self> {
505                let size = ::std::mem::size_of::<Self>();
506
507                if payload.len() < size {
508                    return ::std::result::Result::Err(
509                        ::ringkernel_core::error::RingKernelError::DeserializationError(
510                            ::std::format!(
511                                "Payload too small: expected {} bytes, got {}",
512                                size,
513                                payload.len()
514                            )
515                        )
516                    );
517                }
518
519                // Safety: We've verified the payload is large enough,
520                // and the struct is repr(C) + Copy
521                let value = unsafe {
522                    ::std::ptr::read(payload.as_ptr() as *const Self)
523                };
524
525                ::std::result::Result::Ok(value)
526            }
527        }
528    };
529
530    TokenStream::from(expanded)
531}
532
533/// Attributes for the ring_kernel macro.
534#[derive(Debug, FromMeta)]
535struct RingKernelArgs {
536    /// Kernel identifier.
537    id: String,
538    /// Execution mode (persistent or event_driven).
539    #[darling(default)]
540    mode: Option<String>,
541    /// Grid size.
542    #[darling(default)]
543    grid_size: Option<u32>,
544    /// Block size.
545    #[darling(default)]
546    block_size: Option<u32>,
547    /// Target kernels this kernel publishes to.
548    #[darling(default)]
549    publishes_to: Option<String>,
550}
551
552/// Attribute macro for defining ring kernel handlers.
553///
554/// # Attributes
555///
556/// - `id` (required) - Unique kernel identifier
557/// - `mode` - Execution mode: "persistent" (default) or "event_driven"
558/// - `grid_size` - Number of blocks (default: 1)
559/// - `block_size` - Threads per block (default: 256)
560/// - `publishes_to` - Comma-separated list of target kernel IDs
561///
562/// # Example
563///
564/// ```ignore
565/// #[ring_kernel(id = "processor", mode = "persistent", block_size = 128)]
566/// async fn handle(ctx: &mut RingContext, msg: MyMessage) -> MyResponse {
567///     // Process message
568///     MyResponse { ... }
569/// }
570/// ```
571#[proc_macro_attribute]
572pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
573    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
574        Ok(v) => v,
575        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
576    };
577
578    let args = match RingKernelArgs::from_list(&args) {
579        Ok(v) => v,
580        Err(e) => return TokenStream::from(e.write_errors()),
581    };
582
583    let input = parse_macro_input!(item as ItemFn);
584
585    let kernel_id = &args.id;
586    let fn_name = &input.sig.ident;
587    let fn_vis = &input.vis;
588    let fn_block = &input.block;
589    let fn_attrs = &input.attrs;
590
591    // Parse function signature
592    let inputs = &input.sig.inputs;
593    let output = &input.sig.output;
594
595    // Extract context and message types from signature
596    let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
597        let ctx = inputs.first();
598        let msg = inputs.iter().nth(1);
599        (ctx, msg)
600    } else {
601        (None, None)
602    };
603
604    // Get message type
605    let msg_type = msg_arg
606        .map(|arg| {
607            if let syn::FnArg::Typed(pat_type) = arg {
608                pat_type.ty.clone()
609            } else {
610                syn::parse_quote!(())
611            }
612        })
613        .unwrap_or_else(|| syn::parse_quote!(()));
614
615    // Generate kernel mode
616    let mode = args.mode.as_deref().unwrap_or("persistent");
617    let mode_expr = if mode == "event_driven" {
618        quote! { ::ringkernel_core::types::KernelMode::EventDriven }
619    } else {
620        quote! { ::ringkernel_core::types::KernelMode::Persistent }
621    };
622
623    // Generate grid/block size
624    let grid_size = args.grid_size.unwrap_or(1);
625    let block_size = args.block_size.unwrap_or(256);
626
627    // Parse publishes_to into a list of target kernel IDs
628    let publishes_to_targets: Vec<String> = args
629        .publishes_to
630        .as_ref()
631        .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
632        .unwrap_or_default();
633
634    // Generate registration struct name
635    let registration_name = format_ident!(
636        "__RINGKERNEL_REGISTRATION_{}",
637        fn_name.to_string().to_uppercase()
638    );
639    let handler_name = format_ident!("{}_handler", fn_name);
640
641    // Generate the expanded code
642    let expanded = quote! {
643        // Original function (preserved for documentation/testing)
644        #(#fn_attrs)*
645        #fn_vis async fn #fn_name #inputs #output #fn_block
646
647        // Kernel handler wrapper
648        #fn_vis fn #handler_name(
649            ctx: &mut ::ringkernel_core::RingContext<'_>,
650            envelope: ::ringkernel_core::message::MessageEnvelope,
651        ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
652            Box::pin(async move {
653                // Deserialize input message
654                let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
655
656                // Call the actual handler
657                let response = #fn_name(ctx, msg).await;
658
659                // Serialize response
660                let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
661                let response_header = ::ringkernel_core::message::MessageHeader::new(
662                    <_ as ::ringkernel_core::message::RingMessage>::message_type(),
663                    envelope.header.dest_kernel,
664                    envelope.header.source_kernel,
665                    response_payload.len(),
666                    ctx.now(),
667                ).with_correlation(envelope.header.correlation_id);
668
669                Ok(::ringkernel_core::message::MessageEnvelope {
670                    header: response_header,
671                    payload: response_payload,
672                })
673            })
674        }
675
676        // Kernel registration
677        #[allow(non_upper_case_globals)]
678        #[::inventory::submit]
679        static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
680            id: #kernel_id,
681            mode: #mode_expr,
682            grid_size: #grid_size,
683            block_size: #block_size,
684            publishes_to: &[#(#publishes_to_targets),*],
685        };
686    };
687
688    TokenStream::from(expanded)
689}
690
691/// Derive macro for GPU-compatible types.
692///
693/// Ensures the type has a stable memory layout suitable for GPU transfer.
694#[proc_macro_derive(GpuType)]
695pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
696    let input = parse_macro_input!(input as DeriveInput);
697    let name = &input.ident;
698    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
699
700    // Generate assertions for GPU compatibility
701    let expanded = quote! {
702        // Verify type is Copy (required for GPU transfer)
703        const _: fn() = || {
704            fn assert_copy<T: Copy>() {}
705            assert_copy::<#name #ty_generics>();
706        };
707
708        // Verify type is Pod (plain old data)
709        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
710        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
711    };
712
713    TokenStream::from(expanded)
714}
715
716// ============================================================================
717// Stencil Kernel Macro (requires cuda-codegen feature)
718// ============================================================================
719
720/// Attributes for the stencil_kernel macro.
721#[derive(Debug, FromMeta)]
722struct StencilKernelArgs {
723    /// Kernel identifier.
724    id: String,
725    /// Grid dimensionality: "1d", "2d", or "3d".
726    #[darling(default)]
727    grid: Option<String>,
728    /// Tile/block size (single value for square tiles).
729    #[darling(default)]
730    tile_size: Option<u32>,
731    /// Tile width (for non-square tiles).
732    #[darling(default)]
733    tile_width: Option<u32>,
734    /// Tile height (for non-square tiles).
735    #[darling(default)]
736    tile_height: Option<u32>,
737    /// Halo/ghost cell width (stencil radius).
738    #[darling(default)]
739    halo: Option<u32>,
740}
741
742/// Attribute macro for defining stencil kernels that transpile to CUDA.
743///
744/// This macro generates CUDA C code from Rust stencil kernel functions at compile time.
745/// The generated CUDA source is embedded in the binary and can be compiled at runtime
746/// using NVRTC.
747///
748/// # Attributes
749///
750/// - `id` (required) - Unique kernel identifier
751/// - `grid` - Grid dimensionality: "1d", "2d" (default), or "3d"
752/// - `tile_size` - Tile/block size (default: 16)
753/// - `tile_width` / `tile_height` - Non-square tile dimensions
754/// - `halo` - Stencil radius / ghost cell width (default: 1)
755///
756/// # Supported Rust Subset
757///
758/// - Primitives: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
759/// - Slices: `&[T]`, `&mut [T]`
760/// - Arithmetic: `+`, `-`, `*`, `/`, `%`
761/// - Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=`
762/// - Let bindings: `let x = expr;`
763/// - If/else: `if cond { a } else { b }`
764/// - Stencil intrinsics via `GridPos`
765///
766/// # Example
767///
768/// ```ignore
769/// use ringkernel_derive::stencil_kernel;
770/// use ringkernel_cuda_codegen::GridPos;
771///
772/// #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
773/// fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
774///     let curr = p[pos.idx()];
775///     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
776///     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
777/// }
778///
779/// // Access generated CUDA source:
780/// assert!(FDTD_CUDA_SOURCE.contains("__global__"));
781/// ```
782#[proc_macro_attribute]
783pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
784    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
785        Ok(v) => v,
786        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
787    };
788
789    let args = match StencilKernelArgs::from_list(&args) {
790        Ok(v) => v,
791        Err(e) => return TokenStream::from(e.write_errors()),
792    };
793
794    let input = parse_macro_input!(item as ItemFn);
795
796    // Generate the stencil kernel code
797    stencil_kernel_impl(args, input)
798}
799
800fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
801    let kernel_id = &args.id;
802    let fn_name = &input.sig.ident;
803    let fn_vis = &input.vis;
804    let fn_block = &input.block;
805    let fn_inputs = &input.sig.inputs;
806    let fn_output = &input.sig.output;
807    let fn_attrs = &input.attrs;
808
809    // Parse configuration
810    let grid = args.grid.as_deref().unwrap_or("2d");
811    let tile_width = args
812        .tile_width
813        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
814    let tile_height = args
815        .tile_height
816        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
817    let halo = args.halo.unwrap_or(1);
818
819    // Generate CUDA source constant name
820    let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
821
822    // Generate registration name
823    let registration_name = format_ident!(
824        "__STENCIL_KERNEL_REGISTRATION_{}",
825        fn_name.to_string().to_uppercase()
826    );
827
828    // Transpile to CUDA (if feature enabled)
829    #[cfg(feature = "cuda-codegen")]
830    let cuda_source_code = {
831        use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
832
833        let grid_type = match grid {
834            "1d" => Grid::Grid1D,
835            "2d" => Grid::Grid2D,
836            "3d" => Grid::Grid3D,
837            _ => Grid::Grid2D,
838        };
839
840        let config = StencilConfig::new(kernel_id.clone())
841            .with_grid(grid_type)
842            .with_tile_size(tile_width as usize, tile_height as usize)
843            .with_halo(halo as usize);
844
845        match transpile_stencil_kernel(&input, &config) {
846            Ok(cuda) => cuda,
847            Err(e) => {
848                return TokenStream::from(
849                    syn::Error::new_spanned(
850                        &input.sig.ident,
851                        format!("CUDA transpilation failed: {}", e),
852                    )
853                    .to_compile_error(),
854                );
855            }
856        }
857    };
858
859    #[cfg(not(feature = "cuda-codegen"))]
860    let cuda_source_code = format!(
861        "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
862        kernel_id
863    );
864
865    // Generate the expanded code
866    let expanded = quote! {
867        // Original function (for documentation/testing/CPU fallback)
868        #(#fn_attrs)*
869        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
870
871        /// Generated CUDA source code for this stencil kernel.
872        #fn_vis const #cuda_const_name: &str = #cuda_source_code;
873
874        /// Stencil kernel registration for runtime discovery.
875        #[allow(non_upper_case_globals)]
876        #[::inventory::submit]
877        static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
878            ::ringkernel_core::__private::StencilKernelRegistration {
879                id: #kernel_id,
880                grid: #grid,
881                tile_width: #tile_width,
882                tile_height: #tile_height,
883                halo: #halo,
884                cuda_source: #cuda_source_code,
885            };
886    };
887
888    TokenStream::from(expanded)
889}
890
891// ============================================================================
892// Multi-Backend GPU Kernel Macro
893// ============================================================================
894
895/// GPU backend targets (internal use only).
896#[derive(Debug, Clone, Copy, PartialEq, Eq)]
897enum GpuBackend {
898    /// NVIDIA CUDA backend.
899    Cuda,
900    /// Apple Metal backend.
901    Metal,
902    /// WebGPU backend (cross-platform).
903    Wgpu,
904    /// CPU fallback backend.
905    Cpu,
906}
907
908impl GpuBackend {
909    fn from_str(s: &str) -> Option<Self> {
910        match s.to_lowercase().as_str() {
911            "cuda" => Some(Self::Cuda),
912            "metal" => Some(Self::Metal),
913            "wgpu" | "webgpu" => Some(Self::Wgpu),
914            "cpu" => Some(Self::Cpu),
915            _ => None,
916        }
917    }
918
919    fn as_str(&self) -> &'static str {
920        match self {
921            Self::Cuda => "cuda",
922            Self::Metal => "metal",
923            Self::Wgpu => "wgpu",
924            Self::Cpu => "cpu",
925        }
926    }
927}
928
929/// GPU capability flags that can be required by a kernel (internal use only).
930#[derive(Debug, Clone, Copy, PartialEq, Eq)]
931enum GpuCapability {
932    /// 64-bit floating point support.
933    Float64,
934    /// 64-bit integer support.
935    Int64,
936    /// 64-bit atomics support.
937    Atomic64,
938    /// Cooperative groups / grid-wide sync.
939    CooperativeGroups,
940    /// Subgroup / warp / SIMD operations.
941    Subgroups,
942    /// Shared memory / threadgroup memory.
943    SharedMemory,
944    /// Dynamic parallelism (launching kernels from kernels).
945    DynamicParallelism,
946    /// Half-precision (f16) support.
947    Float16,
948}
949
950impl GpuCapability {
951    fn from_str(s: &str) -> Option<Self> {
952        match s.to_lowercase().as_str() {
953            "f64" | "float64" => Some(Self::Float64),
954            "i64" | "int64" => Some(Self::Int64),
955            "atomic64" => Some(Self::Atomic64),
956            "cooperative_groups" | "cooperativegroups" | "grid_sync" => {
957                Some(Self::CooperativeGroups)
958            }
959            "subgroups" | "warp" | "simd" => Some(Self::Subgroups),
960            "shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
961            "dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
962            "f16" | "float16" | "half" => Some(Self::Float16),
963            _ => None,
964        }
965    }
966
967    fn as_str(&self) -> &'static str {
968        match self {
969            Self::Float64 => "f64",
970            Self::Int64 => "i64",
971            Self::Atomic64 => "atomic64",
972            Self::CooperativeGroups => "cooperative_groups",
973            Self::Subgroups => "subgroups",
974            Self::SharedMemory => "shared_memory",
975            Self::DynamicParallelism => "dynamic_parallelism",
976            Self::Float16 => "f16",
977        }
978    }
979
980    /// Check if a backend supports this capability.
981    fn supported_by(&self, backend: GpuBackend) -> bool {
982        match (self, backend) {
983            // CUDA supports everything
984            (_, GpuBackend::Cuda) => true,
985
986            // Metal capabilities
987            (Self::Float64, GpuBackend::Metal) => false,
988            (Self::CooperativeGroups, GpuBackend::Metal) => false,
989            (Self::DynamicParallelism, GpuBackend::Metal) => false,
990            (_, GpuBackend::Metal) => true,
991
992            // WebGPU capabilities
993            (Self::Float64, GpuBackend::Wgpu) => false,
994            (Self::Int64, GpuBackend::Wgpu) => false,
995            (Self::Atomic64, GpuBackend::Wgpu) => false, // Emulated only
996            (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
997            (Self::DynamicParallelism, GpuBackend::Wgpu) => false,
998            (Self::Subgroups, GpuBackend::Wgpu) => true, // Optional extension
999            (_, GpuBackend::Wgpu) => true,
1000
1001            // CPU supports everything (in emulation)
1002            (_, GpuBackend::Cpu) => true,
1003        }
1004    }
1005}
1006
1007/// Attributes for the gpu_kernel macro.
1008#[derive(Debug)]
1009struct GpuKernelArgs {
1010    /// Kernel identifier.
1011    id: Option<String>,
1012    /// Target backends to generate code for.
1013    backends: Vec<GpuBackend>,
1014    /// Fallback order for backend selection.
1015    fallback: Vec<GpuBackend>,
1016    /// Required capabilities.
1017    requires: Vec<GpuCapability>,
1018    /// Block/workgroup size.
1019    block_size: Option<u32>,
1020}
1021
1022impl Default for GpuKernelArgs {
1023    fn default() -> Self {
1024        Self {
1025            id: None,
1026            backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
1027            fallback: vec![
1028                GpuBackend::Cuda,
1029                GpuBackend::Metal,
1030                GpuBackend::Wgpu,
1031                GpuBackend::Cpu,
1032            ],
1033            requires: Vec::new(),
1034            block_size: None,
1035        }
1036    }
1037}
1038
1039impl GpuKernelArgs {
1040    fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
1041        let mut args = Self::default();
1042        let attr_str = attr.to_string();
1043
1044        // Parse backends = [...]
1045        if let Some(start) = attr_str.find("backends") {
1046            if let Some(bracket_start) = attr_str[start..].find('[') {
1047                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1048                    let backends_str =
1049                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1050                    args.backends = backends_str
1051                        .split(',')
1052                        .filter_map(|s| GpuBackend::from_str(s.trim()))
1053                        .collect();
1054                }
1055            }
1056        }
1057
1058        // Parse fallback = [...]
1059        if let Some(start) = attr_str.find("fallback") {
1060            if let Some(bracket_start) = attr_str[start..].find('[') {
1061                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1062                    let fallback_str =
1063                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1064                    args.fallback = fallback_str
1065                        .split(',')
1066                        .filter_map(|s| GpuBackend::from_str(s.trim()))
1067                        .collect();
1068                }
1069            }
1070        }
1071
1072        // Parse requires = [...]
1073        if let Some(start) = attr_str.find("requires") {
1074            if let Some(bracket_start) = attr_str[start..].find('[') {
1075                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1076                    let requires_str =
1077                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1078                    args.requires = requires_str
1079                        .split(',')
1080                        .filter_map(|s| GpuCapability::from_str(s.trim()))
1081                        .collect();
1082                }
1083            }
1084        }
1085
1086        // Parse id = "..."
1087        if let Some(start) = attr_str.find("id") {
1088            if let Some(quote_start) = attr_str[start..].find('"') {
1089                if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
1090                    args.id = Some(
1091                        attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
1092                            .to_string(),
1093                    );
1094                }
1095            }
1096        }
1097
1098        // Parse block_size = N
1099        if let Some(start) = attr_str.find("block_size") {
1100            if let Some(eq) = attr_str[start..].find('=') {
1101                let rest = &attr_str[start + eq + 1..];
1102                let num_end = rest
1103                    .find(|c: char| !c.is_numeric() && c != ' ')
1104                    .unwrap_or(rest.len());
1105                if let Ok(n) = rest[..num_end].trim().parse() {
1106                    args.block_size = Some(n);
1107                }
1108            }
1109        }
1110
1111        Ok(args)
1112    }
1113
1114    /// Validate that all required capabilities are supported by at least one backend.
1115    fn validate_capabilities(&self) -> Result<(), String> {
1116        for cap in &self.requires {
1117            let mut supported_by_any = false;
1118            for backend in &self.backends {
1119                if cap.supported_by(*backend) {
1120                    supported_by_any = true;
1121                    break;
1122                }
1123            }
1124            if !supported_by_any {
1125                return Err(format!(
1126                    "Capability '{}' is not supported by any of the specified backends: {:?}",
1127                    cap.as_str(),
1128                    self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
1129                ));
1130            }
1131        }
1132        Ok(())
1133    }
1134
1135    /// Get backends that support all required capabilities.
1136    fn compatible_backends(&self) -> Vec<GpuBackend> {
1137        self.backends
1138            .iter()
1139            .filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
1140            .copied()
1141            .collect()
1142    }
1143}
1144
1145/// Attribute macro for defining multi-backend GPU kernels.
1146///
1147/// This macro generates code for multiple GPU backends with compile-time
1148/// capability validation. It integrates with the `ringkernel-ir` crate
1149/// to lower Rust DSL to backend-specific shader code.
1150///
1151/// # Attributes
1152///
1153/// - `backends = [cuda, metal, wgpu]` - Target backends (default: all)
1154/// - `fallback = [cuda, metal, wgpu, cpu]` - Fallback order for runtime selection
1155/// - `requires = [f64, atomic64]` - Required capabilities (validated at compile time)
1156/// - `id = "kernel_name"` - Explicit kernel identifier
1157/// - `block_size = 256` - Thread block size
1158///
1159/// # Example
1160///
1161/// ```ignore
1162/// use ringkernel_derive::gpu_kernel;
1163///
1164/// #[gpu_kernel(backends = [cuda, metal], requires = [subgroups])]
1165/// fn warp_reduce(data: &mut [f32], n: i32) {
1166///     let idx = global_thread_id_x();
1167///     if idx < n {
1168///         // Use warp shuffle for reduction
1169///         let val = data[idx as usize];
1170///         let reduced = warp_reduce_sum(val);
1171///         if lane_id() == 0 {
1172///             data[idx as usize] = reduced;
1173///         }
1174///     }
1175/// }
1176/// ```
1177///
1178/// # Capability Checking
1179///
1180/// The macro validates at compile time that all required capabilities are
1181/// supported by at least one target backend:
1182///
1183/// | Capability | CUDA | Metal | WebGPU | CPU |
1184/// |------------|------|-------|--------|-----|
1185/// | f64        | Yes  | No    | No     | Yes |
1186/// | i64        | Yes  | Yes   | No     | Yes |
1187/// | atomic64   | Yes  | Yes   | No*    | Yes |
1188/// | cooperative_groups | Yes | No | No | Yes |
1189/// | subgroups  | Yes  | Yes   | Opt    | Yes |
1190/// | shared_memory | Yes | Yes | Yes    | Yes |
1191/// | f16        | Yes  | Yes   | Yes    | Yes |
1192///
1193/// *WebGPU emulates 64-bit atomics with 32-bit pairs.
1194///
1195/// # Generated Code
1196///
1197/// For each compatible backend, the macro generates:
1198/// - Backend-specific source code constant (e.g., `KERNEL_NAME_CUDA_SOURCE`)
1199/// - Registration entry for runtime discovery
1200/// - CPU fallback function (if `cpu_fallback = true`)
1201#[proc_macro_attribute]
1202pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
1203    let attr2: proc_macro2::TokenStream = attr.into();
1204    let args = match GpuKernelArgs::parse(attr2) {
1205        Ok(args) => args,
1206        Err(e) => return TokenStream::from(e.write_errors()),
1207    };
1208
1209    let input = parse_macro_input!(item as ItemFn);
1210
1211    // Validate capabilities
1212    if let Err(msg) = args.validate_capabilities() {
1213        return TokenStream::from(
1214            syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
1215        );
1216    }
1217
1218    gpu_kernel_impl(args, input)
1219}
1220
1221fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
1222    let fn_name = &input.sig.ident;
1223    let fn_vis = &input.vis;
1224    let fn_block = &input.block;
1225    let fn_inputs = &input.sig.inputs;
1226    let fn_output = &input.sig.output;
1227    let fn_attrs = &input.attrs;
1228
1229    let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
1230    let block_size = args.block_size.unwrap_or(256);
1231
1232    // Get compatible backends
1233    let compatible_backends = args.compatible_backends();
1234
1235    // Generate backend-specific source constants
1236    let mut source_constants = Vec::new();
1237
1238    for backend in &compatible_backends {
1239        let const_name = format_ident!(
1240            "{}_{}",
1241            fn_name.to_string().to_uppercase(),
1242            backend.as_str().to_uppercase()
1243        );
1244
1245        let backend_str = backend.as_str();
1246
1247        // Generate placeholder source (actual IR lowering happens at build time)
1248        // In a full implementation, this would call ringkernel-ir lowering
1249        let source_placeholder = format!(
1250            "// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
1251            backend_str.to_uppercase(),
1252            kernel_id,
1253            args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
1254        );
1255
1256        source_constants.push(quote! {
1257            /// Generated source code for this kernel.
1258            #fn_vis const #const_name: &str = #source_placeholder;
1259        });
1260    }
1261
1262    // Generate capability flags as strings
1263    let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
1264    let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
1265    let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
1266
1267    // Generate registration struct name
1268    let registration_name = format_ident!(
1269        "__GPU_KERNEL_REGISTRATION_{}",
1270        fn_name.to_string().to_uppercase()
1271    );
1272
1273    // Generate info struct name
1274    let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
1275
1276    // Generate the expanded code
1277    let expanded = quote! {
1278        // Original function (CPU fallback / documentation / testing)
1279        #(#fn_attrs)*
1280        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
1281
1282        // Backend source constants
1283        #(#source_constants)*
1284
1285        /// Multi-backend kernel information.
1286        #fn_vis mod #info_name {
1287            /// Kernel identifier.
1288            pub const ID: &str = #kernel_id;
1289
1290            /// Block/workgroup size.
1291            pub const BLOCK_SIZE: u32 = #block_size;
1292
1293            /// Required capabilities.
1294            pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
1295
1296            /// Compatible backends (those that support all required capabilities).
1297            pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
1298
1299            /// Fallback order for runtime backend selection.
1300            pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
1301        }
1302
1303        /// GPU kernel registration for runtime discovery.
1304        #[allow(non_upper_case_globals)]
1305        #[::inventory::submit]
1306        static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
1307            ::ringkernel_core::__private::GpuKernelRegistration {
1308                id: #kernel_id,
1309                block_size: #block_size,
1310                capabilities: &[#(#capability_strs),*],
1311                backends: &[#(#backend_strs),*],
1312                fallback_order: &[#(#fallback_strs),*],
1313            };
1314    };
1315
1316    TokenStream::from(expanded)
1317}
1318
1319// ============================================================================
1320// ControlBlockState Derive Macro (FR-4)
1321// ============================================================================
1322
1323/// Attributes for the ControlBlockState derive macro.
1324#[derive(Debug, FromDeriveInput)]
1325#[darling(attributes(state), supports(struct_named))]
1326struct ControlBlockStateArgs {
1327    ident: syn::Ident,
1328    generics: syn::Generics,
1329    /// State version for forward compatibility.
1330    #[darling(default)]
1331    version: Option<u32>,
1332}
1333
1334/// Derive macro for implementing EmbeddedState trait.
1335///
1336/// This macro generates implementations for types that can be stored in
1337/// the ControlBlock's 24-byte `_reserved` field for zero-copy state access.
1338///
1339/// # Requirements
1340///
1341/// The type must:
1342/// - Be `#[repr(C)]` for stable memory layout
1343/// - Be <= 24 bytes in size (checked at compile time)
1344/// - Implement `Clone`, `Copy`, and `Default`
1345/// - Contain only POD (Plain Old Data) types
1346///
1347/// # Attributes
1348///
1349/// - `#[state(version = N)]` - Set state version for migrations (default: 1)
1350///
1351/// # Example
1352///
1353/// ```ignore
1354/// #[derive(ControlBlockState, Default, Clone, Copy)]
1355/// #[repr(C, align(8))]
1356/// #[state(version = 1)]
1357/// pub struct OrderBookState {
1358///     pub best_bid: u64,    // 8 bytes
1359///     pub best_ask: u64,    // 8 bytes
1360///     pub order_count: u32, // 4 bytes
1361///     pub _pad: u32,        // 4 bytes (padding for alignment)
1362/// }  // Total: 24 bytes - fits in ControlBlock._reserved
1363///
1364/// // Use with ControlBlockStateHelper:
1365/// let mut block = ControlBlock::new();
1366/// let state = OrderBookState { best_bid: 100, best_ask: 101, order_count: 42, _pad: 0 };
1367/// ControlBlockStateHelper::write_embedded(&mut block, &state)?;
1368/// ```
1369///
1370/// # Size Validation
1371///
1372/// The macro generates a compile-time assertion that fails if the type
1373/// exceeds 24 bytes:
1374///
1375/// ```ignore
1376/// #[derive(ControlBlockState, Default, Clone, Copy)]
1377/// #[repr(C)]
1378/// struct TooLarge {
1379///     data: [u8; 32],  // 32 bytes - COMPILE ERROR!
1380/// }
1381/// ```
1382#[proc_macro_derive(ControlBlockState, attributes(state))]
1383pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
1384    let input = parse_macro_input!(input as DeriveInput);
1385
1386    let args = match ControlBlockStateArgs::from_derive_input(&input) {
1387        Ok(args) => args,
1388        Err(e) => return e.write_errors().into(),
1389    };
1390
1391    let name = &args.ident;
1392    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
1393    let version = args.version.unwrap_or(1);
1394
1395    let expanded = quote! {
1396        // Compile-time size check: EmbeddedState must fit in 24 bytes
1397        const _: () = {
1398            assert!(
1399                ::std::mem::size_of::<#name #ty_generics>() <= 24,
1400                "ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
1401            );
1402        };
1403
1404        // Verify type is Copy (required for GPU transfer)
1405        const _: fn() = || {
1406            fn assert_copy<T: Copy>() {}
1407            assert_copy::<#name #ty_generics>();
1408        };
1409
1410        // Implement Pod and Zeroable (required by EmbeddedState)
1411        // SAFETY: Type is #[repr(C)] with only primitive types, verified by user
1412        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
1413        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
1414
1415        // Implement EmbeddedState
1416        impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
1417            const VERSION: u32 = #version;
1418
1419            fn is_embedded() -> bool {
1420                true
1421            }
1422        }
1423    };
1424
1425    TokenStream::from(expanded)
1426}