ringkernel_derive/lib.rs
1//! Procedural macros for RingKernel.
2//!
3//! This crate provides the following macros:
4//!
5//! - `#[derive(RingMessage)]` - Implement the RingMessage trait for message types
6//! - `#[derive(PersistentMessage)]` - Implement PersistentMessage for GPU kernel dispatch
7//! - `#[ring_kernel]` - Define a ring kernel handler
8//! - `#[stencil_kernel]` - Define a GPU stencil kernel (with `cuda-codegen` feature)
9//! - `#[gpu_kernel]` - Define a multi-backend GPU kernel with capability checking
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ringkernel_derive::{RingMessage, ring_kernel};
15//!
16//! #[derive(RingMessage)]
17//! struct AddRequest {
18//! #[message(id)]
19//! id: MessageId,
20//! a: f32,
21//! b: f32,
22//! }
23//!
24//! #[derive(RingMessage)]
25//! struct AddResponse {
26//! #[message(id)]
27//! id: MessageId,
28//! result: f32,
29//! }
30//!
31//! #[ring_kernel(id = "adder")]
32//! async fn process(ctx: &mut RingContext, req: AddRequest) -> AddResponse {
33//! AddResponse {
34//! id: MessageId::generate(),
35//! result: req.a + req.b,
36//! }
37//! }
38//! ```
39//!
40//! # Multi-Backend GPU Kernels
41//!
42//! The `#[gpu_kernel]` macro enables multi-backend code generation with capability checking:
43//!
44//! ```ignore
45//! use ringkernel_derive::gpu_kernel;
46//!
47//! // Generate code for CUDA and Metal, with fallback order
48//! #[gpu_kernel(backends = [cuda, metal], fallback = [wgpu, cpu])]
49//! fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
50//! let idx = global_thread_id_x();
51//! if idx < n {
52//! y[idx as usize] = a * x[idx as usize] + y[idx as usize];
53//! }
54//! }
55//!
56//! // Require specific capabilities at compile time
57//! #[gpu_kernel(backends = [cuda], requires = [f64, atomic64])]
58//! fn double_precision(data: &mut [f64], n: i32) {
59//! // Uses f64 operations - validated at compile time
60//! }
61//! ```
62//!
63//! # Stencil Kernels (with `cuda-codegen` feature)
64//!
65//! ```ignore
66//! use ringkernel_derive::stencil_kernel;
67//! use ringkernel_cuda_codegen::GridPos;
68//!
69//! #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
70//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
71//! let curr = p[pos.idx()];
72//! let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
73//! p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
74//! }
75//! ```
76
77use darling::{ast, FromDeriveInput, FromField, FromMeta};
78use proc_macro::TokenStream;
79use quote::{format_ident, quote};
80use syn::{parse_macro_input, DeriveInput, ItemFn};
81
82/// Attributes for the RingMessage derive macro.
83#[derive(Debug, FromDeriveInput)]
84#[darling(attributes(message, ring_message), supports(struct_named))]
85struct RingMessageArgs {
86 ident: syn::Ident,
87 generics: syn::Generics,
88 data: ast::Data<(), RingMessageField>,
89 /// Optional explicit message type ID.
90 /// If domain is specified, this is the offset within the domain (0-99).
91 /// If domain is not specified, this is the absolute type ID.
92 #[darling(default)]
93 type_id: Option<u64>,
94 /// Optional domain for message classification.
95 /// When specified, the final type ID = domain.base_type_id() + type_id.
96 #[darling(default)]
97 domain: Option<String>,
98 /// Whether this message is routable via K2K.
99 /// When true, generates a K2KMessageRegistration for runtime discovery.
100 #[darling(default)]
101 k2k_routable: bool,
102 /// Optional category for K2K routing groups.
103 /// Multiple messages can share a category for grouped routing.
104 #[darling(default)]
105 category: Option<String>,
106}
107
108/// Field attributes for RingMessage.
109#[derive(Debug, FromField)]
110#[darling(attributes(message))]
111struct RingMessageField {
112 ident: Option<syn::Ident>,
113 #[allow(dead_code)]
114 ty: syn::Type,
115 /// Mark this field as the message ID.
116 #[darling(default)]
117 id: bool,
118 /// Mark this field as the correlation ID.
119 #[darling(default)]
120 correlation: bool,
121 /// Mark this field as the priority.
122 #[darling(default)]
123 priority: bool,
124}
125
126/// Derive macro for implementing the RingMessage trait.
127///
128/// # Attributes
129///
130/// On the struct (via `#[message(...)]` or `#[ring_message(...)]`):
131/// - `type_id = 123` - Set explicit message type ID (or domain offset if domain is set)
132/// - `domain = "OrderMatching"` - Assign to a business domain (adds base type ID)
133/// - `k2k_routable = true` - Register for K2K routing discovery
134/// - `category = "orders"` - Group messages for K2K routing
135///
136/// On fields:
137/// - `#[message(id)]` - Mark as message ID field
138/// - `#[message(correlation)]` - Mark as correlation ID field
139/// - `#[message(priority)]` - Mark as priority field
140///
141/// # Examples
142///
143/// Basic usage:
144/// ```ignore
145/// #[derive(RingMessage)]
146/// #[message(type_id = 1)]
147/// struct MyMessage {
148/// #[message(id)]
149/// id: MessageId,
150/// #[message(correlation)]
151/// correlation: CorrelationId,
152/// #[message(priority)]
153/// priority: Priority,
154/// payload: Vec<u8>,
155/// }
156/// ```
157///
158/// With domain (type ID = 500 + 1 = 501):
159/// ```ignore
160/// #[derive(RingMessage)]
161/// #[ring_message(type_id = 1, domain = "OrderMatching")]
162/// pub struct SubmitOrderInput {
163/// #[message(id)]
164/// id: MessageId,
165/// pub order: Order,
166/// }
167/// // Also implements DomainMessage trait
168/// assert_eq!(SubmitOrderInput::domain(), Domain::OrderMatching);
169/// ```
170///
171/// K2K-routable message:
172/// ```ignore
173/// #[derive(RingMessage)]
174/// #[ring_message(type_id = 1, domain = "OrderMatching", k2k_routable = true, category = "orders")]
175/// pub struct SubmitOrderInput { ... }
176///
177/// // Runtime discovery:
178/// let registry = K2KTypeRegistry::discover();
179/// assert!(registry.is_routable(501));
180/// ```
181#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
182pub fn derive_ring_message(input: TokenStream) -> TokenStream {
183 let input = parse_macro_input!(input as DeriveInput);
184
185 let args = match RingMessageArgs::from_derive_input(&input) {
186 Ok(args) => args,
187 Err(e) => return e.write_errors().into(),
188 };
189
190 let name = &args.ident;
191 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
192
193 // Calculate base type ID (offset within domain, or absolute if no domain)
194 let base_type_id = args.type_id.unwrap_or_else(|| {
195 use std::collections::hash_map::DefaultHasher;
196 use std::hash::{Hash, Hasher};
197 let mut hasher = DefaultHasher::new();
198 name.to_string().hash(&mut hasher);
199 // If domain is set, hash to a value within 0-99 range
200 if args.domain.is_some() {
201 hasher.finish() % 100
202 } else {
203 hasher.finish()
204 }
205 });
206
207 // Find annotated fields
208 let fields = match &args.data {
209 ast::Data::Struct(fields) => fields,
210 _ => panic!("RingMessage can only be derived for structs"),
211 };
212
213 let mut id_field: Option<&syn::Ident> = None;
214 let mut correlation_field: Option<&syn::Ident> = None;
215 let mut priority_field: Option<&syn::Ident> = None;
216
217 for field in fields.iter() {
218 if field.id {
219 id_field = field.ident.as_ref();
220 }
221 if field.correlation {
222 correlation_field = field.ident.as_ref();
223 }
224 if field.priority {
225 priority_field = field.ident.as_ref();
226 }
227 }
228
229 // Generate message_id method
230 let message_id_impl = if let Some(field) = id_field {
231 quote! { self.#field }
232 } else {
233 quote! { ::ringkernel_core::message::MessageId::new(0) }
234 };
235
236 // Generate correlation_id method
237 let correlation_id_impl = if let Some(field) = correlation_field {
238 quote! { self.#field }
239 } else {
240 quote! { ::ringkernel_core::message::CorrelationId::none() }
241 };
242
243 // Generate priority method
244 let priority_impl = if let Some(field) = priority_field {
245 quote! { self.#field }
246 } else {
247 quote! { ::ringkernel_core::message::Priority::Normal }
248 };
249
250 // Generate message_type() implementation based on whether domain is specified
251 let message_type_impl = if let Some(ref domain_str) = args.domain {
252 // With domain: type_id = domain.base_type_id() + offset
253 quote! {
254 ::ringkernel_core::domain::Domain::from_str(#domain_str)
255 .unwrap_or(::ringkernel_core::domain::Domain::General)
256 .base_type_id() + #base_type_id
257 }
258 } else {
259 // Without domain: use absolute type_id
260 quote! { #base_type_id }
261 };
262
263 // Generate DomainMessage impl if domain is specified
264 let domain_impl = if let Some(ref domain_str) = args.domain {
265 quote! {
266 impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
267 fn domain() -> ::ringkernel_core::domain::Domain {
268 ::ringkernel_core::domain::Domain::from_str(#domain_str)
269 .unwrap_or(::ringkernel_core::domain::Domain::General)
270 }
271 }
272 }
273 } else {
274 quote! {}
275 };
276
277 // Generate K2K registration if k2k_routable is set
278 let k2k_registration = if args.k2k_routable {
279 let registration_name = format_ident!(
280 "__K2K_MESSAGE_REGISTRATION_{}",
281 name.to_string().to_uppercase()
282 );
283 let type_name_str = name.to_string();
284 let category_tokens = match &args.category {
285 Some(cat) => quote! { ::std::option::Option::Some(#cat) },
286 None => quote! { ::std::option::Option::None },
287 };
288
289 quote! {
290 #[allow(non_upper_case_globals)]
291 #[::inventory::submit]
292 static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
293 ::ringkernel_core::k2k::K2KMessageRegistration {
294 type_id: {
295 // Note: This is a const context, so we use the base calculation
296 // For domain types, we need to add the base manually
297 #base_type_id
298 },
299 type_name: #type_name_str,
300 k2k_routable: true,
301 category: #category_tokens,
302 };
303 }
304 } else {
305 quote! {}
306 };
307
308 let expanded = quote! {
309 impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
310 fn message_type() -> u64 {
311 #message_type_impl
312 }
313
314 fn message_id(&self) -> ::ringkernel_core::message::MessageId {
315 #message_id_impl
316 }
317
318 fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
319 #correlation_id_impl
320 }
321
322 fn priority(&self) -> ::ringkernel_core::message::Priority {
323 #priority_impl
324 }
325
326 fn serialize(&self) -> Vec<u8> {
327 // Use rkyv for serialization with a 4KB scratch buffer
328 // For larger payloads, rkyv will allocate as needed
329 ::rkyv::to_bytes::<_, 4096>(self)
330 .map(|v| v.to_vec())
331 .unwrap_or_default()
332 }
333
334 fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
335 where
336 Self: Sized,
337 {
338 use ::rkyv::Deserialize as _;
339 let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
340 let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
341 .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
342 "rkyv deserialization failed".to_string()
343 ))?;
344 Ok(deserialized)
345 }
346
347 fn size_hint(&self) -> usize {
348 ::std::mem::size_of::<Self>()
349 }
350 }
351
352 #domain_impl
353
354 #k2k_registration
355 };
356
357 TokenStream::from(expanded)
358}
359
360// ============================================================================
361// PersistentMessage Derive Macro
362// ============================================================================
363
364/// Maximum size for inline payload in persistent messages.
365#[allow(dead_code)]
366const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
367
368/// Attributes for the PersistentMessage derive macro.
369#[derive(Debug, FromDeriveInput)]
370#[darling(attributes(persistent_message), supports(struct_named))]
371struct PersistentMessageArgs {
372 ident: syn::Ident,
373 generics: syn::Generics,
374 /// Field data (reserved for future per-field attributes).
375 #[allow(dead_code)]
376 data: ast::Data<(), PersistentMessageField>,
377 /// Handler ID for CUDA dispatch (0-255).
378 handler_id: u32,
379 /// Whether this message type expects a response.
380 #[darling(default)]
381 requires_response: bool,
382}
383
384/// Field attributes for PersistentMessage (reserved for future use).
385#[derive(Debug, FromField)]
386#[darling(attributes(persistent_message))]
387struct PersistentMessageField {
388 /// Field identifier.
389 #[allow(dead_code)]
390 ident: Option<syn::Ident>,
391 /// Field type.
392 #[allow(dead_code)]
393 ty: syn::Type,
394}
395
396/// Derive macro for implementing the PersistentMessage trait.
397///
398/// This macro enables type-based dispatch within persistent GPU kernels by
399/// generating handler_id, inline payload serialization, and deserialization.
400///
401/// # Requirements
402///
403/// The struct must:
404/// - Already implement `RingMessage` (use `#[derive(RingMessage)]`)
405/// - Be `#[repr(C)]` for safe memory layout
406/// - Be `Copy` + `Clone` for inline payload serialization
407///
408/// # Attributes
409///
410/// On the struct:
411/// - `handler_id = N` (required) - CUDA dispatch handler ID (0-255)
412/// - `requires_response = true/false` (optional) - Whether this message expects a response
413///
414/// # Example
415///
416/// ```ignore
417/// use ringkernel_derive::{RingMessage, PersistentMessage};
418///
419/// #[derive(RingMessage, PersistentMessage, Clone, Copy)]
420/// #[repr(C)]
421/// #[message(type_id = 1001)]
422/// #[persistent_message(handler_id = 1, requires_response = true)]
423/// pub struct FraudCheckRequest {
424/// pub transaction_id: u64,
425/// pub amount: f32,
426/// pub account_id: u32,
427/// }
428///
429/// // Generated implementations:
430/// // - handler_id() returns 1
431/// // - requires_response() returns true
432/// // - to_inline_payload() serializes the struct to [u8; 32] if it fits
433/// // - from_inline_payload() deserializes from bytes
434/// // - payload_size() returns the struct size
435/// ```
436///
437/// # Size Validation
438///
439/// For inline payload serialization, structs must be <= 32 bytes.
440/// Larger structs will return `None` from `to_inline_payload()`.
441///
442/// # CUDA Integration
443///
444/// The handler_id maps to a switch case in generated CUDA code:
445///
446/// ```cuda
447/// switch (msg->handler_id) {
448/// case 1: handle_fraud_check(msg, state, response); break;
449/// case 2: handle_aggregate(msg, state, response); break;
450/// // ...
451/// }
452/// ```
453#[proc_macro_derive(PersistentMessage, attributes(persistent_message))]
454pub fn derive_persistent_message(input: TokenStream) -> TokenStream {
455 let input = parse_macro_input!(input as DeriveInput);
456
457 let args = match PersistentMessageArgs::from_derive_input(&input) {
458 Ok(args) => args,
459 Err(e) => return e.write_errors().into(),
460 };
461
462 let name = &args.ident;
463 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
464
465 let handler_id = args.handler_id;
466 let requires_response = args.requires_response;
467
468 // Generate the PersistentMessage implementation
469 let expanded = quote! {
470 impl #impl_generics ::ringkernel_core::persistent_message::PersistentMessage for #name #ty_generics #where_clause {
471 fn handler_id() -> u32 {
472 #handler_id
473 }
474
475 fn requires_response() -> bool {
476 #requires_response
477 }
478
479 fn payload_size() -> usize {
480 ::std::mem::size_of::<Self>()
481 }
482
483 fn to_inline_payload(&self) -> ::std::option::Option<[u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
484 // Only serialize if the struct fits in the inline payload
485 if ::std::mem::size_of::<Self>() > ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE {
486 return ::std::option::Option::None;
487 }
488
489 let mut payload = [0u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE];
490
491 // Safety: We've verified the struct fits in the payload,
492 // and the struct is repr(C) + Copy
493 unsafe {
494 ::std::ptr::copy_nonoverlapping(
495 self as *const Self as *const u8,
496 payload.as_mut_ptr(),
497 ::std::mem::size_of::<Self>()
498 );
499 }
500
501 ::std::option::Option::Some(payload)
502 }
503
504 fn from_inline_payload(payload: &[u8]) -> ::ringkernel_core::error::Result<Self> {
505 let size = ::std::mem::size_of::<Self>();
506
507 if payload.len() < size {
508 return ::std::result::Result::Err(
509 ::ringkernel_core::error::RingKernelError::DeserializationError(
510 ::std::format!(
511 "Payload too small: expected {} bytes, got {}",
512 size,
513 payload.len()
514 )
515 )
516 );
517 }
518
519 // Safety: We've verified the payload is large enough,
520 // and the struct is repr(C) + Copy
521 let value = unsafe {
522 ::std::ptr::read(payload.as_ptr() as *const Self)
523 };
524
525 ::std::result::Result::Ok(value)
526 }
527 }
528 };
529
530 TokenStream::from(expanded)
531}
532
533/// Attributes for the ring_kernel macro.
534#[derive(Debug, FromMeta)]
535struct RingKernelArgs {
536 /// Kernel identifier.
537 id: String,
538 /// Execution mode (persistent or event_driven).
539 #[darling(default)]
540 mode: Option<String>,
541 /// Grid size.
542 #[darling(default)]
543 grid_size: Option<u32>,
544 /// Block size.
545 #[darling(default)]
546 block_size: Option<u32>,
547 /// Target kernels this kernel publishes to.
548 #[darling(default)]
549 publishes_to: Option<String>,
550}
551
552/// Attribute macro for defining ring kernel handlers.
553///
554/// # Attributes
555///
556/// - `id` (required) - Unique kernel identifier
557/// - `mode` - Execution mode: "persistent" (default) or "event_driven"
558/// - `grid_size` - Number of blocks (default: 1)
559/// - `block_size` - Threads per block (default: 256)
560/// - `publishes_to` - Comma-separated list of target kernel IDs
561///
562/// # Example
563///
564/// ```ignore
565/// #[ring_kernel(id = "processor", mode = "persistent", block_size = 128)]
566/// async fn handle(ctx: &mut RingContext, msg: MyMessage) -> MyResponse {
567/// // Process message
568/// MyResponse { ... }
569/// }
570/// ```
571#[proc_macro_attribute]
572pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
573 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
574 Ok(v) => v,
575 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
576 };
577
578 let args = match RingKernelArgs::from_list(&args) {
579 Ok(v) => v,
580 Err(e) => return TokenStream::from(e.write_errors()),
581 };
582
583 let input = parse_macro_input!(item as ItemFn);
584
585 let kernel_id = &args.id;
586 let fn_name = &input.sig.ident;
587 let fn_vis = &input.vis;
588 let fn_block = &input.block;
589 let fn_attrs = &input.attrs;
590
591 // Parse function signature
592 let inputs = &input.sig.inputs;
593 let output = &input.sig.output;
594
595 // Extract context and message types from signature
596 let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
597 let ctx = inputs.first();
598 let msg = inputs.iter().nth(1);
599 (ctx, msg)
600 } else {
601 (None, None)
602 };
603
604 // Get message type
605 let msg_type = msg_arg
606 .map(|arg| {
607 if let syn::FnArg::Typed(pat_type) = arg {
608 pat_type.ty.clone()
609 } else {
610 syn::parse_quote!(())
611 }
612 })
613 .unwrap_or_else(|| syn::parse_quote!(()));
614
615 // Generate kernel mode
616 let mode = args.mode.as_deref().unwrap_or("persistent");
617 let mode_expr = if mode == "event_driven" {
618 quote! { ::ringkernel_core::types::KernelMode::EventDriven }
619 } else {
620 quote! { ::ringkernel_core::types::KernelMode::Persistent }
621 };
622
623 // Generate grid/block size
624 let grid_size = args.grid_size.unwrap_or(1);
625 let block_size = args.block_size.unwrap_or(256);
626
627 // Parse publishes_to into a list of target kernel IDs
628 let publishes_to_targets: Vec<String> = args
629 .publishes_to
630 .as_ref()
631 .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
632 .unwrap_or_default();
633
634 // Generate registration struct name
635 let registration_name = format_ident!(
636 "__RINGKERNEL_REGISTRATION_{}",
637 fn_name.to_string().to_uppercase()
638 );
639 let handler_name = format_ident!("{}_handler", fn_name);
640
641 // Generate the expanded code
642 let expanded = quote! {
643 // Original function (preserved for documentation/testing)
644 #(#fn_attrs)*
645 #fn_vis async fn #fn_name #inputs #output #fn_block
646
647 // Kernel handler wrapper
648 #fn_vis fn #handler_name(
649 ctx: &mut ::ringkernel_core::RingContext<'_>,
650 envelope: ::ringkernel_core::message::MessageEnvelope,
651 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
652 Box::pin(async move {
653 // Deserialize input message
654 let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
655
656 // Call the actual handler
657 let response = #fn_name(ctx, msg).await;
658
659 // Serialize response
660 let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
661 let response_header = ::ringkernel_core::message::MessageHeader::new(
662 <_ as ::ringkernel_core::message::RingMessage>::message_type(),
663 envelope.header.dest_kernel,
664 envelope.header.source_kernel,
665 response_payload.len(),
666 ctx.now(),
667 ).with_correlation(envelope.header.correlation_id);
668
669 Ok(::ringkernel_core::message::MessageEnvelope {
670 header: response_header,
671 payload: response_payload,
672 ..::std::default::Default::default()
673 })
674 })
675 }
676
677 // Kernel registration
678 #[allow(non_upper_case_globals)]
679 #[::inventory::submit]
680 static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
681 id: #kernel_id,
682 mode: #mode_expr,
683 grid_size: #grid_size,
684 block_size: #block_size,
685 publishes_to: &[#(#publishes_to_targets),*],
686 };
687 };
688
689 TokenStream::from(expanded)
690}
691
692/// Derive macro for GPU-compatible types.
693///
694/// Ensures the type has a stable memory layout suitable for GPU transfer.
695#[proc_macro_derive(GpuType)]
696pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
697 let input = parse_macro_input!(input as DeriveInput);
698 let name = &input.ident;
699 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
700
701 // Generate assertions for GPU compatibility
702 let expanded = quote! {
703 // Verify type is Copy (required for GPU transfer)
704 const _: fn() = || {
705 fn assert_copy<T: Copy>() {}
706 assert_copy::<#name #ty_generics>();
707 };
708
709 // Verify type is Pod (plain old data)
710 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
711 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
712 };
713
714 TokenStream::from(expanded)
715}
716
717// ============================================================================
718// Stencil Kernel Macro (requires cuda-codegen feature)
719// ============================================================================
720
721/// Attributes for the stencil_kernel macro.
722#[derive(Debug, FromMeta)]
723struct StencilKernelArgs {
724 /// Kernel identifier.
725 id: String,
726 /// Grid dimensionality: "1d", "2d", or "3d".
727 #[darling(default)]
728 grid: Option<String>,
729 /// Tile/block size (single value for square tiles).
730 #[darling(default)]
731 tile_size: Option<u32>,
732 /// Tile width (for non-square tiles).
733 #[darling(default)]
734 tile_width: Option<u32>,
735 /// Tile height (for non-square tiles).
736 #[darling(default)]
737 tile_height: Option<u32>,
738 /// Halo/ghost cell width (stencil radius).
739 #[darling(default)]
740 halo: Option<u32>,
741}
742
743/// Attribute macro for defining stencil kernels that transpile to CUDA.
744///
745/// This macro generates CUDA C code from Rust stencil kernel functions at compile time.
746/// The generated CUDA source is embedded in the binary and can be compiled at runtime
747/// using NVRTC.
748///
749/// # Attributes
750///
751/// - `id` (required) - Unique kernel identifier
752/// - `grid` - Grid dimensionality: "1d", "2d" (default), or "3d"
753/// - `tile_size` - Tile/block size (default: 16)
754/// - `tile_width` / `tile_height` - Non-square tile dimensions
755/// - `halo` - Stencil radius / ghost cell width (default: 1)
756///
757/// # Supported Rust Subset
758///
759/// - Primitives: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
760/// - Slices: `&[T]`, `&mut [T]`
761/// - Arithmetic: `+`, `-`, `*`, `/`, `%`
762/// - Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=`
763/// - Let bindings: `let x = expr;`
764/// - If/else: `if cond { a } else { b }`
765/// - Stencil intrinsics via `GridPos`
766///
767/// # Example
768///
769/// ```ignore
770/// use ringkernel_derive::stencil_kernel;
771/// use ringkernel_cuda_codegen::GridPos;
772///
773/// #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
774/// fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
775/// let curr = p[pos.idx()];
776/// let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
777/// p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
778/// }
779///
780/// // Access generated CUDA source:
781/// assert!(FDTD_CUDA_SOURCE.contains("__global__"));
782/// ```
783#[proc_macro_attribute]
784pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
785 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
786 Ok(v) => v,
787 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
788 };
789
790 let args = match StencilKernelArgs::from_list(&args) {
791 Ok(v) => v,
792 Err(e) => return TokenStream::from(e.write_errors()),
793 };
794
795 let input = parse_macro_input!(item as ItemFn);
796
797 // Generate the stencil kernel code
798 stencil_kernel_impl(args, input)
799}
800
801fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
802 let kernel_id = &args.id;
803 let fn_name = &input.sig.ident;
804 let fn_vis = &input.vis;
805 let fn_block = &input.block;
806 let fn_inputs = &input.sig.inputs;
807 let fn_output = &input.sig.output;
808 let fn_attrs = &input.attrs;
809
810 // Parse configuration
811 let grid = args.grid.as_deref().unwrap_or("2d");
812 let tile_width = args
813 .tile_width
814 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
815 let tile_height = args
816 .tile_height
817 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
818 let halo = args.halo.unwrap_or(1);
819
820 // Generate CUDA source constant name
821 let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
822
823 // Generate registration name
824 let registration_name = format_ident!(
825 "__STENCIL_KERNEL_REGISTRATION_{}",
826 fn_name.to_string().to_uppercase()
827 );
828
829 // Transpile to CUDA (if feature enabled)
830 #[cfg(feature = "cuda-codegen")]
831 let cuda_source_code = {
832 use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
833
834 let grid_type = match grid {
835 "1d" => Grid::Grid1D,
836 "2d" => Grid::Grid2D,
837 "3d" => Grid::Grid3D,
838 _ => Grid::Grid2D,
839 };
840
841 let config = StencilConfig::new(kernel_id.clone())
842 .with_grid(grid_type)
843 .with_tile_size(tile_width as usize, tile_height as usize)
844 .with_halo(halo as usize);
845
846 match transpile_stencil_kernel(&input, &config) {
847 Ok(cuda) => cuda,
848 Err(e) => {
849 return TokenStream::from(
850 syn::Error::new_spanned(
851 &input.sig.ident,
852 format!("CUDA transpilation failed: {}", e),
853 )
854 .to_compile_error(),
855 );
856 }
857 }
858 };
859
860 #[cfg(not(feature = "cuda-codegen"))]
861 let cuda_source_code = format!(
862 "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
863 kernel_id
864 );
865
866 // Generate the expanded code
867 let expanded = quote! {
868 // Original function (for documentation/testing/CPU fallback)
869 #(#fn_attrs)*
870 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
871
872 /// Generated CUDA source code for this stencil kernel.
873 #fn_vis const #cuda_const_name: &str = #cuda_source_code;
874
875 /// Stencil kernel registration for runtime discovery.
876 #[allow(non_upper_case_globals)]
877 #[::inventory::submit]
878 static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
879 ::ringkernel_core::__private::StencilKernelRegistration {
880 id: #kernel_id,
881 grid: #grid,
882 tile_width: #tile_width,
883 tile_height: #tile_height,
884 halo: #halo,
885 cuda_source: #cuda_source_code,
886 };
887 };
888
889 TokenStream::from(expanded)
890}
891
892// ============================================================================
893// Multi-Backend GPU Kernel Macro
894// ============================================================================
895
896/// GPU backend targets (internal use only).
897#[derive(Debug, Clone, Copy, PartialEq, Eq)]
898enum GpuBackend {
899 /// NVIDIA CUDA backend.
900 Cuda,
901 /// Apple Metal backend.
902 Metal,
903 /// WebGPU backend (cross-platform).
904 Wgpu,
905 /// CPU fallback backend.
906 Cpu,
907}
908
909impl GpuBackend {
910 fn from_str(s: &str) -> Option<Self> {
911 match s.to_lowercase().as_str() {
912 "cuda" => Some(Self::Cuda),
913 "metal" => Some(Self::Metal),
914 "wgpu" | "webgpu" => Some(Self::Wgpu),
915 "cpu" => Some(Self::Cpu),
916 _ => None,
917 }
918 }
919
920 fn as_str(&self) -> &'static str {
921 match self {
922 Self::Cuda => "cuda",
923 Self::Metal => "metal",
924 Self::Wgpu => "wgpu",
925 Self::Cpu => "cpu",
926 }
927 }
928}
929
930/// GPU capability flags that can be required by a kernel (internal use only).
931#[derive(Debug, Clone, Copy, PartialEq, Eq)]
932enum GpuCapability {
933 /// 64-bit floating point support.
934 Float64,
935 /// 64-bit integer support.
936 Int64,
937 /// 64-bit atomics support.
938 Atomic64,
939 /// Cooperative groups / grid-wide sync.
940 CooperativeGroups,
941 /// Subgroup / warp / SIMD operations.
942 Subgroups,
943 /// Shared memory / threadgroup memory.
944 SharedMemory,
945 /// Dynamic parallelism (launching kernels from kernels).
946 DynamicParallelism,
947 /// Half-precision (f16) support.
948 Float16,
949}
950
951impl GpuCapability {
952 fn from_str(s: &str) -> Option<Self> {
953 match s.to_lowercase().as_str() {
954 "f64" | "float64" => Some(Self::Float64),
955 "i64" | "int64" => Some(Self::Int64),
956 "atomic64" => Some(Self::Atomic64),
957 "cooperative_groups" | "cooperativegroups" | "grid_sync" => {
958 Some(Self::CooperativeGroups)
959 }
960 "subgroups" | "warp" | "simd" => Some(Self::Subgroups),
961 "shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
962 "dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
963 "f16" | "float16" | "half" => Some(Self::Float16),
964 _ => None,
965 }
966 }
967
968 fn as_str(&self) -> &'static str {
969 match self {
970 Self::Float64 => "f64",
971 Self::Int64 => "i64",
972 Self::Atomic64 => "atomic64",
973 Self::CooperativeGroups => "cooperative_groups",
974 Self::Subgroups => "subgroups",
975 Self::SharedMemory => "shared_memory",
976 Self::DynamicParallelism => "dynamic_parallelism",
977 Self::Float16 => "f16",
978 }
979 }
980
981 /// Check if a backend supports this capability.
982 fn supported_by(&self, backend: GpuBackend) -> bool {
983 match (self, backend) {
984 // CUDA supports everything
985 (_, GpuBackend::Cuda) => true,
986
987 // Metal capabilities
988 (Self::Float64, GpuBackend::Metal) => false,
989 (Self::CooperativeGroups, GpuBackend::Metal) => false,
990 (Self::DynamicParallelism, GpuBackend::Metal) => false,
991 (_, GpuBackend::Metal) => true,
992
993 // WebGPU capabilities
994 (Self::Float64, GpuBackend::Wgpu) => false,
995 (Self::Int64, GpuBackend::Wgpu) => false,
996 (Self::Atomic64, GpuBackend::Wgpu) => false, // Emulated only
997 (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
998 (Self::DynamicParallelism, GpuBackend::Wgpu) => false,
999 (Self::Subgroups, GpuBackend::Wgpu) => true, // Optional extension
1000 (_, GpuBackend::Wgpu) => true,
1001
1002 // CPU supports everything (in emulation)
1003 (_, GpuBackend::Cpu) => true,
1004 }
1005 }
1006}
1007
1008/// Attributes for the gpu_kernel macro.
1009#[derive(Debug)]
1010struct GpuKernelArgs {
1011 /// Kernel identifier.
1012 id: Option<String>,
1013 /// Target backends to generate code for.
1014 backends: Vec<GpuBackend>,
1015 /// Fallback order for backend selection.
1016 fallback: Vec<GpuBackend>,
1017 /// Required capabilities.
1018 requires: Vec<GpuCapability>,
1019 /// Block/workgroup size.
1020 block_size: Option<u32>,
1021}
1022
1023impl Default for GpuKernelArgs {
1024 fn default() -> Self {
1025 Self {
1026 id: None,
1027 backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
1028 fallback: vec![
1029 GpuBackend::Cuda,
1030 GpuBackend::Metal,
1031 GpuBackend::Wgpu,
1032 GpuBackend::Cpu,
1033 ],
1034 requires: Vec::new(),
1035 block_size: None,
1036 }
1037 }
1038}
1039
1040impl GpuKernelArgs {
1041 fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
1042 let mut args = Self::default();
1043 let attr_str = attr.to_string();
1044
1045 // Parse backends = [...]
1046 if let Some(start) = attr_str.find("backends") {
1047 if let Some(bracket_start) = attr_str[start..].find('[') {
1048 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1049 let backends_str =
1050 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1051 args.backends = backends_str
1052 .split(',')
1053 .filter_map(|s| GpuBackend::from_str(s.trim()))
1054 .collect();
1055 }
1056 }
1057 }
1058
1059 // Parse fallback = [...]
1060 if let Some(start) = attr_str.find("fallback") {
1061 if let Some(bracket_start) = attr_str[start..].find('[') {
1062 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1063 let fallback_str =
1064 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1065 args.fallback = fallback_str
1066 .split(',')
1067 .filter_map(|s| GpuBackend::from_str(s.trim()))
1068 .collect();
1069 }
1070 }
1071 }
1072
1073 // Parse requires = [...]
1074 if let Some(start) = attr_str.find("requires") {
1075 if let Some(bracket_start) = attr_str[start..].find('[') {
1076 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1077 let requires_str =
1078 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1079 args.requires = requires_str
1080 .split(',')
1081 .filter_map(|s| GpuCapability::from_str(s.trim()))
1082 .collect();
1083 }
1084 }
1085 }
1086
1087 // Parse id = "..."
1088 if let Some(start) = attr_str.find("id") {
1089 if let Some(quote_start) = attr_str[start..].find('"') {
1090 if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
1091 args.id = Some(
1092 attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
1093 .to_string(),
1094 );
1095 }
1096 }
1097 }
1098
1099 // Parse block_size = N
1100 if let Some(start) = attr_str.find("block_size") {
1101 if let Some(eq) = attr_str[start..].find('=') {
1102 let rest = &attr_str[start + eq + 1..];
1103 let num_end = rest
1104 .find(|c: char| !c.is_numeric() && c != ' ')
1105 .unwrap_or(rest.len());
1106 if let Ok(n) = rest[..num_end].trim().parse() {
1107 args.block_size = Some(n);
1108 }
1109 }
1110 }
1111
1112 Ok(args)
1113 }
1114
1115 /// Validate that all required capabilities are supported by at least one backend.
1116 fn validate_capabilities(&self) -> Result<(), String> {
1117 for cap in &self.requires {
1118 let mut supported_by_any = false;
1119 for backend in &self.backends {
1120 if cap.supported_by(*backend) {
1121 supported_by_any = true;
1122 break;
1123 }
1124 }
1125 if !supported_by_any {
1126 return Err(format!(
1127 "Capability '{}' is not supported by any of the specified backends: {:?}",
1128 cap.as_str(),
1129 self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
1130 ));
1131 }
1132 }
1133 Ok(())
1134 }
1135
1136 /// Get backends that support all required capabilities.
1137 fn compatible_backends(&self) -> Vec<GpuBackend> {
1138 self.backends
1139 .iter()
1140 .filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
1141 .copied()
1142 .collect()
1143 }
1144}
1145
1146/// Attribute macro for defining multi-backend GPU kernels.
1147///
1148/// This macro generates code for multiple GPU backends with compile-time
1149/// capability validation. It integrates with the `ringkernel-ir` crate
1150/// to lower Rust DSL to backend-specific shader code.
1151///
1152/// # Attributes
1153///
1154/// - `backends = [cuda, metal, wgpu]` - Target backends (default: all)
1155/// - `fallback = [cuda, metal, wgpu, cpu]` - Fallback order for runtime selection
1156/// - `requires = [f64, atomic64]` - Required capabilities (validated at compile time)
1157/// - `id = "kernel_name"` - Explicit kernel identifier
1158/// - `block_size = 256` - Thread block size
1159///
1160/// # Example
1161///
1162/// ```ignore
1163/// use ringkernel_derive::gpu_kernel;
1164///
1165/// #[gpu_kernel(backends = [cuda, metal], requires = [subgroups])]
1166/// fn warp_reduce(data: &mut [f32], n: i32) {
1167/// let idx = global_thread_id_x();
1168/// if idx < n {
1169/// // Use warp shuffle for reduction
1170/// let val = data[idx as usize];
1171/// let reduced = warp_reduce_sum(val);
1172/// if lane_id() == 0 {
1173/// data[idx as usize] = reduced;
1174/// }
1175/// }
1176/// }
1177/// ```
1178///
1179/// # Capability Checking
1180///
1181/// The macro validates at compile time that all required capabilities are
1182/// supported by at least one target backend:
1183///
1184/// | Capability | CUDA | Metal | WebGPU | CPU |
1185/// |------------|------|-------|--------|-----|
1186/// | f64 | Yes | No | No | Yes |
1187/// | i64 | Yes | Yes | No | Yes |
1188/// | atomic64 | Yes | Yes | No* | Yes |
1189/// | cooperative_groups | Yes | No | No | Yes |
1190/// | subgroups | Yes | Yes | Opt | Yes |
1191/// | shared_memory | Yes | Yes | Yes | Yes |
1192/// | f16 | Yes | Yes | Yes | Yes |
1193///
1194/// *WebGPU emulates 64-bit atomics with 32-bit pairs.
1195///
1196/// # Generated Code
1197///
1198/// For each compatible backend, the macro generates:
1199/// - Backend-specific source code constant (e.g., `KERNEL_NAME_CUDA_SOURCE`)
1200/// - Registration entry for runtime discovery
1201/// - CPU fallback function (if `cpu_fallback = true`)
1202#[proc_macro_attribute]
1203pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
1204 let attr2: proc_macro2::TokenStream = attr.into();
1205 let args = match GpuKernelArgs::parse(attr2) {
1206 Ok(args) => args,
1207 Err(e) => return TokenStream::from(e.write_errors()),
1208 };
1209
1210 let input = parse_macro_input!(item as ItemFn);
1211
1212 // Validate capabilities
1213 if let Err(msg) = args.validate_capabilities() {
1214 return TokenStream::from(
1215 syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
1216 );
1217 }
1218
1219 gpu_kernel_impl(args, input)
1220}
1221
1222fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
1223 let fn_name = &input.sig.ident;
1224 let fn_vis = &input.vis;
1225 let fn_block = &input.block;
1226 let fn_inputs = &input.sig.inputs;
1227 let fn_output = &input.sig.output;
1228 let fn_attrs = &input.attrs;
1229
1230 let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
1231 let block_size = args.block_size.unwrap_or(256);
1232
1233 // Get compatible backends
1234 let compatible_backends = args.compatible_backends();
1235
1236 // Generate backend-specific source constants
1237 let mut source_constants = Vec::new();
1238
1239 for backend in &compatible_backends {
1240 let const_name = format_ident!(
1241 "{}_{}",
1242 fn_name.to_string().to_uppercase(),
1243 backend.as_str().to_uppercase()
1244 );
1245
1246 let backend_str = backend.as_str();
1247
1248 // Generate placeholder source (actual IR lowering happens at build time)
1249 // In a full implementation, this would call ringkernel-ir lowering
1250 let source_placeholder = format!(
1251 "// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
1252 backend_str.to_uppercase(),
1253 kernel_id,
1254 args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
1255 );
1256
1257 source_constants.push(quote! {
1258 /// Generated source code for this kernel.
1259 #fn_vis const #const_name: &str = #source_placeholder;
1260 });
1261 }
1262
1263 // Generate capability flags as strings
1264 let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
1265 let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
1266 let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
1267
1268 // Generate registration struct name
1269 let registration_name = format_ident!(
1270 "__GPU_KERNEL_REGISTRATION_{}",
1271 fn_name.to_string().to_uppercase()
1272 );
1273
1274 // Generate info struct name
1275 let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
1276
1277 // Generate the expanded code
1278 let expanded = quote! {
1279 // Original function (CPU fallback / documentation / testing)
1280 #(#fn_attrs)*
1281 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
1282
1283 // Backend source constants
1284 #(#source_constants)*
1285
1286 /// Multi-backend kernel information.
1287 #fn_vis mod #info_name {
1288 /// Kernel identifier.
1289 pub const ID: &str = #kernel_id;
1290
1291 /// Block/workgroup size.
1292 pub const BLOCK_SIZE: u32 = #block_size;
1293
1294 /// Required capabilities.
1295 pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
1296
1297 /// Compatible backends (those that support all required capabilities).
1298 pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
1299
1300 /// Fallback order for runtime backend selection.
1301 pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
1302 }
1303
1304 /// GPU kernel registration for runtime discovery.
1305 #[allow(non_upper_case_globals)]
1306 #[::inventory::submit]
1307 static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
1308 ::ringkernel_core::__private::GpuKernelRegistration {
1309 id: #kernel_id,
1310 block_size: #block_size,
1311 capabilities: &[#(#capability_strs),*],
1312 backends: &[#(#backend_strs),*],
1313 fallback_order: &[#(#fallback_strs),*],
1314 };
1315 };
1316
1317 TokenStream::from(expanded)
1318}
1319
1320// ============================================================================
1321// ControlBlockState Derive Macro (FR-4)
1322// ============================================================================
1323
1324/// Attributes for the ControlBlockState derive macro.
1325#[derive(Debug, FromDeriveInput)]
1326#[darling(attributes(state), supports(struct_named))]
1327struct ControlBlockStateArgs {
1328 ident: syn::Ident,
1329 generics: syn::Generics,
1330 /// State version for forward compatibility.
1331 #[darling(default)]
1332 version: Option<u32>,
1333}
1334
1335/// Derive macro for implementing EmbeddedState trait.
1336///
1337/// This macro generates implementations for types that can be stored in
1338/// the ControlBlock's 24-byte `_reserved` field for zero-copy state access.
1339///
1340/// # Requirements
1341///
1342/// The type must:
1343/// - Be `#[repr(C)]` for stable memory layout
1344/// - Be <= 24 bytes in size (checked at compile time)
1345/// - Implement `Clone`, `Copy`, and `Default`
1346/// - Contain only POD (Plain Old Data) types
1347///
1348/// # Attributes
1349///
1350/// - `#[state(version = N)]` - Set state version for migrations (default: 1)
1351///
1352/// # Example
1353///
1354/// ```ignore
1355/// #[derive(ControlBlockState, Default, Clone, Copy)]
1356/// #[repr(C, align(8))]
1357/// #[state(version = 1)]
1358/// pub struct OrderBookState {
1359/// pub best_bid: u64, // 8 bytes
1360/// pub best_ask: u64, // 8 bytes
1361/// pub order_count: u32, // 4 bytes
1362/// pub _pad: u32, // 4 bytes (padding for alignment)
1363/// } // Total: 24 bytes - fits in ControlBlock._reserved
1364///
1365/// // Use with ControlBlockStateHelper:
1366/// let mut block = ControlBlock::new();
1367/// let state = OrderBookState { best_bid: 100, best_ask: 101, order_count: 42, _pad: 0 };
1368/// ControlBlockStateHelper::write_embedded(&mut block, &state)?;
1369/// ```
1370///
1371/// # Size Validation
1372///
1373/// The macro generates a compile-time assertion that fails if the type
1374/// exceeds 24 bytes:
1375///
1376/// ```ignore
1377/// #[derive(ControlBlockState, Default, Clone, Copy)]
1378/// #[repr(C)]
1379/// struct TooLarge {
1380/// data: [u8; 32], // 32 bytes - COMPILE ERROR!
1381/// }
1382/// ```
1383#[proc_macro_derive(ControlBlockState, attributes(state))]
1384pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
1385 let input = parse_macro_input!(input as DeriveInput);
1386
1387 let args = match ControlBlockStateArgs::from_derive_input(&input) {
1388 Ok(args) => args,
1389 Err(e) => return e.write_errors().into(),
1390 };
1391
1392 let name = &args.ident;
1393 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
1394 let version = args.version.unwrap_or(1);
1395
1396 let expanded = quote! {
1397 // Compile-time size check: EmbeddedState must fit in 24 bytes
1398 const _: () = {
1399 assert!(
1400 ::std::mem::size_of::<#name #ty_generics>() <= 24,
1401 "ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
1402 );
1403 };
1404
1405 // Verify type is Copy (required for GPU transfer)
1406 const _: fn() = || {
1407 fn assert_copy<T: Copy>() {}
1408 assert_copy::<#name #ty_generics>();
1409 };
1410
1411 // Implement Pod and Zeroable (required by EmbeddedState)
1412 // SAFETY: Type is #[repr(C)] with only primitive types, verified by user
1413 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
1414 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
1415
1416 // Implement EmbeddedState
1417 impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
1418 const VERSION: u32 = #version;
1419
1420 fn is_embedded() -> bool {
1421 true
1422 }
1423 }
1424 };
1425
1426 TokenStream::from(expanded)
1427}