Skip to main content

ringkernel_derive/
lib.rs

1//! Procedural macros for RingKernel.
2//!
3//! This crate provides the following macros:
4//!
5//! - `#[derive(RingMessage)]` - Implement the RingMessage trait for message types
6//! - `#[derive(PersistentMessage)]` - Implement PersistentMessage for GPU kernel dispatch
7//! - `#[ring_kernel]` - Define a ring kernel handler
8//! - `#[stencil_kernel]` - Define a GPU stencil kernel (with `cuda-codegen` feature)
9//! - `#[gpu_kernel]` - Define a multi-backend GPU kernel with capability checking
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ringkernel_derive::{RingMessage, ring_kernel};
15//!
16//! #[derive(RingMessage)]
17//! struct AddRequest {
18//!     #[message(id)]
19//!     id: MessageId,
20//!     a: f32,
21//!     b: f32,
22//! }
23//!
24//! #[derive(RingMessage)]
25//! struct AddResponse {
26//!     #[message(id)]
27//!     id: MessageId,
28//!     result: f32,
29//! }
30//!
31//! #[ring_kernel(id = "adder")]
32//! async fn process(ctx: &mut RingContext, req: AddRequest) -> AddResponse {
33//!     AddResponse {
34//!         id: MessageId::generate(),
35//!         result: req.a + req.b,
36//!     }
37//! }
38//! ```
39//!
40//! # Multi-Backend GPU Kernels
41//!
42//! The `#[gpu_kernel]` macro enables multi-backend code generation with capability checking:
43//!
44//! ```ignore
45//! use ringkernel_derive::gpu_kernel;
46//!
47//! // Generate code for CUDA and Metal, with fallback order
48//! #[gpu_kernel(backends = [cuda, metal], fallback = [wgpu, cpu])]
49//! fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
50//!     let idx = global_thread_id_x();
51//!     if idx < n {
52//!         y[idx as usize] = a * x[idx as usize] + y[idx as usize];
53//!     }
54//! }
55//!
56//! // Require specific capabilities at compile time
57//! #[gpu_kernel(backends = [cuda], requires = [f64, atomic64])]
58//! fn double_precision(data: &mut [f64], n: i32) {
59//!     // Uses f64 operations - validated at compile time
60//! }
61//! ```
62//!
63//! # Stencil Kernels (with `cuda-codegen` feature)
64//!
65//! ```ignore
66//! use ringkernel_derive::stencil_kernel;
67//! use ringkernel_cuda_codegen::GridPos;
68//!
69//! #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
70//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
71//!     let curr = p[pos.idx()];
72//!     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
73//!     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
74//! }
75//! ```
76
77use darling::{ast, FromDeriveInput, FromField, FromMeta};
78use proc_macro::TokenStream;
79use quote::{format_ident, quote};
80use syn::{parse_macro_input, DeriveInput, ItemFn};
81
82/// Attributes for the RingMessage derive macro.
83#[derive(Debug, FromDeriveInput)]
84#[darling(attributes(message, ring_message), supports(struct_named))]
85struct RingMessageArgs {
86    ident: syn::Ident,
87    generics: syn::Generics,
88    data: ast::Data<(), RingMessageField>,
89    /// Optional explicit message type ID.
90    /// If domain is specified, this is the offset within the domain (0-99).
91    /// If domain is not specified, this is the absolute type ID.
92    #[darling(default)]
93    type_id: Option<u64>,
94    /// Optional domain for message classification.
95    /// When specified, the final type ID = domain.base_type_id() + type_id.
96    #[darling(default)]
97    domain: Option<String>,
98    /// Whether this message is routable via K2K.
99    /// When true, generates a K2KMessageRegistration for runtime discovery.
100    #[darling(default)]
101    k2k_routable: bool,
102    /// Optional category for K2K routing groups.
103    /// Multiple messages can share a category for grouped routing.
104    #[darling(default)]
105    category: Option<String>,
106}
107
108/// Field attributes for RingMessage.
109#[derive(Debug, FromField)]
110#[darling(attributes(message))]
111struct RingMessageField {
112    ident: Option<syn::Ident>,
113    #[allow(dead_code)]
114    ty: syn::Type,
115    /// Mark this field as the message ID.
116    #[darling(default)]
117    id: bool,
118    /// Mark this field as the correlation ID.
119    #[darling(default)]
120    correlation: bool,
121    /// Mark this field as the priority.
122    #[darling(default)]
123    priority: bool,
124}
125
126/// Derive macro for implementing the RingMessage trait.
127///
128/// # Attributes
129///
130/// On the struct (via `#[message(...)]` or `#[ring_message(...)]`):
131/// - `type_id = 123` - Set explicit message type ID (or domain offset if domain is set)
132/// - `domain = "OrderMatching"` - Assign to a business domain (adds base type ID)
133/// - `k2k_routable = true` - Register for K2K routing discovery
134/// - `category = "orders"` - Group messages for K2K routing
135///
136/// On fields:
137/// - `#[message(id)]` - Mark as message ID field
138/// - `#[message(correlation)]` - Mark as correlation ID field
139/// - `#[message(priority)]` - Mark as priority field
140///
141/// # Examples
142///
143/// Basic usage:
144/// ```ignore
145/// #[derive(RingMessage)]
146/// #[message(type_id = 1)]
147/// struct MyMessage {
148///     #[message(id)]
149///     id: MessageId,
150///     #[message(correlation)]
151///     correlation: CorrelationId,
152///     #[message(priority)]
153///     priority: Priority,
154///     payload: Vec<u8>,
155/// }
156/// ```
157///
158/// With domain (type ID = 500 + 1 = 501):
159/// ```ignore
160/// #[derive(RingMessage)]
161/// #[ring_message(type_id = 1, domain = "OrderMatching")]
162/// pub struct SubmitOrderInput {
163///     #[message(id)]
164///     id: MessageId,
165///     pub order: Order,
166/// }
167/// // Also implements DomainMessage trait
168/// assert_eq!(SubmitOrderInput::domain(), Domain::OrderMatching);
169/// ```
170///
171/// K2K-routable message:
172/// ```ignore
173/// #[derive(RingMessage)]
174/// #[ring_message(type_id = 1, domain = "OrderMatching", k2k_routable = true, category = "orders")]
175/// pub struct SubmitOrderInput { ... }
176///
177/// // Runtime discovery:
178/// let registry = K2KTypeRegistry::discover();
179/// assert!(registry.is_routable(501));
180/// ```
181#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
182pub fn derive_ring_message(input: TokenStream) -> TokenStream {
183    let input = parse_macro_input!(input as DeriveInput);
184
185    let args = match RingMessageArgs::from_derive_input(&input) {
186        Ok(args) => args,
187        Err(e) => return e.write_errors().into(),
188    };
189
190    let name = &args.ident;
191    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
192
193    // Calculate base type ID (offset within domain, or absolute if no domain)
194    let base_type_id = args.type_id.unwrap_or_else(|| {
195        use std::collections::hash_map::DefaultHasher;
196        use std::hash::{Hash, Hasher};
197        let mut hasher = DefaultHasher::new();
198        name.to_string().hash(&mut hasher);
199        // If domain is set, hash to a value within 0-99 range
200        if args.domain.is_some() {
201            hasher.finish() % 100
202        } else {
203            hasher.finish()
204        }
205    });
206
207    // Find annotated fields
208    let fields = match &args.data {
209        ast::Data::Struct(fields) => fields,
210        _ => panic!("RingMessage can only be derived for structs"),
211    };
212
213    let mut id_field: Option<&syn::Ident> = None;
214    let mut correlation_field: Option<&syn::Ident> = None;
215    let mut priority_field: Option<&syn::Ident> = None;
216
217    for field in fields.iter() {
218        if field.id {
219            id_field = field.ident.as_ref();
220        }
221        if field.correlation {
222            correlation_field = field.ident.as_ref();
223        }
224        if field.priority {
225            priority_field = field.ident.as_ref();
226        }
227    }
228
229    // Generate message_id method
230    let message_id_impl = if let Some(field) = id_field {
231        quote! { self.#field }
232    } else {
233        quote! { ::ringkernel_core::message::MessageId::new(0) }
234    };
235
236    // Generate correlation_id method
237    let correlation_id_impl = if let Some(field) = correlation_field {
238        quote! { self.#field }
239    } else {
240        quote! { ::ringkernel_core::message::CorrelationId::none() }
241    };
242
243    // Generate priority method
244    let priority_impl = if let Some(field) = priority_field {
245        quote! { self.#field }
246    } else {
247        quote! { ::ringkernel_core::message::Priority::Normal }
248    };
249
250    // Generate message_type() implementation based on whether domain is specified
251    let message_type_impl = if let Some(ref domain_str) = args.domain {
252        // With domain: type_id = domain.base_type_id() + offset
253        quote! {
254            ::ringkernel_core::domain::Domain::from_str(#domain_str)
255                .unwrap_or(::ringkernel_core::domain::Domain::General)
256                .base_type_id() + #base_type_id
257        }
258    } else {
259        // Without domain: use absolute type_id
260        quote! { #base_type_id }
261    };
262
263    // Generate DomainMessage impl if domain is specified
264    let domain_impl = if let Some(ref domain_str) = args.domain {
265        quote! {
266            impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
267                fn domain() -> ::ringkernel_core::domain::Domain {
268                    ::ringkernel_core::domain::Domain::from_str(#domain_str)
269                        .unwrap_or(::ringkernel_core::domain::Domain::General)
270                }
271            }
272        }
273    } else {
274        quote! {}
275    };
276
277    // Generate K2K registration if k2k_routable is set
278    let k2k_registration = if args.k2k_routable {
279        let registration_name = format_ident!(
280            "__K2K_MESSAGE_REGISTRATION_{}",
281            name.to_string().to_uppercase()
282        );
283        let type_name_str = name.to_string();
284        let category_tokens = match &args.category {
285            Some(cat) => quote! { ::std::option::Option::Some(#cat) },
286            None => quote! { ::std::option::Option::None },
287        };
288
289        quote! {
290            #[allow(non_upper_case_globals)]
291            #[::inventory::submit]
292            static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
293                ::ringkernel_core::k2k::K2KMessageRegistration {
294                    type_id: {
295                        // Note: This is a const context, so we use the base calculation
296                        // For domain types, we need to add the base manually
297                        #base_type_id
298                    },
299                    type_name: #type_name_str,
300                    k2k_routable: true,
301                    category: #category_tokens,
302                };
303        }
304    } else {
305        quote! {}
306    };
307
308    let expanded = quote! {
309        impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
310            fn message_type() -> u64 {
311                #message_type_impl
312            }
313
314            fn message_id(&self) -> ::ringkernel_core::message::MessageId {
315                #message_id_impl
316            }
317
318            fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
319                #correlation_id_impl
320            }
321
322            fn priority(&self) -> ::ringkernel_core::message::Priority {
323                #priority_impl
324            }
325
326            fn serialize(&self) -> Vec<u8> {
327                // Use rkyv for serialization with a 4KB scratch buffer
328                // For larger payloads, rkyv will allocate as needed
329                ::rkyv::to_bytes::<_, 4096>(self)
330                    .map(|v| v.to_vec())
331                    .unwrap_or_default()
332            }
333
334            fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
335            where
336                Self: Sized,
337            {
338                use ::rkyv::Deserialize as _;
339                let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
340                let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
341                    .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
342                        "rkyv deserialization failed".to_string()
343                    ))?;
344                Ok(deserialized)
345            }
346
347            fn size_hint(&self) -> usize {
348                ::std::mem::size_of::<Self>()
349            }
350        }
351
352        #domain_impl
353
354        #k2k_registration
355    };
356
357    TokenStream::from(expanded)
358}
359
360// ============================================================================
361// PersistentMessage Derive Macro
362// ============================================================================
363
364/// Maximum size for inline payload in persistent messages.
365#[allow(dead_code)]
366const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
367
368/// Attributes for the PersistentMessage derive macro.
369#[derive(Debug, FromDeriveInput)]
370#[darling(attributes(persistent_message), supports(struct_named))]
371struct PersistentMessageArgs {
372    ident: syn::Ident,
373    generics: syn::Generics,
374    /// Field data (reserved for future per-field attributes).
375    #[allow(dead_code)]
376    data: ast::Data<(), PersistentMessageField>,
377    /// Handler ID for CUDA dispatch (0-255).
378    handler_id: u32,
379    /// Whether this message type expects a response.
380    #[darling(default)]
381    requires_response: bool,
382}
383
384/// Field attributes for PersistentMessage (reserved for future use).
385#[derive(Debug, FromField)]
386#[darling(attributes(persistent_message))]
387struct PersistentMessageField {
388    /// Field identifier.
389    #[allow(dead_code)]
390    ident: Option<syn::Ident>,
391    /// Field type.
392    #[allow(dead_code)]
393    ty: syn::Type,
394}
395
396/// Derive macro for implementing the PersistentMessage trait.
397///
398/// This macro enables type-based dispatch within persistent GPU kernels by
399/// generating handler_id, inline payload serialization, and deserialization.
400///
401/// # Requirements
402///
403/// The struct must:
404/// - Already implement `RingMessage` (use `#[derive(RingMessage)]`)
405/// - Be `#[repr(C)]` for safe memory layout
406/// - Be `Copy` + `Clone` for inline payload serialization
407///
408/// # Attributes
409///
410/// On the struct:
411/// - `handler_id = N` (required) - CUDA dispatch handler ID (0-255)
412/// - `requires_response = true/false` (optional) - Whether this message expects a response
413///
414/// # Example
415///
416/// ```ignore
417/// use ringkernel_derive::{RingMessage, PersistentMessage};
418///
419/// #[derive(RingMessage, PersistentMessage, Clone, Copy)]
420/// #[repr(C)]
421/// #[message(type_id = 1001)]
422/// #[persistent_message(handler_id = 1, requires_response = true)]
423/// pub struct FraudCheckRequest {
424///     pub transaction_id: u64,
425///     pub amount: f32,
426///     pub account_id: u32,
427/// }
428///
429/// // Generated implementations:
430/// // - handler_id() returns 1
431/// // - requires_response() returns true
432/// // - to_inline_payload() serializes the struct to [u8; 32] if it fits
433/// // - from_inline_payload() deserializes from bytes
434/// // - payload_size() returns the struct size
435/// ```
436///
437/// # Size Validation
438///
439/// For inline payload serialization, structs must be <= 32 bytes.
440/// Larger structs will return `None` from `to_inline_payload()`.
441///
442/// # CUDA Integration
443///
444/// The handler_id maps to a switch case in generated CUDA code:
445///
446/// ```cuda
447/// switch (msg->handler_id) {
448///     case 1: handle_fraud_check(msg, state, response); break;
449///     case 2: handle_aggregate(msg, state, response); break;
450///     // ...
451/// }
452/// ```
453#[proc_macro_derive(PersistentMessage, attributes(persistent_message))]
454pub fn derive_persistent_message(input: TokenStream) -> TokenStream {
455    let input = parse_macro_input!(input as DeriveInput);
456
457    let args = match PersistentMessageArgs::from_derive_input(&input) {
458        Ok(args) => args,
459        Err(e) => return e.write_errors().into(),
460    };
461
462    let name = &args.ident;
463    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
464
465    let handler_id = args.handler_id;
466    let requires_response = args.requires_response;
467
468    // Generate the PersistentMessage implementation
469    let expanded = quote! {
470        impl #impl_generics ::ringkernel_core::persistent_message::PersistentMessage for #name #ty_generics #where_clause {
471            fn handler_id() -> u32 {
472                #handler_id
473            }
474
475            fn requires_response() -> bool {
476                #requires_response
477            }
478
479            fn payload_size() -> usize {
480                ::std::mem::size_of::<Self>()
481            }
482
483            fn to_inline_payload(&self) -> ::std::option::Option<[u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
484                // Only serialize if the struct fits in the inline payload
485                if ::std::mem::size_of::<Self>() > ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE {
486                    return ::std::option::Option::None;
487                }
488
489                let mut payload = [0u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE];
490
491                // Safety: We've verified the struct fits in the payload,
492                // and the struct is repr(C) + Copy
493                unsafe {
494                    ::std::ptr::copy_nonoverlapping(
495                        self as *const Self as *const u8,
496                        payload.as_mut_ptr(),
497                        ::std::mem::size_of::<Self>()
498                    );
499                }
500
501                ::std::option::Option::Some(payload)
502            }
503
504            fn from_inline_payload(payload: &[u8]) -> ::ringkernel_core::error::Result<Self> {
505                let size = ::std::mem::size_of::<Self>();
506
507                if payload.len() < size {
508                    return ::std::result::Result::Err(
509                        ::ringkernel_core::error::RingKernelError::DeserializationError(
510                            ::std::format!(
511                                "Payload too small: expected {} bytes, got {}",
512                                size,
513                                payload.len()
514                            )
515                        )
516                    );
517                }
518
519                // Safety: We've verified the payload is large enough,
520                // and the struct is repr(C) + Copy
521                let value = unsafe {
522                    ::std::ptr::read(payload.as_ptr() as *const Self)
523                };
524
525                ::std::result::Result::Ok(value)
526            }
527        }
528    };
529
530    TokenStream::from(expanded)
531}
532
533/// Attributes for the ring_kernel macro.
534#[derive(Debug, FromMeta)]
535struct RingKernelArgs {
536    /// Kernel identifier.
537    id: String,
538    /// Execution mode (persistent or event_driven).
539    #[darling(default)]
540    mode: Option<String>,
541    /// Grid size.
542    #[darling(default)]
543    grid_size: Option<u32>,
544    /// Block size.
545    #[darling(default)]
546    block_size: Option<u32>,
547    /// Target kernels this kernel publishes to.
548    #[darling(default)]
549    publishes_to: Option<String>,
550}
551
552/// Attribute macro for defining ring kernel handlers.
553///
554/// # Attributes
555///
556/// - `id` (required) - Unique kernel identifier
557/// - `mode` - Execution mode: "persistent" (default) or "event_driven"
558/// - `grid_size` - Number of blocks (default: 1)
559/// - `block_size` - Threads per block (default: 256)
560/// - `publishes_to` - Comma-separated list of target kernel IDs
561///
562/// # Example
563///
564/// ```ignore
565/// #[ring_kernel(id = "processor", mode = "persistent", block_size = 128)]
566/// async fn handle(ctx: &mut RingContext, msg: MyMessage) -> MyResponse {
567///     // Process message
568///     MyResponse { ... }
569/// }
570/// ```
571#[proc_macro_attribute]
572pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
573    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
574        Ok(v) => v,
575        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
576    };
577
578    let args = match RingKernelArgs::from_list(&args) {
579        Ok(v) => v,
580        Err(e) => return TokenStream::from(e.write_errors()),
581    };
582
583    let input = parse_macro_input!(item as ItemFn);
584
585    let kernel_id = &args.id;
586    let fn_name = &input.sig.ident;
587    let fn_vis = &input.vis;
588    let fn_block = &input.block;
589    let fn_attrs = &input.attrs;
590
591    // Parse function signature
592    let inputs = &input.sig.inputs;
593    let output = &input.sig.output;
594
595    // Extract context and message types from signature
596    let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
597        let ctx = inputs.first();
598        let msg = inputs.iter().nth(1);
599        (ctx, msg)
600    } else {
601        (None, None)
602    };
603
604    // Get message type
605    let msg_type = msg_arg
606        .map(|arg| {
607            if let syn::FnArg::Typed(pat_type) = arg {
608                pat_type.ty.clone()
609            } else {
610                syn::parse_quote!(())
611            }
612        })
613        .unwrap_or_else(|| syn::parse_quote!(()));
614
615    // Generate kernel mode
616    let mode = args.mode.as_deref().unwrap_or("persistent");
617    let mode_expr = if mode == "event_driven" {
618        quote! { ::ringkernel_core::types::KernelMode::EventDriven }
619    } else {
620        quote! { ::ringkernel_core::types::KernelMode::Persistent }
621    };
622
623    // Generate grid/block size
624    let grid_size = args.grid_size.unwrap_or(1);
625    let block_size = args.block_size.unwrap_or(256);
626
627    // Parse publishes_to into a list of target kernel IDs
628    let publishes_to_targets: Vec<String> = args
629        .publishes_to
630        .as_ref()
631        .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
632        .unwrap_or_default();
633
634    // Generate registration struct name
635    let registration_name = format_ident!(
636        "__RINGKERNEL_REGISTRATION_{}",
637        fn_name.to_string().to_uppercase()
638    );
639    let handler_name = format_ident!("{}_handler", fn_name);
640
641    // Generate the expanded code
642    let expanded = quote! {
643        // Original function (preserved for documentation/testing)
644        #(#fn_attrs)*
645        #fn_vis async fn #fn_name #inputs #output #fn_block
646
647        // Kernel handler wrapper
648        #fn_vis fn #handler_name(
649            ctx: &mut ::ringkernel_core::RingContext<'_>,
650            envelope: ::ringkernel_core::message::MessageEnvelope,
651        ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
652            Box::pin(async move {
653                // Deserialize input message
654                let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
655
656                // Call the actual handler
657                let response = #fn_name(ctx, msg).await;
658
659                // Serialize response
660                let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
661                let response_header = ::ringkernel_core::message::MessageHeader::new(
662                    <_ as ::ringkernel_core::message::RingMessage>::message_type(),
663                    envelope.header.dest_kernel,
664                    envelope.header.source_kernel,
665                    response_payload.len(),
666                    ctx.now(),
667                ).with_correlation(envelope.header.correlation_id);
668
669                Ok(::ringkernel_core::message::MessageEnvelope {
670                    header: response_header,
671                    payload: response_payload,
672                    ..::std::default::Default::default()
673                })
674            })
675        }
676
677        // Kernel registration
678        #[allow(non_upper_case_globals)]
679        #[::inventory::submit]
680        static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
681            id: #kernel_id,
682            mode: #mode_expr,
683            grid_size: #grid_size,
684            block_size: #block_size,
685            publishes_to: &[#(#publishes_to_targets),*],
686        };
687    };
688
689    TokenStream::from(expanded)
690}
691
692/// Derive macro for GPU-compatible types.
693///
694/// Ensures the type has a stable memory layout suitable for GPU transfer.
695#[proc_macro_derive(GpuType)]
696pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
697    let input = parse_macro_input!(input as DeriveInput);
698    let name = &input.ident;
699    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
700
701    // Generate assertions for GPU compatibility
702    let expanded = quote! {
703        // Verify type is Copy (required for GPU transfer)
704        const _: fn() = || {
705            fn assert_copy<T: Copy>() {}
706            assert_copy::<#name #ty_generics>();
707        };
708
709        // Verify type is Pod (plain old data)
710        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
711        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
712    };
713
714    TokenStream::from(expanded)
715}
716
717// ============================================================================
718// Stencil Kernel Macro (requires cuda-codegen feature)
719// ============================================================================
720
721/// Attributes for the stencil_kernel macro.
722#[derive(Debug, FromMeta)]
723struct StencilKernelArgs {
724    /// Kernel identifier.
725    id: String,
726    /// Grid dimensionality: "1d", "2d", or "3d".
727    #[darling(default)]
728    grid: Option<String>,
729    /// Tile/block size (single value for square tiles).
730    #[darling(default)]
731    tile_size: Option<u32>,
732    /// Tile width (for non-square tiles).
733    #[darling(default)]
734    tile_width: Option<u32>,
735    /// Tile height (for non-square tiles).
736    #[darling(default)]
737    tile_height: Option<u32>,
738    /// Halo/ghost cell width (stencil radius).
739    #[darling(default)]
740    halo: Option<u32>,
741}
742
743/// Attribute macro for defining stencil kernels that transpile to CUDA.
744///
745/// This macro generates CUDA C code from Rust stencil kernel functions at compile time.
746/// The generated CUDA source is embedded in the binary and can be compiled at runtime
747/// using NVRTC.
748///
749/// # Attributes
750///
751/// - `id` (required) - Unique kernel identifier
752/// - `grid` - Grid dimensionality: "1d", "2d" (default), or "3d"
753/// - `tile_size` - Tile/block size (default: 16)
754/// - `tile_width` / `tile_height` - Non-square tile dimensions
755/// - `halo` - Stencil radius / ghost cell width (default: 1)
756///
757/// # Supported Rust Subset
758///
759/// - Primitives: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
760/// - Slices: `&[T]`, `&mut [T]`
761/// - Arithmetic: `+`, `-`, `*`, `/`, `%`
762/// - Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=`
763/// - Let bindings: `let x = expr;`
764/// - If/else: `if cond { a } else { b }`
765/// - Stencil intrinsics via `GridPos`
766///
767/// # Example
768///
769/// ```ignore
770/// use ringkernel_derive::stencil_kernel;
771/// use ringkernel_cuda_codegen::GridPos;
772///
773/// #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
774/// fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
775///     let curr = p[pos.idx()];
776///     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
777///     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
778/// }
779///
780/// // Access generated CUDA source:
781/// assert!(FDTD_CUDA_SOURCE.contains("__global__"));
782/// ```
783#[proc_macro_attribute]
784pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
785    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
786        Ok(v) => v,
787        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
788    };
789
790    let args = match StencilKernelArgs::from_list(&args) {
791        Ok(v) => v,
792        Err(e) => return TokenStream::from(e.write_errors()),
793    };
794
795    let input = parse_macro_input!(item as ItemFn);
796
797    // Generate the stencil kernel code
798    stencil_kernel_impl(args, input)
799}
800
801fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
802    let kernel_id = &args.id;
803    let fn_name = &input.sig.ident;
804    let fn_vis = &input.vis;
805    let fn_block = &input.block;
806    let fn_inputs = &input.sig.inputs;
807    let fn_output = &input.sig.output;
808    let fn_attrs = &input.attrs;
809
810    // Parse configuration
811    let grid = args.grid.as_deref().unwrap_or("2d");
812    let tile_width = args
813        .tile_width
814        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
815    let tile_height = args
816        .tile_height
817        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
818    let halo = args.halo.unwrap_or(1);
819
820    // Generate CUDA source constant name
821    let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
822
823    // Generate registration name
824    let registration_name = format_ident!(
825        "__STENCIL_KERNEL_REGISTRATION_{}",
826        fn_name.to_string().to_uppercase()
827    );
828
829    // Transpile to CUDA (if feature enabled)
830    #[cfg(feature = "cuda-codegen")]
831    let cuda_source_code = {
832        use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
833
834        let grid_type = match grid {
835            "1d" => Grid::Grid1D,
836            "2d" => Grid::Grid2D,
837            "3d" => Grid::Grid3D,
838            _ => Grid::Grid2D,
839        };
840
841        let config = StencilConfig::new(kernel_id.clone())
842            .with_grid(grid_type)
843            .with_tile_size(tile_width as usize, tile_height as usize)
844            .with_halo(halo as usize);
845
846        match transpile_stencil_kernel(&input, &config) {
847            Ok(cuda) => cuda,
848            Err(e) => {
849                return TokenStream::from(
850                    syn::Error::new_spanned(
851                        &input.sig.ident,
852                        format!("CUDA transpilation failed: {}", e),
853                    )
854                    .to_compile_error(),
855                );
856            }
857        }
858    };
859
860    #[cfg(not(feature = "cuda-codegen"))]
861    let cuda_source_code = format!(
862        "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
863        kernel_id
864    );
865
866    // Generate the expanded code
867    let expanded = quote! {
868        // Original function (for documentation/testing/CPU fallback)
869        #(#fn_attrs)*
870        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
871
872        /// Generated CUDA source code for this stencil kernel.
873        #fn_vis const #cuda_const_name: &str = #cuda_source_code;
874
875        /// Stencil kernel registration for runtime discovery.
876        #[allow(non_upper_case_globals)]
877        #[::inventory::submit]
878        static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
879            ::ringkernel_core::__private::StencilKernelRegistration {
880                id: #kernel_id,
881                grid: #grid,
882                tile_width: #tile_width,
883                tile_height: #tile_height,
884                halo: #halo,
885                cuda_source: #cuda_source_code,
886            };
887    };
888
889    TokenStream::from(expanded)
890}
891
892// ============================================================================
893// Multi-Backend GPU Kernel Macro
894// ============================================================================
895
896/// GPU backend targets (internal use only).
897#[derive(Debug, Clone, Copy, PartialEq, Eq)]
898enum GpuBackend {
899    /// NVIDIA CUDA backend.
900    Cuda,
901    /// Apple Metal backend.
902    Metal,
903    /// WebGPU backend (cross-platform).
904    Wgpu,
905    /// CPU fallback backend.
906    Cpu,
907}
908
909impl GpuBackend {
910    fn from_str(s: &str) -> Option<Self> {
911        match s.to_lowercase().as_str() {
912            "cuda" => Some(Self::Cuda),
913            "metal" => Some(Self::Metal),
914            "wgpu" | "webgpu" => Some(Self::Wgpu),
915            "cpu" => Some(Self::Cpu),
916            _ => None,
917        }
918    }
919
920    fn as_str(&self) -> &'static str {
921        match self {
922            Self::Cuda => "cuda",
923            Self::Metal => "metal",
924            Self::Wgpu => "wgpu",
925            Self::Cpu => "cpu",
926        }
927    }
928}
929
930/// GPU capability flags that can be required by a kernel (internal use only).
931#[derive(Debug, Clone, Copy, PartialEq, Eq)]
932enum GpuCapability {
933    /// 64-bit floating point support.
934    Float64,
935    /// 64-bit integer support.
936    Int64,
937    /// 64-bit atomics support.
938    Atomic64,
939    /// Cooperative groups / grid-wide sync.
940    CooperativeGroups,
941    /// Subgroup / warp / SIMD operations.
942    Subgroups,
943    /// Shared memory / threadgroup memory.
944    SharedMemory,
945    /// Dynamic parallelism (launching kernels from kernels).
946    DynamicParallelism,
947    /// Half-precision (f16) support.
948    Float16,
949}
950
951impl GpuCapability {
952    fn from_str(s: &str) -> Option<Self> {
953        match s.to_lowercase().as_str() {
954            "f64" | "float64" => Some(Self::Float64),
955            "i64" | "int64" => Some(Self::Int64),
956            "atomic64" => Some(Self::Atomic64),
957            "cooperative_groups" | "cooperativegroups" | "grid_sync" => {
958                Some(Self::CooperativeGroups)
959            }
960            "subgroups" | "warp" | "simd" => Some(Self::Subgroups),
961            "shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
962            "dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
963            "f16" | "float16" | "half" => Some(Self::Float16),
964            _ => None,
965        }
966    }
967
968    fn as_str(&self) -> &'static str {
969        match self {
970            Self::Float64 => "f64",
971            Self::Int64 => "i64",
972            Self::Atomic64 => "atomic64",
973            Self::CooperativeGroups => "cooperative_groups",
974            Self::Subgroups => "subgroups",
975            Self::SharedMemory => "shared_memory",
976            Self::DynamicParallelism => "dynamic_parallelism",
977            Self::Float16 => "f16",
978        }
979    }
980
981    /// Check if a backend supports this capability.
982    fn supported_by(&self, backend: GpuBackend) -> bool {
983        match (self, backend) {
984            // CUDA supports everything
985            (_, GpuBackend::Cuda) => true,
986
987            // Metal capabilities
988            (Self::Float64, GpuBackend::Metal) => false,
989            (Self::CooperativeGroups, GpuBackend::Metal) => false,
990            (Self::DynamicParallelism, GpuBackend::Metal) => false,
991            (_, GpuBackend::Metal) => true,
992
993            // WebGPU capabilities
994            (Self::Float64, GpuBackend::Wgpu) => false,
995            (Self::Int64, GpuBackend::Wgpu) => false,
996            (Self::Atomic64, GpuBackend::Wgpu) => false, // Emulated only
997            (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
998            (Self::DynamicParallelism, GpuBackend::Wgpu) => false,
999            (Self::Subgroups, GpuBackend::Wgpu) => true, // Optional extension
1000            (_, GpuBackend::Wgpu) => true,
1001
1002            // CPU supports everything (in emulation)
1003            (_, GpuBackend::Cpu) => true,
1004        }
1005    }
1006}
1007
1008/// Attributes for the gpu_kernel macro.
1009#[derive(Debug)]
1010struct GpuKernelArgs {
1011    /// Kernel identifier.
1012    id: Option<String>,
1013    /// Target backends to generate code for.
1014    backends: Vec<GpuBackend>,
1015    /// Fallback order for backend selection.
1016    fallback: Vec<GpuBackend>,
1017    /// Required capabilities.
1018    requires: Vec<GpuCapability>,
1019    /// Block/workgroup size.
1020    block_size: Option<u32>,
1021}
1022
1023impl Default for GpuKernelArgs {
1024    fn default() -> Self {
1025        Self {
1026            id: None,
1027            backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
1028            fallback: vec![
1029                GpuBackend::Cuda,
1030                GpuBackend::Metal,
1031                GpuBackend::Wgpu,
1032                GpuBackend::Cpu,
1033            ],
1034            requires: Vec::new(),
1035            block_size: None,
1036        }
1037    }
1038}
1039
1040impl GpuKernelArgs {
1041    fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
1042        let mut args = Self::default();
1043        let attr_str = attr.to_string();
1044
1045        // Parse backends = [...]
1046        if let Some(start) = attr_str.find("backends") {
1047            if let Some(bracket_start) = attr_str[start..].find('[') {
1048                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1049                    let backends_str =
1050                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1051                    args.backends = backends_str
1052                        .split(',')
1053                        .filter_map(|s| GpuBackend::from_str(s.trim()))
1054                        .collect();
1055                }
1056            }
1057        }
1058
1059        // Parse fallback = [...]
1060        if let Some(start) = attr_str.find("fallback") {
1061            if let Some(bracket_start) = attr_str[start..].find('[') {
1062                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1063                    let fallback_str =
1064                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1065                    args.fallback = fallback_str
1066                        .split(',')
1067                        .filter_map(|s| GpuBackend::from_str(s.trim()))
1068                        .collect();
1069                }
1070            }
1071        }
1072
1073        // Parse requires = [...]
1074        if let Some(start) = attr_str.find("requires") {
1075            if let Some(bracket_start) = attr_str[start..].find('[') {
1076                if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1077                    let requires_str =
1078                        &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1079                    args.requires = requires_str
1080                        .split(',')
1081                        .filter_map(|s| GpuCapability::from_str(s.trim()))
1082                        .collect();
1083                }
1084            }
1085        }
1086
1087        // Parse id = "..."
1088        if let Some(start) = attr_str.find("id") {
1089            if let Some(quote_start) = attr_str[start..].find('"') {
1090                if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
1091                    args.id = Some(
1092                        attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
1093                            .to_string(),
1094                    );
1095                }
1096            }
1097        }
1098
1099        // Parse block_size = N
1100        if let Some(start) = attr_str.find("block_size") {
1101            if let Some(eq) = attr_str[start..].find('=') {
1102                let rest = &attr_str[start + eq + 1..];
1103                let num_end = rest
1104                    .find(|c: char| !c.is_numeric() && c != ' ')
1105                    .unwrap_or(rest.len());
1106                if let Ok(n) = rest[..num_end].trim().parse() {
1107                    args.block_size = Some(n);
1108                }
1109            }
1110        }
1111
1112        Ok(args)
1113    }
1114
1115    /// Validate that all required capabilities are supported by at least one backend.
1116    fn validate_capabilities(&self) -> Result<(), String> {
1117        for cap in &self.requires {
1118            let mut supported_by_any = false;
1119            for backend in &self.backends {
1120                if cap.supported_by(*backend) {
1121                    supported_by_any = true;
1122                    break;
1123                }
1124            }
1125            if !supported_by_any {
1126                return Err(format!(
1127                    "Capability '{}' is not supported by any of the specified backends: {:?}",
1128                    cap.as_str(),
1129                    self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
1130                ));
1131            }
1132        }
1133        Ok(())
1134    }
1135
1136    /// Get backends that support all required capabilities.
1137    fn compatible_backends(&self) -> Vec<GpuBackend> {
1138        self.backends
1139            .iter()
1140            .filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
1141            .copied()
1142            .collect()
1143    }
1144}
1145
1146/// Attribute macro for defining multi-backend GPU kernels.
1147///
1148/// This macro generates code for multiple GPU backends with compile-time
1149/// capability validation. It integrates with the `ringkernel-ir` crate
1150/// to lower Rust DSL to backend-specific shader code.
1151///
1152/// # Attributes
1153///
1154/// - `backends = [cuda, metal, wgpu]` - Target backends (default: all)
1155/// - `fallback = [cuda, metal, wgpu, cpu]` - Fallback order for runtime selection
1156/// - `requires = [f64, atomic64]` - Required capabilities (validated at compile time)
1157/// - `id = "kernel_name"` - Explicit kernel identifier
1158/// - `block_size = 256` - Thread block size
1159///
1160/// # Example
1161///
1162/// ```ignore
1163/// use ringkernel_derive::gpu_kernel;
1164///
1165/// #[gpu_kernel(backends = [cuda, metal], requires = [subgroups])]
1166/// fn warp_reduce(data: &mut [f32], n: i32) {
1167///     let idx = global_thread_id_x();
1168///     if idx < n {
1169///         // Use warp shuffle for reduction
1170///         let val = data[idx as usize];
1171///         let reduced = warp_reduce_sum(val);
1172///         if lane_id() == 0 {
1173///             data[idx as usize] = reduced;
1174///         }
1175///     }
1176/// }
1177/// ```
1178///
1179/// # Capability Checking
1180///
1181/// The macro validates at compile time that all required capabilities are
1182/// supported by at least one target backend:
1183///
1184/// | Capability | CUDA | Metal | WebGPU | CPU |
1185/// |------------|------|-------|--------|-----|
1186/// | f64        | Yes  | No    | No     | Yes |
1187/// | i64        | Yes  | Yes   | No     | Yes |
1188/// | atomic64   | Yes  | Yes   | No*    | Yes |
1189/// | cooperative_groups | Yes | No | No | Yes |
1190/// | subgroups  | Yes  | Yes   | Opt    | Yes |
1191/// | shared_memory | Yes | Yes | Yes    | Yes |
1192/// | f16        | Yes  | Yes   | Yes    | Yes |
1193///
1194/// *WebGPU emulates 64-bit atomics with 32-bit pairs.
1195///
1196/// # Generated Code
1197///
1198/// For each compatible backend, the macro generates:
1199/// - Backend-specific source code constant (e.g., `KERNEL_NAME_CUDA_SOURCE`)
1200/// - Registration entry for runtime discovery
1201/// - CPU fallback function (if `cpu_fallback = true`)
1202#[proc_macro_attribute]
1203pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
1204    let attr2: proc_macro2::TokenStream = attr.into();
1205    let args = match GpuKernelArgs::parse(attr2) {
1206        Ok(args) => args,
1207        Err(e) => return TokenStream::from(e.write_errors()),
1208    };
1209
1210    let input = parse_macro_input!(item as ItemFn);
1211
1212    // Validate capabilities
1213    if let Err(msg) = args.validate_capabilities() {
1214        return TokenStream::from(
1215            syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
1216        );
1217    }
1218
1219    gpu_kernel_impl(args, input)
1220}
1221
1222fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
1223    let fn_name = &input.sig.ident;
1224    let fn_vis = &input.vis;
1225    let fn_block = &input.block;
1226    let fn_inputs = &input.sig.inputs;
1227    let fn_output = &input.sig.output;
1228    let fn_attrs = &input.attrs;
1229
1230    let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
1231    let block_size = args.block_size.unwrap_or(256);
1232
1233    // Get compatible backends
1234    let compatible_backends = args.compatible_backends();
1235
1236    // Generate backend-specific source constants
1237    let mut source_constants = Vec::new();
1238
1239    for backend in &compatible_backends {
1240        let const_name = format_ident!(
1241            "{}_{}",
1242            fn_name.to_string().to_uppercase(),
1243            backend.as_str().to_uppercase()
1244        );
1245
1246        let backend_str = backend.as_str();
1247
1248        // Generate placeholder source (actual IR lowering happens at build time)
1249        // In a full implementation, this would call ringkernel-ir lowering
1250        let source_placeholder = format!(
1251            "// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
1252            backend_str.to_uppercase(),
1253            kernel_id,
1254            args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
1255        );
1256
1257        source_constants.push(quote! {
1258            /// Generated source code for this kernel.
1259            #fn_vis const #const_name: &str = #source_placeholder;
1260        });
1261    }
1262
1263    // Generate capability flags as strings
1264    let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
1265    let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
1266    let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
1267
1268    // Generate registration struct name
1269    let registration_name = format_ident!(
1270        "__GPU_KERNEL_REGISTRATION_{}",
1271        fn_name.to_string().to_uppercase()
1272    );
1273
1274    // Generate info struct name
1275    let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
1276
1277    // Generate the expanded code
1278    let expanded = quote! {
1279        // Original function (CPU fallback / documentation / testing)
1280        #(#fn_attrs)*
1281        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
1282
1283        // Backend source constants
1284        #(#source_constants)*
1285
1286        /// Multi-backend kernel information.
1287        #fn_vis mod #info_name {
1288            /// Kernel identifier.
1289            pub const ID: &str = #kernel_id;
1290
1291            /// Block/workgroup size.
1292            pub const BLOCK_SIZE: u32 = #block_size;
1293
1294            /// Required capabilities.
1295            pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
1296
1297            /// Compatible backends (those that support all required capabilities).
1298            pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
1299
1300            /// Fallback order for runtime backend selection.
1301            pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
1302        }
1303
1304        /// GPU kernel registration for runtime discovery.
1305        #[allow(non_upper_case_globals)]
1306        #[::inventory::submit]
1307        static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
1308            ::ringkernel_core::__private::GpuKernelRegistration {
1309                id: #kernel_id,
1310                block_size: #block_size,
1311                capabilities: &[#(#capability_strs),*],
1312                backends: &[#(#backend_strs),*],
1313                fallback_order: &[#(#fallback_strs),*],
1314            };
1315    };
1316
1317    TokenStream::from(expanded)
1318}
1319
1320// ============================================================================
1321// ControlBlockState Derive Macro (FR-4)
1322// ============================================================================
1323
1324/// Attributes for the ControlBlockState derive macro.
1325#[derive(Debug, FromDeriveInput)]
1326#[darling(attributes(state), supports(struct_named))]
1327struct ControlBlockStateArgs {
1328    ident: syn::Ident,
1329    generics: syn::Generics,
1330    /// State version for forward compatibility.
1331    #[darling(default)]
1332    version: Option<u32>,
1333}
1334
1335/// Derive macro for implementing EmbeddedState trait.
1336///
1337/// This macro generates implementations for types that can be stored in
1338/// the ControlBlock's 24-byte `_reserved` field for zero-copy state access.
1339///
1340/// # Requirements
1341///
1342/// The type must:
1343/// - Be `#[repr(C)]` for stable memory layout
1344/// - Be <= 24 bytes in size (checked at compile time)
1345/// - Implement `Clone`, `Copy`, and `Default`
1346/// - Contain only POD (Plain Old Data) types
1347///
1348/// # Attributes
1349///
1350/// - `#[state(version = N)]` - Set state version for migrations (default: 1)
1351///
1352/// # Example
1353///
1354/// ```ignore
1355/// #[derive(ControlBlockState, Default, Clone, Copy)]
1356/// #[repr(C, align(8))]
1357/// #[state(version = 1)]
1358/// pub struct OrderBookState {
1359///     pub best_bid: u64,    // 8 bytes
1360///     pub best_ask: u64,    // 8 bytes
1361///     pub order_count: u32, // 4 bytes
1362///     pub _pad: u32,        // 4 bytes (padding for alignment)
1363/// }  // Total: 24 bytes - fits in ControlBlock._reserved
1364///
1365/// // Use with ControlBlockStateHelper:
1366/// let mut block = ControlBlock::new();
1367/// let state = OrderBookState { best_bid: 100, best_ask: 101, order_count: 42, _pad: 0 };
1368/// ControlBlockStateHelper::write_embedded(&mut block, &state)?;
1369/// ```
1370///
1371/// # Size Validation
1372///
1373/// The macro generates a compile-time assertion that fails if the type
1374/// exceeds 24 bytes:
1375///
1376/// ```ignore
1377/// #[derive(ControlBlockState, Default, Clone, Copy)]
1378/// #[repr(C)]
1379/// struct TooLarge {
1380///     data: [u8; 32],  // 32 bytes - COMPILE ERROR!
1381/// }
1382/// ```
1383#[proc_macro_derive(ControlBlockState, attributes(state))]
1384pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
1385    let input = parse_macro_input!(input as DeriveInput);
1386
1387    let args = match ControlBlockStateArgs::from_derive_input(&input) {
1388        Ok(args) => args,
1389        Err(e) => return e.write_errors().into(),
1390    };
1391
1392    let name = &args.ident;
1393    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
1394    let version = args.version.unwrap_or(1);
1395
1396    let expanded = quote! {
1397        // Compile-time size check: EmbeddedState must fit in 24 bytes
1398        const _: () = {
1399            assert!(
1400                ::std::mem::size_of::<#name #ty_generics>() <= 24,
1401                "ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
1402            );
1403        };
1404
1405        // Verify type is Copy (required for GPU transfer)
1406        const _: fn() = || {
1407            fn assert_copy<T: Copy>() {}
1408            assert_copy::<#name #ty_generics>();
1409        };
1410
1411        // Implement Pod and Zeroable (required by EmbeddedState)
1412        // SAFETY: Type is #[repr(C)] with only primitive types, verified by user
1413        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
1414        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
1415
1416        // Implement EmbeddedState
1417        impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
1418            const VERSION: u32 = #version;
1419
1420            fn is_embedded() -> bool {
1421                true
1422            }
1423        }
1424    };
1425
1426    TokenStream::from(expanded)
1427}