ringkernel_derive/lib.rs
1//! Procedural macros for RingKernel.
2//!
3//! This crate provides the following macros:
4//!
5//! - `#[derive(RingMessage)]` - Implement the RingMessage trait for message types
6//! - `#[derive(PersistentMessage)]` - Implement PersistentMessage for GPU kernel dispatch
7//! - `#[ring_kernel]` - Define a ring kernel handler
8//! - `#[stencil_kernel]` - Define a GPU stencil kernel (with `cuda-codegen` feature)
9//! - `#[gpu_kernel]` - Define a multi-backend GPU kernel with capability checking
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ringkernel_derive::{RingMessage, ring_kernel};
15//!
16//! #[derive(RingMessage)]
17//! struct AddRequest {
18//! #[message(id)]
19//! id: MessageId,
20//! a: f32,
21//! b: f32,
22//! }
23//!
24//! #[derive(RingMessage)]
25//! struct AddResponse {
26//! #[message(id)]
27//! id: MessageId,
28//! result: f32,
29//! }
30//!
31//! #[ring_kernel(id = "adder")]
32//! async fn process(ctx: &mut RingContext, req: AddRequest) -> AddResponse {
33//! AddResponse {
34//! id: MessageId::generate(),
35//! result: req.a + req.b,
36//! }
37//! }
38//! ```
39//!
40//! # Multi-Backend GPU Kernels
41//!
42//! The `#[gpu_kernel]` macro enables multi-backend code generation with capability checking:
43//!
44//! ```ignore
45//! use ringkernel_derive::gpu_kernel;
46//!
47//! // Generate code for CUDA and Metal, with fallback order
48//! #[gpu_kernel(backends = [cuda, metal], fallback = [wgpu, cpu])]
49//! fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
50//! let idx = global_thread_id_x();
51//! if idx < n {
52//! y[idx as usize] = a * x[idx as usize] + y[idx as usize];
53//! }
54//! }
55//!
56//! // Require specific capabilities at compile time
57//! #[gpu_kernel(backends = [cuda], requires = [f64, atomic64])]
58//! fn double_precision(data: &mut [f64], n: i32) {
59//! // Uses f64 operations - validated at compile time
60//! }
61//! ```
62//!
63//! # Stencil Kernels (with `cuda-codegen` feature)
64//!
65//! ```ignore
66//! use ringkernel_derive::stencil_kernel;
67//! use ringkernel_cuda_codegen::GridPos;
68//!
69//! #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
70//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
71//! let curr = p[pos.idx()];
72//! let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
73//! p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
74//! }
75//! ```
76
77use darling::{ast, FromDeriveInput, FromField, FromMeta};
78use proc_macro::TokenStream;
79use quote::{format_ident, quote};
80use syn::{parse_macro_input, DeriveInput, ItemFn};
81
82/// Attributes for the RingMessage derive macro.
83#[derive(Debug, FromDeriveInput)]
84#[darling(attributes(message, ring_message), supports(struct_named))]
85struct RingMessageArgs {
86 ident: syn::Ident,
87 generics: syn::Generics,
88 data: ast::Data<(), RingMessageField>,
89 /// Optional explicit message type ID.
90 /// If domain is specified, this is the offset within the domain (0-99).
91 /// If domain is not specified, this is the absolute type ID.
92 #[darling(default)]
93 type_id: Option<u64>,
94 /// Optional domain for message classification.
95 /// When specified, the final type ID = domain.base_type_id() + type_id.
96 #[darling(default)]
97 domain: Option<String>,
98 /// Whether this message is routable via K2K.
99 /// When true, generates a K2KMessageRegistration for runtime discovery.
100 #[darling(default)]
101 k2k_routable: bool,
102 /// Optional category for K2K routing groups.
103 /// Multiple messages can share a category for grouped routing.
104 #[darling(default)]
105 category: Option<String>,
106}
107
108/// Field attributes for RingMessage.
109#[derive(Debug, FromField)]
110#[darling(attributes(message))]
111struct RingMessageField {
112 ident: Option<syn::Ident>,
113 #[allow(dead_code)]
114 ty: syn::Type,
115 /// Mark this field as the message ID.
116 #[darling(default)]
117 id: bool,
118 /// Mark this field as the correlation ID.
119 #[darling(default)]
120 correlation: bool,
121 /// Mark this field as the priority.
122 #[darling(default)]
123 priority: bool,
124}
125
126/// Derive macro for implementing the RingMessage trait.
127///
128/// # Attributes
129///
130/// On the struct (via `#[message(...)]` or `#[ring_message(...)]`):
131/// - `type_id = 123` - Set explicit message type ID (or domain offset if domain is set)
132/// - `domain = "OrderMatching"` - Assign to a business domain (adds base type ID)
133/// - `k2k_routable = true` - Register for K2K routing discovery
134/// - `category = "orders"` - Group messages for K2K routing
135///
136/// On fields:
137/// - `#[message(id)]` - Mark as message ID field
138/// - `#[message(correlation)]` - Mark as correlation ID field
139/// - `#[message(priority)]` - Mark as priority field
140///
141/// # Examples
142///
143/// Basic usage:
144/// ```ignore
145/// #[derive(RingMessage)]
146/// #[message(type_id = 1)]
147/// struct MyMessage {
148/// #[message(id)]
149/// id: MessageId,
150/// #[message(correlation)]
151/// correlation: CorrelationId,
152/// #[message(priority)]
153/// priority: Priority,
154/// payload: Vec<u8>,
155/// }
156/// ```
157///
158/// With domain (type ID = 500 + 1 = 501):
159/// ```ignore
160/// #[derive(RingMessage)]
161/// #[ring_message(type_id = 1, domain = "OrderMatching")]
162/// pub struct SubmitOrderInput {
163/// #[message(id)]
164/// id: MessageId,
165/// pub order: Order,
166/// }
167/// // Also implements DomainMessage trait
168/// assert_eq!(SubmitOrderInput::domain(), Domain::OrderMatching);
169/// ```
170///
171/// K2K-routable message:
172/// ```ignore
173/// #[derive(RingMessage)]
174/// #[ring_message(type_id = 1, domain = "OrderMatching", k2k_routable = true, category = "orders")]
175/// pub struct SubmitOrderInput { ... }
176///
177/// // Runtime discovery:
178/// let registry = K2KTypeRegistry::discover();
179/// assert!(registry.is_routable(501));
180/// ```
181#[proc_macro_derive(RingMessage, attributes(message, ring_message))]
182pub fn derive_ring_message(input: TokenStream) -> TokenStream {
183 let input = parse_macro_input!(input as DeriveInput);
184
185 let args = match RingMessageArgs::from_derive_input(&input) {
186 Ok(args) => args,
187 Err(e) => return e.write_errors().into(),
188 };
189
190 let name = &args.ident;
191 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
192
193 // Calculate base type ID (offset within domain, or absolute if no domain)
194 let base_type_id = args.type_id.unwrap_or_else(|| {
195 use std::collections::hash_map::DefaultHasher;
196 use std::hash::{Hash, Hasher};
197 let mut hasher = DefaultHasher::new();
198 name.to_string().hash(&mut hasher);
199 // If domain is set, hash to a value within 0-99 range
200 if args.domain.is_some() {
201 hasher.finish() % 100
202 } else {
203 hasher.finish()
204 }
205 });
206
207 // Find annotated fields
208 let fields = match &args.data {
209 ast::Data::Struct(fields) => fields,
210 _ => panic!("RingMessage can only be derived for structs"),
211 };
212
213 let mut id_field: Option<&syn::Ident> = None;
214 let mut correlation_field: Option<&syn::Ident> = None;
215 let mut priority_field: Option<&syn::Ident> = None;
216
217 for field in fields.iter() {
218 if field.id {
219 id_field = field.ident.as_ref();
220 }
221 if field.correlation {
222 correlation_field = field.ident.as_ref();
223 }
224 if field.priority {
225 priority_field = field.ident.as_ref();
226 }
227 }
228
229 // Generate message_id method
230 let message_id_impl = if let Some(field) = id_field {
231 quote! { self.#field }
232 } else {
233 quote! { ::ringkernel_core::message::MessageId::new(0) }
234 };
235
236 // Generate correlation_id method
237 let correlation_id_impl = if let Some(field) = correlation_field {
238 quote! { self.#field }
239 } else {
240 quote! { ::ringkernel_core::message::CorrelationId::none() }
241 };
242
243 // Generate priority method
244 let priority_impl = if let Some(field) = priority_field {
245 quote! { self.#field }
246 } else {
247 quote! { ::ringkernel_core::message::Priority::Normal }
248 };
249
250 // Generate message_type() implementation based on whether domain is specified
251 let message_type_impl = if let Some(ref domain_str) = args.domain {
252 // With domain: type_id = domain.base_type_id() + offset
253 quote! {
254 ::ringkernel_core::domain::Domain::from_str(#domain_str)
255 .unwrap_or(::ringkernel_core::domain::Domain::General)
256 .base_type_id() + #base_type_id
257 }
258 } else {
259 // Without domain: use absolute type_id
260 quote! { #base_type_id }
261 };
262
263 // Generate DomainMessage impl if domain is specified
264 let domain_impl = if let Some(ref domain_str) = args.domain {
265 quote! {
266 impl #impl_generics ::ringkernel_core::domain::DomainMessage for #name #ty_generics #where_clause {
267 fn domain() -> ::ringkernel_core::domain::Domain {
268 ::ringkernel_core::domain::Domain::from_str(#domain_str)
269 .unwrap_or(::ringkernel_core::domain::Domain::General)
270 }
271 }
272 }
273 } else {
274 quote! {}
275 };
276
277 // Generate K2K registration if k2k_routable is set
278 let k2k_registration = if args.k2k_routable {
279 let registration_name = format_ident!(
280 "__K2K_MESSAGE_REGISTRATION_{}",
281 name.to_string().to_uppercase()
282 );
283 let type_name_str = name.to_string();
284 let category_tokens = match &args.category {
285 Some(cat) => quote! { ::std::option::Option::Some(#cat) },
286 None => quote! { ::std::option::Option::None },
287 };
288
289 quote! {
290 #[allow(non_upper_case_globals)]
291 #[::inventory::submit]
292 static #registration_name: ::ringkernel_core::k2k::K2KMessageRegistration =
293 ::ringkernel_core::k2k::K2KMessageRegistration {
294 type_id: {
295 // Note: This is a const context, so we use the base calculation
296 // For domain types, we need to add the base manually
297 #base_type_id
298 },
299 type_name: #type_name_str,
300 k2k_routable: true,
301 category: #category_tokens,
302 };
303 }
304 } else {
305 quote! {}
306 };
307
308 let expanded = quote! {
309 impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
310 fn message_type() -> u64 {
311 #message_type_impl
312 }
313
314 fn message_id(&self) -> ::ringkernel_core::message::MessageId {
315 #message_id_impl
316 }
317
318 fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
319 #correlation_id_impl
320 }
321
322 fn priority(&self) -> ::ringkernel_core::message::Priority {
323 #priority_impl
324 }
325
326 fn serialize(&self) -> Vec<u8> {
327 // Use rkyv for serialization with a 4KB scratch buffer
328 // For larger payloads, rkyv will allocate as needed
329 ::rkyv::to_bytes::<_, 4096>(self)
330 .map(|v| v.to_vec())
331 .unwrap_or_default()
332 }
333
334 fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
335 where
336 Self: Sized,
337 {
338 use ::rkyv::Deserialize as _;
339 let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
340 let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
341 .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
342 "rkyv deserialization failed".to_string()
343 ))?;
344 Ok(deserialized)
345 }
346
347 fn size_hint(&self) -> usize {
348 ::std::mem::size_of::<Self>()
349 }
350 }
351
352 #domain_impl
353
354 #k2k_registration
355 };
356
357 TokenStream::from(expanded)
358}
359
360// ============================================================================
361// PersistentMessage Derive Macro
362// ============================================================================
363
364/// Maximum size for inline payload in persistent messages.
365#[allow(dead_code)]
366const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
367
368/// Attributes for the PersistentMessage derive macro.
369#[derive(Debug, FromDeriveInput)]
370#[darling(attributes(persistent_message), supports(struct_named))]
371struct PersistentMessageArgs {
372 ident: syn::Ident,
373 generics: syn::Generics,
374 /// Field data (reserved for future per-field attributes).
375 #[allow(dead_code)]
376 data: ast::Data<(), PersistentMessageField>,
377 /// Handler ID for CUDA dispatch (0-255).
378 handler_id: u32,
379 /// Whether this message type expects a response.
380 #[darling(default)]
381 requires_response: bool,
382}
383
384/// Field attributes for PersistentMessage (reserved for future use).
385#[derive(Debug, FromField)]
386#[darling(attributes(persistent_message))]
387struct PersistentMessageField {
388 /// Field identifier.
389 #[allow(dead_code)]
390 ident: Option<syn::Ident>,
391 /// Field type.
392 #[allow(dead_code)]
393 ty: syn::Type,
394}
395
396/// Derive macro for implementing the PersistentMessage trait.
397///
398/// This macro enables type-based dispatch within persistent GPU kernels by
399/// generating handler_id, inline payload serialization, and deserialization.
400///
401/// # Requirements
402///
403/// The struct must:
404/// - Already implement `RingMessage` (use `#[derive(RingMessage)]`)
405/// - Be `#[repr(C)]` for safe memory layout
406/// - Be `Copy` + `Clone` for inline payload serialization
407///
408/// # Attributes
409///
410/// On the struct:
411/// - `handler_id = N` (required) - CUDA dispatch handler ID (0-255)
412/// - `requires_response = true/false` (optional) - Whether this message expects a response
413///
414/// # Example
415///
416/// ```ignore
417/// use ringkernel_derive::{RingMessage, PersistentMessage};
418///
419/// #[derive(RingMessage, PersistentMessage, Clone, Copy)]
420/// #[repr(C)]
421/// #[message(type_id = 1001)]
422/// #[persistent_message(handler_id = 1, requires_response = true)]
423/// pub struct FraudCheckRequest {
424/// pub transaction_id: u64,
425/// pub amount: f32,
426/// pub account_id: u32,
427/// }
428///
429/// // Generated implementations:
430/// // - handler_id() returns 1
431/// // - requires_response() returns true
432/// // - to_inline_payload() serializes the struct to [u8; 32] if it fits
433/// // - from_inline_payload() deserializes from bytes
434/// // - payload_size() returns the struct size
435/// ```
436///
437/// # Size Validation
438///
439/// For inline payload serialization, structs must be <= 32 bytes.
440/// Larger structs will return `None` from `to_inline_payload()`.
441///
442/// # CUDA Integration
443///
444/// The handler_id maps to a switch case in generated CUDA code:
445///
446/// ```cuda
447/// switch (msg->handler_id) {
448/// case 1: handle_fraud_check(msg, state, response); break;
449/// case 2: handle_aggregate(msg, state, response); break;
450/// // ...
451/// }
452/// ```
453#[proc_macro_derive(PersistentMessage, attributes(persistent_message))]
454pub fn derive_persistent_message(input: TokenStream) -> TokenStream {
455 let input = parse_macro_input!(input as DeriveInput);
456
457 let args = match PersistentMessageArgs::from_derive_input(&input) {
458 Ok(args) => args,
459 Err(e) => return e.write_errors().into(),
460 };
461
462 let name = &args.ident;
463 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
464
465 let handler_id = args.handler_id;
466 let requires_response = args.requires_response;
467
468 // Generate the PersistentMessage implementation
469 let expanded = quote! {
470 impl #impl_generics ::ringkernel_core::persistent_message::PersistentMessage for #name #ty_generics #where_clause {
471 fn handler_id() -> u32 {
472 #handler_id
473 }
474
475 fn requires_response() -> bool {
476 #requires_response
477 }
478
479 fn payload_size() -> usize {
480 ::std::mem::size_of::<Self>()
481 }
482
483 fn to_inline_payload(&self) -> ::std::option::Option<[u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
484 // Only serialize if the struct fits in the inline payload
485 if ::std::mem::size_of::<Self>() > ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE {
486 return ::std::option::Option::None;
487 }
488
489 let mut payload = [0u8; ::ringkernel_core::persistent_message::MAX_INLINE_PAYLOAD_SIZE];
490
491 // Safety: We've verified the struct fits in the payload,
492 // and the struct is repr(C) + Copy
493 unsafe {
494 ::std::ptr::copy_nonoverlapping(
495 self as *const Self as *const u8,
496 payload.as_mut_ptr(),
497 ::std::mem::size_of::<Self>()
498 );
499 }
500
501 ::std::option::Option::Some(payload)
502 }
503
504 fn from_inline_payload(payload: &[u8]) -> ::ringkernel_core::error::Result<Self> {
505 let size = ::std::mem::size_of::<Self>();
506
507 if payload.len() < size {
508 return ::std::result::Result::Err(
509 ::ringkernel_core::error::RingKernelError::DeserializationError(
510 ::std::format!(
511 "Payload too small: expected {} bytes, got {}",
512 size,
513 payload.len()
514 )
515 )
516 );
517 }
518
519 // Safety: We've verified the payload is large enough,
520 // and the struct is repr(C) + Copy
521 let value = unsafe {
522 ::std::ptr::read(payload.as_ptr() as *const Self)
523 };
524
525 ::std::result::Result::Ok(value)
526 }
527 }
528 };
529
530 TokenStream::from(expanded)
531}
532
533/// Attributes for the ring_kernel macro.
534#[derive(Debug, FromMeta)]
535struct RingKernelArgs {
536 /// Kernel identifier.
537 id: String,
538 /// Execution mode (persistent or event_driven).
539 #[darling(default)]
540 mode: Option<String>,
541 /// Grid size.
542 #[darling(default)]
543 grid_size: Option<u32>,
544 /// Block size.
545 #[darling(default)]
546 block_size: Option<u32>,
547 /// Target kernels this kernel publishes to.
548 #[darling(default)]
549 publishes_to: Option<String>,
550}
551
552/// Attribute macro for defining ring kernel handlers.
553///
554/// # Attributes
555///
556/// - `id` (required) - Unique kernel identifier
557/// - `mode` - Execution mode: "persistent" (default) or "event_driven"
558/// - `grid_size` - Number of blocks (default: 1)
559/// - `block_size` - Threads per block (default: 256)
560/// - `publishes_to` - Comma-separated list of target kernel IDs
561///
562/// # Example
563///
564/// ```ignore
565/// #[ring_kernel(id = "processor", mode = "persistent", block_size = 128)]
566/// async fn handle(ctx: &mut RingContext, msg: MyMessage) -> MyResponse {
567/// // Process message
568/// MyResponse { ... }
569/// }
570/// ```
571#[proc_macro_attribute]
572pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
573 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
574 Ok(v) => v,
575 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
576 };
577
578 let args = match RingKernelArgs::from_list(&args) {
579 Ok(v) => v,
580 Err(e) => return TokenStream::from(e.write_errors()),
581 };
582
583 let input = parse_macro_input!(item as ItemFn);
584
585 let kernel_id = &args.id;
586 let fn_name = &input.sig.ident;
587 let fn_vis = &input.vis;
588 let fn_block = &input.block;
589 let fn_attrs = &input.attrs;
590
591 // Parse function signature
592 let inputs = &input.sig.inputs;
593 let output = &input.sig.output;
594
595 // Extract context and message types from signature
596 let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
597 let ctx = inputs.first();
598 let msg = inputs.iter().nth(1);
599 (ctx, msg)
600 } else {
601 (None, None)
602 };
603
604 // Get message type
605 let msg_type = msg_arg
606 .map(|arg| {
607 if let syn::FnArg::Typed(pat_type) = arg {
608 pat_type.ty.clone()
609 } else {
610 syn::parse_quote!(())
611 }
612 })
613 .unwrap_or_else(|| syn::parse_quote!(()));
614
615 // Generate kernel mode
616 let mode = args.mode.as_deref().unwrap_or("persistent");
617 let mode_expr = if mode == "event_driven" {
618 quote! { ::ringkernel_core::types::KernelMode::EventDriven }
619 } else {
620 quote! { ::ringkernel_core::types::KernelMode::Persistent }
621 };
622
623 // Generate grid/block size
624 let grid_size = args.grid_size.unwrap_or(1);
625 let block_size = args.block_size.unwrap_or(256);
626
627 // Parse publishes_to into a list of target kernel IDs
628 let publishes_to_targets: Vec<String> = args
629 .publishes_to
630 .as_ref()
631 .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
632 .unwrap_or_default();
633
634 // Generate registration struct name
635 let registration_name = format_ident!(
636 "__RINGKERNEL_REGISTRATION_{}",
637 fn_name.to_string().to_uppercase()
638 );
639 let handler_name = format_ident!("{}_handler", fn_name);
640
641 // Generate the expanded code
642 let expanded = quote! {
643 // Original function (preserved for documentation/testing)
644 #(#fn_attrs)*
645 #fn_vis async fn #fn_name #inputs #output #fn_block
646
647 // Kernel handler wrapper
648 #fn_vis fn #handler_name(
649 ctx: &mut ::ringkernel_core::RingContext<'_>,
650 envelope: ::ringkernel_core::message::MessageEnvelope,
651 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
652 Box::pin(async move {
653 // Deserialize input message
654 let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
655
656 // Call the actual handler
657 let response = #fn_name(ctx, msg).await;
658
659 // Serialize response
660 let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
661 let response_header = ::ringkernel_core::message::MessageHeader::new(
662 <_ as ::ringkernel_core::message::RingMessage>::message_type(),
663 envelope.header.dest_kernel,
664 envelope.header.source_kernel,
665 response_payload.len(),
666 ctx.now(),
667 ).with_correlation(envelope.header.correlation_id);
668
669 Ok(::ringkernel_core::message::MessageEnvelope {
670 header: response_header,
671 payload: response_payload,
672 })
673 })
674 }
675
676 // Kernel registration
677 #[allow(non_upper_case_globals)]
678 #[::inventory::submit]
679 static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
680 id: #kernel_id,
681 mode: #mode_expr,
682 grid_size: #grid_size,
683 block_size: #block_size,
684 publishes_to: &[#(#publishes_to_targets),*],
685 };
686 };
687
688 TokenStream::from(expanded)
689}
690
691/// Derive macro for GPU-compatible types.
692///
693/// Ensures the type has a stable memory layout suitable for GPU transfer.
694#[proc_macro_derive(GpuType)]
695pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
696 let input = parse_macro_input!(input as DeriveInput);
697 let name = &input.ident;
698 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
699
700 // Generate assertions for GPU compatibility
701 let expanded = quote! {
702 // Verify type is Copy (required for GPU transfer)
703 const _: fn() = || {
704 fn assert_copy<T: Copy>() {}
705 assert_copy::<#name #ty_generics>();
706 };
707
708 // Verify type is Pod (plain old data)
709 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
710 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
711 };
712
713 TokenStream::from(expanded)
714}
715
716// ============================================================================
717// Stencil Kernel Macro (requires cuda-codegen feature)
718// ============================================================================
719
720/// Attributes for the stencil_kernel macro.
721#[derive(Debug, FromMeta)]
722struct StencilKernelArgs {
723 /// Kernel identifier.
724 id: String,
725 /// Grid dimensionality: "1d", "2d", or "3d".
726 #[darling(default)]
727 grid: Option<String>,
728 /// Tile/block size (single value for square tiles).
729 #[darling(default)]
730 tile_size: Option<u32>,
731 /// Tile width (for non-square tiles).
732 #[darling(default)]
733 tile_width: Option<u32>,
734 /// Tile height (for non-square tiles).
735 #[darling(default)]
736 tile_height: Option<u32>,
737 /// Halo/ghost cell width (stencil radius).
738 #[darling(default)]
739 halo: Option<u32>,
740}
741
742/// Attribute macro for defining stencil kernels that transpile to CUDA.
743///
744/// This macro generates CUDA C code from Rust stencil kernel functions at compile time.
745/// The generated CUDA source is embedded in the binary and can be compiled at runtime
746/// using NVRTC.
747///
748/// # Attributes
749///
750/// - `id` (required) - Unique kernel identifier
751/// - `grid` - Grid dimensionality: "1d", "2d" (default), or "3d"
752/// - `tile_size` - Tile/block size (default: 16)
753/// - `tile_width` / `tile_height` - Non-square tile dimensions
754/// - `halo` - Stencil radius / ghost cell width (default: 1)
755///
756/// # Supported Rust Subset
757///
758/// - Primitives: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
759/// - Slices: `&[T]`, `&mut [T]`
760/// - Arithmetic: `+`, `-`, `*`, `/`, `%`
761/// - Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=`
762/// - Let bindings: `let x = expr;`
763/// - If/else: `if cond { a } else { b }`
764/// - Stencil intrinsics via `GridPos`
765///
766/// # Example
767///
768/// ```ignore
769/// use ringkernel_derive::stencil_kernel;
770/// use ringkernel_cuda_codegen::GridPos;
771///
772/// #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
773/// fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
774/// let curr = p[pos.idx()];
775/// let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
776/// p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
777/// }
778///
779/// // Access generated CUDA source:
780/// assert!(FDTD_CUDA_SOURCE.contains("__global__"));
781/// ```
782#[proc_macro_attribute]
783pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
784 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
785 Ok(v) => v,
786 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
787 };
788
789 let args = match StencilKernelArgs::from_list(&args) {
790 Ok(v) => v,
791 Err(e) => return TokenStream::from(e.write_errors()),
792 };
793
794 let input = parse_macro_input!(item as ItemFn);
795
796 // Generate the stencil kernel code
797 stencil_kernel_impl(args, input)
798}
799
800fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
801 let kernel_id = &args.id;
802 let fn_name = &input.sig.ident;
803 let fn_vis = &input.vis;
804 let fn_block = &input.block;
805 let fn_inputs = &input.sig.inputs;
806 let fn_output = &input.sig.output;
807 let fn_attrs = &input.attrs;
808
809 // Parse configuration
810 let grid = args.grid.as_deref().unwrap_or("2d");
811 let tile_width = args
812 .tile_width
813 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
814 let tile_height = args
815 .tile_height
816 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
817 let halo = args.halo.unwrap_or(1);
818
819 // Generate CUDA source constant name
820 let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
821
822 // Generate registration name
823 let registration_name = format_ident!(
824 "__STENCIL_KERNEL_REGISTRATION_{}",
825 fn_name.to_string().to_uppercase()
826 );
827
828 // Transpile to CUDA (if feature enabled)
829 #[cfg(feature = "cuda-codegen")]
830 let cuda_source_code = {
831 use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
832
833 let grid_type = match grid {
834 "1d" => Grid::Grid1D,
835 "2d" => Grid::Grid2D,
836 "3d" => Grid::Grid3D,
837 _ => Grid::Grid2D,
838 };
839
840 let config = StencilConfig::new(kernel_id.clone())
841 .with_grid(grid_type)
842 .with_tile_size(tile_width as usize, tile_height as usize)
843 .with_halo(halo as usize);
844
845 match transpile_stencil_kernel(&input, &config) {
846 Ok(cuda) => cuda,
847 Err(e) => {
848 return TokenStream::from(
849 syn::Error::new_spanned(
850 &input.sig.ident,
851 format!("CUDA transpilation failed: {}", e),
852 )
853 .to_compile_error(),
854 );
855 }
856 }
857 };
858
859 #[cfg(not(feature = "cuda-codegen"))]
860 let cuda_source_code = format!(
861 "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
862 kernel_id
863 );
864
865 // Generate the expanded code
866 let expanded = quote! {
867 // Original function (for documentation/testing/CPU fallback)
868 #(#fn_attrs)*
869 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
870
871 /// Generated CUDA source code for this stencil kernel.
872 #fn_vis const #cuda_const_name: &str = #cuda_source_code;
873
874 /// Stencil kernel registration for runtime discovery.
875 #[allow(non_upper_case_globals)]
876 #[::inventory::submit]
877 static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
878 ::ringkernel_core::__private::StencilKernelRegistration {
879 id: #kernel_id,
880 grid: #grid,
881 tile_width: #tile_width,
882 tile_height: #tile_height,
883 halo: #halo,
884 cuda_source: #cuda_source_code,
885 };
886 };
887
888 TokenStream::from(expanded)
889}
890
891// ============================================================================
892// Multi-Backend GPU Kernel Macro
893// ============================================================================
894
895/// GPU backend targets (internal use only).
896#[derive(Debug, Clone, Copy, PartialEq, Eq)]
897enum GpuBackend {
898 /// NVIDIA CUDA backend.
899 Cuda,
900 /// Apple Metal backend.
901 Metal,
902 /// WebGPU backend (cross-platform).
903 Wgpu,
904 /// CPU fallback backend.
905 Cpu,
906}
907
908impl GpuBackend {
909 fn from_str(s: &str) -> Option<Self> {
910 match s.to_lowercase().as_str() {
911 "cuda" => Some(Self::Cuda),
912 "metal" => Some(Self::Metal),
913 "wgpu" | "webgpu" => Some(Self::Wgpu),
914 "cpu" => Some(Self::Cpu),
915 _ => None,
916 }
917 }
918
919 fn as_str(&self) -> &'static str {
920 match self {
921 Self::Cuda => "cuda",
922 Self::Metal => "metal",
923 Self::Wgpu => "wgpu",
924 Self::Cpu => "cpu",
925 }
926 }
927}
928
929/// GPU capability flags that can be required by a kernel (internal use only).
930#[derive(Debug, Clone, Copy, PartialEq, Eq)]
931enum GpuCapability {
932 /// 64-bit floating point support.
933 Float64,
934 /// 64-bit integer support.
935 Int64,
936 /// 64-bit atomics support.
937 Atomic64,
938 /// Cooperative groups / grid-wide sync.
939 CooperativeGroups,
940 /// Subgroup / warp / SIMD operations.
941 Subgroups,
942 /// Shared memory / threadgroup memory.
943 SharedMemory,
944 /// Dynamic parallelism (launching kernels from kernels).
945 DynamicParallelism,
946 /// Half-precision (f16) support.
947 Float16,
948}
949
950impl GpuCapability {
951 fn from_str(s: &str) -> Option<Self> {
952 match s.to_lowercase().as_str() {
953 "f64" | "float64" => Some(Self::Float64),
954 "i64" | "int64" => Some(Self::Int64),
955 "atomic64" => Some(Self::Atomic64),
956 "cooperative_groups" | "cooperativegroups" | "grid_sync" => {
957 Some(Self::CooperativeGroups)
958 }
959 "subgroups" | "warp" | "simd" => Some(Self::Subgroups),
960 "shared_memory" | "sharedmemory" | "threadgroup" => Some(Self::SharedMemory),
961 "dynamic_parallelism" | "dynamicparallelism" => Some(Self::DynamicParallelism),
962 "f16" | "float16" | "half" => Some(Self::Float16),
963 _ => None,
964 }
965 }
966
967 fn as_str(&self) -> &'static str {
968 match self {
969 Self::Float64 => "f64",
970 Self::Int64 => "i64",
971 Self::Atomic64 => "atomic64",
972 Self::CooperativeGroups => "cooperative_groups",
973 Self::Subgroups => "subgroups",
974 Self::SharedMemory => "shared_memory",
975 Self::DynamicParallelism => "dynamic_parallelism",
976 Self::Float16 => "f16",
977 }
978 }
979
980 /// Check if a backend supports this capability.
981 fn supported_by(&self, backend: GpuBackend) -> bool {
982 match (self, backend) {
983 // CUDA supports everything
984 (_, GpuBackend::Cuda) => true,
985
986 // Metal capabilities
987 (Self::Float64, GpuBackend::Metal) => false,
988 (Self::CooperativeGroups, GpuBackend::Metal) => false,
989 (Self::DynamicParallelism, GpuBackend::Metal) => false,
990 (_, GpuBackend::Metal) => true,
991
992 // WebGPU capabilities
993 (Self::Float64, GpuBackend::Wgpu) => false,
994 (Self::Int64, GpuBackend::Wgpu) => false,
995 (Self::Atomic64, GpuBackend::Wgpu) => false, // Emulated only
996 (Self::CooperativeGroups, GpuBackend::Wgpu) => false,
997 (Self::DynamicParallelism, GpuBackend::Wgpu) => false,
998 (Self::Subgroups, GpuBackend::Wgpu) => true, // Optional extension
999 (_, GpuBackend::Wgpu) => true,
1000
1001 // CPU supports everything (in emulation)
1002 (_, GpuBackend::Cpu) => true,
1003 }
1004 }
1005}
1006
1007/// Attributes for the gpu_kernel macro.
1008#[derive(Debug)]
1009struct GpuKernelArgs {
1010 /// Kernel identifier.
1011 id: Option<String>,
1012 /// Target backends to generate code for.
1013 backends: Vec<GpuBackend>,
1014 /// Fallback order for backend selection.
1015 fallback: Vec<GpuBackend>,
1016 /// Required capabilities.
1017 requires: Vec<GpuCapability>,
1018 /// Block/workgroup size.
1019 block_size: Option<u32>,
1020}
1021
1022impl Default for GpuKernelArgs {
1023 fn default() -> Self {
1024 Self {
1025 id: None,
1026 backends: vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::Wgpu],
1027 fallback: vec![
1028 GpuBackend::Cuda,
1029 GpuBackend::Metal,
1030 GpuBackend::Wgpu,
1031 GpuBackend::Cpu,
1032 ],
1033 requires: Vec::new(),
1034 block_size: None,
1035 }
1036 }
1037}
1038
1039impl GpuKernelArgs {
1040 fn parse(attr: proc_macro2::TokenStream) -> Result<Self, darling::Error> {
1041 let mut args = Self::default();
1042 let attr_str = attr.to_string();
1043
1044 // Parse backends = [...]
1045 if let Some(start) = attr_str.find("backends") {
1046 if let Some(bracket_start) = attr_str[start..].find('[') {
1047 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1048 let backends_str =
1049 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1050 args.backends = backends_str
1051 .split(',')
1052 .filter_map(|s| GpuBackend::from_str(s.trim()))
1053 .collect();
1054 }
1055 }
1056 }
1057
1058 // Parse fallback = [...]
1059 if let Some(start) = attr_str.find("fallback") {
1060 if let Some(bracket_start) = attr_str[start..].find('[') {
1061 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1062 let fallback_str =
1063 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1064 args.fallback = fallback_str
1065 .split(',')
1066 .filter_map(|s| GpuBackend::from_str(s.trim()))
1067 .collect();
1068 }
1069 }
1070 }
1071
1072 // Parse requires = [...]
1073 if let Some(start) = attr_str.find("requires") {
1074 if let Some(bracket_start) = attr_str[start..].find('[') {
1075 if let Some(bracket_end) = attr_str[start + bracket_start..].find(']') {
1076 let requires_str =
1077 &attr_str[start + bracket_start + 1..start + bracket_start + bracket_end];
1078 args.requires = requires_str
1079 .split(',')
1080 .filter_map(|s| GpuCapability::from_str(s.trim()))
1081 .collect();
1082 }
1083 }
1084 }
1085
1086 // Parse id = "..."
1087 if let Some(start) = attr_str.find("id") {
1088 if let Some(quote_start) = attr_str[start..].find('"') {
1089 if let Some(quote_end) = attr_str[start + quote_start + 1..].find('"') {
1090 args.id = Some(
1091 attr_str[start + quote_start + 1..start + quote_start + 1 + quote_end]
1092 .to_string(),
1093 );
1094 }
1095 }
1096 }
1097
1098 // Parse block_size = N
1099 if let Some(start) = attr_str.find("block_size") {
1100 if let Some(eq) = attr_str[start..].find('=') {
1101 let rest = &attr_str[start + eq + 1..];
1102 let num_end = rest
1103 .find(|c: char| !c.is_numeric() && c != ' ')
1104 .unwrap_or(rest.len());
1105 if let Ok(n) = rest[..num_end].trim().parse() {
1106 args.block_size = Some(n);
1107 }
1108 }
1109 }
1110
1111 Ok(args)
1112 }
1113
1114 /// Validate that all required capabilities are supported by at least one backend.
1115 fn validate_capabilities(&self) -> Result<(), String> {
1116 for cap in &self.requires {
1117 let mut supported_by_any = false;
1118 for backend in &self.backends {
1119 if cap.supported_by(*backend) {
1120 supported_by_any = true;
1121 break;
1122 }
1123 }
1124 if !supported_by_any {
1125 return Err(format!(
1126 "Capability '{}' is not supported by any of the specified backends: {:?}",
1127 cap.as_str(),
1128 self.backends.iter().map(|b| b.as_str()).collect::<Vec<_>>()
1129 ));
1130 }
1131 }
1132 Ok(())
1133 }
1134
1135 /// Get backends that support all required capabilities.
1136 fn compatible_backends(&self) -> Vec<GpuBackend> {
1137 self.backends
1138 .iter()
1139 .filter(|backend| self.requires.iter().all(|cap| cap.supported_by(**backend)))
1140 .copied()
1141 .collect()
1142 }
1143}
1144
1145/// Attribute macro for defining multi-backend GPU kernels.
1146///
1147/// This macro generates code for multiple GPU backends with compile-time
1148/// capability validation. It integrates with the `ringkernel-ir` crate
1149/// to lower Rust DSL to backend-specific shader code.
1150///
1151/// # Attributes
1152///
1153/// - `backends = [cuda, metal, wgpu]` - Target backends (default: all)
1154/// - `fallback = [cuda, metal, wgpu, cpu]` - Fallback order for runtime selection
1155/// - `requires = [f64, atomic64]` - Required capabilities (validated at compile time)
1156/// - `id = "kernel_name"` - Explicit kernel identifier
1157/// - `block_size = 256` - Thread block size
1158///
1159/// # Example
1160///
1161/// ```ignore
1162/// use ringkernel_derive::gpu_kernel;
1163///
1164/// #[gpu_kernel(backends = [cuda, metal], requires = [subgroups])]
1165/// fn warp_reduce(data: &mut [f32], n: i32) {
1166/// let idx = global_thread_id_x();
1167/// if idx < n {
1168/// // Use warp shuffle for reduction
1169/// let val = data[idx as usize];
1170/// let reduced = warp_reduce_sum(val);
1171/// if lane_id() == 0 {
1172/// data[idx as usize] = reduced;
1173/// }
1174/// }
1175/// }
1176/// ```
1177///
1178/// # Capability Checking
1179///
1180/// The macro validates at compile time that all required capabilities are
1181/// supported by at least one target backend:
1182///
1183/// | Capability | CUDA | Metal | WebGPU | CPU |
1184/// |------------|------|-------|--------|-----|
1185/// | f64 | Yes | No | No | Yes |
1186/// | i64 | Yes | Yes | No | Yes |
1187/// | atomic64 | Yes | Yes | No* | Yes |
1188/// | cooperative_groups | Yes | No | No | Yes |
1189/// | subgroups | Yes | Yes | Opt | Yes |
1190/// | shared_memory | Yes | Yes | Yes | Yes |
1191/// | f16 | Yes | Yes | Yes | Yes |
1192///
1193/// *WebGPU emulates 64-bit atomics with 32-bit pairs.
1194///
1195/// # Generated Code
1196///
1197/// For each compatible backend, the macro generates:
1198/// - Backend-specific source code constant (e.g., `KERNEL_NAME_CUDA_SOURCE`)
1199/// - Registration entry for runtime discovery
1200/// - CPU fallback function (if `cpu_fallback = true`)
1201#[proc_macro_attribute]
1202pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
1203 let attr2: proc_macro2::TokenStream = attr.into();
1204 let args = match GpuKernelArgs::parse(attr2) {
1205 Ok(args) => args,
1206 Err(e) => return TokenStream::from(e.write_errors()),
1207 };
1208
1209 let input = parse_macro_input!(item as ItemFn);
1210
1211 // Validate capabilities
1212 if let Err(msg) = args.validate_capabilities() {
1213 return TokenStream::from(
1214 syn::Error::new_spanned(&input.sig.ident, msg).to_compile_error(),
1215 );
1216 }
1217
1218 gpu_kernel_impl(args, input)
1219}
1220
1221fn gpu_kernel_impl(args: GpuKernelArgs, input: ItemFn) -> TokenStream {
1222 let fn_name = &input.sig.ident;
1223 let fn_vis = &input.vis;
1224 let fn_block = &input.block;
1225 let fn_inputs = &input.sig.inputs;
1226 let fn_output = &input.sig.output;
1227 let fn_attrs = &input.attrs;
1228
1229 let kernel_id = args.id.clone().unwrap_or_else(|| fn_name.to_string());
1230 let block_size = args.block_size.unwrap_or(256);
1231
1232 // Get compatible backends
1233 let compatible_backends = args.compatible_backends();
1234
1235 // Generate backend-specific source constants
1236 let mut source_constants = Vec::new();
1237
1238 for backend in &compatible_backends {
1239 let const_name = format_ident!(
1240 "{}_{}",
1241 fn_name.to_string().to_uppercase(),
1242 backend.as_str().to_uppercase()
1243 );
1244
1245 let backend_str = backend.as_str();
1246
1247 // Generate placeholder source (actual IR lowering happens at build time)
1248 // In a full implementation, this would call ringkernel-ir lowering
1249 let source_placeholder = format!(
1250 "// {} source for kernel '{}'\n// Generated by ringkernel-derive\n// Capabilities: {:?}\n",
1251 backend_str.to_uppercase(),
1252 kernel_id,
1253 args.requires.iter().map(|c| c.as_str()).collect::<Vec<_>>()
1254 );
1255
1256 source_constants.push(quote! {
1257 /// Generated source code for this kernel.
1258 #fn_vis const #const_name: &str = #source_placeholder;
1259 });
1260 }
1261
1262 // Generate capability flags as strings
1263 let capability_strs: Vec<_> = args.requires.iter().map(|c| c.as_str()).collect();
1264 let backend_strs: Vec<_> = compatible_backends.iter().map(|b| b.as_str()).collect();
1265 let fallback_strs: Vec<_> = args.fallback.iter().map(|b| b.as_str()).collect();
1266
1267 // Generate registration struct name
1268 let registration_name = format_ident!(
1269 "__GPU_KERNEL_REGISTRATION_{}",
1270 fn_name.to_string().to_uppercase()
1271 );
1272
1273 // Generate info struct name
1274 let info_name = format_ident!("{}_INFO", fn_name.to_string().to_uppercase());
1275
1276 // Generate the expanded code
1277 let expanded = quote! {
1278 // Original function (CPU fallback / documentation / testing)
1279 #(#fn_attrs)*
1280 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
1281
1282 // Backend source constants
1283 #(#source_constants)*
1284
1285 /// Multi-backend kernel information.
1286 #fn_vis mod #info_name {
1287 /// Kernel identifier.
1288 pub const ID: &str = #kernel_id;
1289
1290 /// Block/workgroup size.
1291 pub const BLOCK_SIZE: u32 = #block_size;
1292
1293 /// Required capabilities.
1294 pub const CAPABILITIES: &[&str] = &[#(#capability_strs),*];
1295
1296 /// Compatible backends (those that support all required capabilities).
1297 pub const BACKENDS: &[&str] = &[#(#backend_strs),*];
1298
1299 /// Fallback order for runtime backend selection.
1300 pub const FALLBACK_ORDER: &[&str] = &[#(#fallback_strs),*];
1301 }
1302
1303 /// GPU kernel registration for runtime discovery.
1304 #[allow(non_upper_case_globals)]
1305 #[::inventory::submit]
1306 static #registration_name: ::ringkernel_core::__private::GpuKernelRegistration =
1307 ::ringkernel_core::__private::GpuKernelRegistration {
1308 id: #kernel_id,
1309 block_size: #block_size,
1310 capabilities: &[#(#capability_strs),*],
1311 backends: &[#(#backend_strs),*],
1312 fallback_order: &[#(#fallback_strs),*],
1313 };
1314 };
1315
1316 TokenStream::from(expanded)
1317}
1318
1319// ============================================================================
1320// ControlBlockState Derive Macro (FR-4)
1321// ============================================================================
1322
1323/// Attributes for the ControlBlockState derive macro.
1324#[derive(Debug, FromDeriveInput)]
1325#[darling(attributes(state), supports(struct_named))]
1326struct ControlBlockStateArgs {
1327 ident: syn::Ident,
1328 generics: syn::Generics,
1329 /// State version for forward compatibility.
1330 #[darling(default)]
1331 version: Option<u32>,
1332}
1333
1334/// Derive macro for implementing EmbeddedState trait.
1335///
1336/// This macro generates implementations for types that can be stored in
1337/// the ControlBlock's 24-byte `_reserved` field for zero-copy state access.
1338///
1339/// # Requirements
1340///
1341/// The type must:
1342/// - Be `#[repr(C)]` for stable memory layout
1343/// - Be <= 24 bytes in size (checked at compile time)
1344/// - Implement `Clone`, `Copy`, and `Default`
1345/// - Contain only POD (Plain Old Data) types
1346///
1347/// # Attributes
1348///
1349/// - `#[state(version = N)]` - Set state version for migrations (default: 1)
1350///
1351/// # Example
1352///
1353/// ```ignore
1354/// #[derive(ControlBlockState, Default, Clone, Copy)]
1355/// #[repr(C, align(8))]
1356/// #[state(version = 1)]
1357/// pub struct OrderBookState {
1358/// pub best_bid: u64, // 8 bytes
1359/// pub best_ask: u64, // 8 bytes
1360/// pub order_count: u32, // 4 bytes
1361/// pub _pad: u32, // 4 bytes (padding for alignment)
1362/// } // Total: 24 bytes - fits in ControlBlock._reserved
1363///
1364/// // Use with ControlBlockStateHelper:
1365/// let mut block = ControlBlock::new();
1366/// let state = OrderBookState { best_bid: 100, best_ask: 101, order_count: 42, _pad: 0 };
1367/// ControlBlockStateHelper::write_embedded(&mut block, &state)?;
1368/// ```
1369///
1370/// # Size Validation
1371///
1372/// The macro generates a compile-time assertion that fails if the type
1373/// exceeds 24 bytes:
1374///
1375/// ```ignore
1376/// #[derive(ControlBlockState, Default, Clone, Copy)]
1377/// #[repr(C)]
1378/// struct TooLarge {
1379/// data: [u8; 32], // 32 bytes - COMPILE ERROR!
1380/// }
1381/// ```
1382#[proc_macro_derive(ControlBlockState, attributes(state))]
1383pub fn derive_control_block_state(input: TokenStream) -> TokenStream {
1384 let input = parse_macro_input!(input as DeriveInput);
1385
1386 let args = match ControlBlockStateArgs::from_derive_input(&input) {
1387 Ok(args) => args,
1388 Err(e) => return e.write_errors().into(),
1389 };
1390
1391 let name = &args.ident;
1392 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
1393 let version = args.version.unwrap_or(1);
1394
1395 let expanded = quote! {
1396 // Compile-time size check: EmbeddedState must fit in 24 bytes
1397 const _: () = {
1398 assert!(
1399 ::std::mem::size_of::<#name #ty_generics>() <= 24,
1400 "ControlBlockState types must fit in 24 bytes (ControlBlock._reserved size)"
1401 );
1402 };
1403
1404 // Verify type is Copy (required for GPU transfer)
1405 const _: fn() = || {
1406 fn assert_copy<T: Copy>() {}
1407 assert_copy::<#name #ty_generics>();
1408 };
1409
1410 // Implement Pod and Zeroable (required by EmbeddedState)
1411 // SAFETY: Type is #[repr(C)] with only primitive types, verified by user
1412 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
1413 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
1414
1415 // Implement EmbeddedState
1416 impl #impl_generics ::ringkernel_core::state::EmbeddedState for #name #ty_generics #where_clause {
1417 const VERSION: u32 = #version;
1418
1419 fn is_embedded() -> bool {
1420 true
1421 }
1422 }
1423 };
1424
1425 TokenStream::from(expanded)
1426}