Skip to main content

ringkernel_cuda_codegen/
ring_kernel.rs

1//! Ring kernel configuration and code generation for persistent actor kernels.
2//!
3//! This module provides the infrastructure for generating CUDA code for
4//! persistent actor kernels that process messages in a loop.
5//!
6//! # Overview
7//!
8//! Ring kernels are persistent GPU kernels that:
9//! - Run continuously until terminated
10//! - Process messages from input queues
11//! - Produce responses to output queues
12//! - Use HLC (Hybrid Logical Clocks) for causal ordering
13//! - Support kernel-to-kernel (K2K) messaging
14//!
15//! # Generated Kernel Structure
16//!
17//! ```cuda
18//! extern "C" __global__ void ring_kernel_NAME(
19//!     ControlBlock* __restrict__ control,
20//!     unsigned char* __restrict__ input_buffer,
21//!     unsigned char* __restrict__ output_buffer,
22//!     void* __restrict__ shared_state
23//! ) {
24//!     // Preamble: thread setup
25//!     int tid = threadIdx.x + blockIdx.x * blockDim.x;
26//!
27//!     // Persistent message loop
28//!     while (!atomicLoad(&control->should_terminate)) {
29//!         if (!atomicLoad(&control->is_active)) {
30//!             __nanosleep(1000);
31//!             continue;
32//!         }
33//!
34//!         // Message processing...
35//!         // (user handler code inserted here)
36//!     }
37//!
38//!     // Epilogue: mark terminated
39//!     if (tid == 0) {
40//!         atomicStore(&control->has_terminated, 1);
41//!     }
42//! }
43//! ```
44
45use std::fmt::Write;
46
47use crate::reduction_intrinsics::ReductionOp;
48
49/// Configuration for global reductions in a ring kernel.
50///
51/// This enables algorithms like PageRank to perform global reductions
52/// (e.g., computing dangling node sum) within the kernel using a two-phase
53/// reduce-and-broadcast pattern.
54#[derive(Debug, Clone, Default)]
55pub struct KernelReductionConfig {
56    /// Whether global reduction is enabled.
57    pub enabled: bool,
58    /// The reduction operation to perform.
59    pub op: ReductionOp,
60    /// Type of the accumulator (e.g., "double", "float", "int").
61    pub accumulator_type: String,
62    /// Use cooperative groups for grid-wide sync (requires CC 6.0+).
63    /// If false, falls back to software barrier.
64    pub use_cooperative: bool,
65    /// Name of the shared memory array for block reduction.
66    pub shared_array_name: String,
67    /// Name of the global accumulator variable.
68    pub accumulator_name: String,
69}
70
71impl KernelReductionConfig {
72    /// Create a new reduction configuration.
73    pub fn new() -> Self {
74        Self {
75            enabled: false,
76            op: ReductionOp::Sum,
77            accumulator_type: "double".to_string(),
78            use_cooperative: true,
79            shared_array_name: "__reduction_shared".to_string(),
80            accumulator_name: "__reduction_accumulator".to_string(),
81        }
82    }
83
84    /// Enable reduction with the specified operation.
85    pub fn with_op(mut self, op: ReductionOp) -> Self {
86        self.enabled = true;
87        self.op = op;
88        self
89    }
90
91    /// Set the accumulator type.
92    pub fn with_type(mut self, ty: &str) -> Self {
93        self.accumulator_type = ty.to_string();
94        self
95    }
96
97    /// Enable or disable cooperative groups.
98    pub fn with_cooperative(mut self, use_cooperative: bool) -> Self {
99        self.use_cooperative = use_cooperative;
100        self
101    }
102
103    /// Set the shared memory array name.
104    pub fn with_shared_name(mut self, name: &str) -> Self {
105        self.shared_array_name = name.to_string();
106        self
107    }
108
109    /// Set the accumulator variable name.
110    pub fn with_accumulator_name(mut self, name: &str) -> Self {
111        self.accumulator_name = name.to_string();
112        self
113    }
114
115    /// Generate the shared memory declaration for block reduction.
116    pub fn generate_shared_declaration(&self, block_size: u32) -> String {
117        if !self.enabled {
118            return String::new();
119        }
120        format!(
121            "    __shared__ {} {}[{}];",
122            self.accumulator_type, self.shared_array_name, block_size
123        )
124    }
125
126    /// Generate the accumulator parameter declaration.
127    pub fn generate_accumulator_param(&self) -> String {
128        if !self.enabled {
129            return String::new();
130        }
131        format!(
132            "    {}* __restrict__ {},",
133            self.accumulator_type, self.accumulator_name
134        )
135    }
136}
137
138/// Configuration for a ring kernel.
139#[derive(Debug, Clone)]
140pub struct RingKernelConfig {
141    /// Kernel identifier (used in function name).
142    pub id: String,
143    /// Block size (threads per block).
144    pub block_size: u32,
145    /// Input queue capacity (must be power of 2).
146    pub queue_capacity: u32,
147    /// Enable kernel-to-kernel messaging.
148    pub enable_k2k: bool,
149    /// Enable HLC clock operations.
150    pub enable_hlc: bool,
151    /// Message size in bytes (for buffer offset calculations).
152    pub message_size: usize,
153    /// Response size in bytes.
154    pub response_size: usize,
155    /// Use cooperative thread groups.
156    pub cooperative_groups: bool,
157    /// Nanosleep duration when idle (0 to spin).
158    pub idle_sleep_ns: u32,
159    /// Use MessageEnvelope format (256-byte header + payload).
160    /// When enabled, messages in queues are full envelopes with headers.
161    pub use_envelope_format: bool,
162    /// Kernel numeric ID (for response routing).
163    pub kernel_id_num: u64,
164    /// HLC node ID for this kernel.
165    pub hlc_node_id: u64,
166    /// Configuration for global reductions (e.g., for PageRank dangling sum).
167    pub reduction: KernelReductionConfig,
168}
169
170impl Default for RingKernelConfig {
171    fn default() -> Self {
172        Self {
173            id: "ring_kernel".to_string(),
174            block_size: 128,
175            queue_capacity: 1024,
176            enable_k2k: false,
177            enable_hlc: true,
178            message_size: 64,
179            response_size: 64,
180            cooperative_groups: false,
181            idle_sleep_ns: 1000,
182            use_envelope_format: true, // Default to envelope format for proper serialization
183            kernel_id_num: 0,
184            hlc_node_id: 0,
185            reduction: KernelReductionConfig::new(),
186        }
187    }
188}
189
190impl RingKernelConfig {
191    /// Create a new ring kernel configuration with the given ID.
192    pub fn new(id: impl Into<String>) -> Self {
193        Self {
194            id: id.into(),
195            ..Default::default()
196        }
197    }
198
199    /// Set the block size.
200    pub fn with_block_size(mut self, size: u32) -> Self {
201        self.block_size = size;
202        self
203    }
204
205    /// Set the queue capacity (must be power of 2).
206    pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
207        debug_assert!(
208            capacity.is_power_of_two(),
209            "Queue capacity must be power of 2"
210        );
211        self.queue_capacity = capacity;
212        self
213    }
214
215    /// Enable kernel-to-kernel messaging.
216    pub fn with_k2k(mut self, enabled: bool) -> Self {
217        self.enable_k2k = enabled;
218        self
219    }
220
221    /// Enable HLC clock operations.
222    pub fn with_hlc(mut self, enabled: bool) -> Self {
223        self.enable_hlc = enabled;
224        self
225    }
226
227    /// Set message and response sizes.
228    pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
229        self.message_size = message;
230        self.response_size = response;
231        self
232    }
233
234    /// Set idle sleep duration in nanoseconds.
235    pub fn with_idle_sleep(mut self, ns: u32) -> Self {
236        self.idle_sleep_ns = ns;
237        self
238    }
239
240    /// Enable or disable MessageEnvelope format.
241    /// When enabled, messages use the 256-byte header + payload format.
242    pub fn with_envelope_format(mut self, enabled: bool) -> Self {
243        self.use_envelope_format = enabled;
244        self
245    }
246
247    /// Set the kernel numeric ID (for K2K routing and response headers).
248    pub fn with_kernel_id(mut self, id: u64) -> Self {
249        self.kernel_id_num = id;
250        self
251    }
252
253    /// Set the HLC node ID for this kernel.
254    pub fn with_hlc_node_id(mut self, node_id: u64) -> Self {
255        self.hlc_node_id = node_id;
256        self
257    }
258
259    /// Configure global reduction for this kernel.
260    ///
261    /// # Example
262    ///
263    /// ```ignore
264    /// use ringkernel_cuda_codegen::{RingKernelConfig, KernelReductionConfig};
265    /// use ringkernel_cuda_codegen::reduction_intrinsics::ReductionOp;
266    ///
267    /// let config = RingKernelConfig::new("pagerank")
268    ///     .with_reduction(
269    ///         KernelReductionConfig::new()
270    ///             .with_op(ReductionOp::Sum)
271    ///             .with_type("double")
272    ///     );
273    /// ```
274    pub fn with_reduction(mut self, reduction: KernelReductionConfig) -> Self {
275        self.reduction = reduction;
276        // If reduction is enabled and uses cooperative groups, enable them for the kernel
277        if self.reduction.enabled && self.reduction.use_cooperative {
278            self.cooperative_groups = true;
279        }
280        self
281    }
282
283    /// Enable global sum reduction with default settings.
284    ///
285    /// This is a convenience method for enabling a sum reduction with double precision.
286    pub fn with_sum_reduction(mut self) -> Self {
287        self.reduction = KernelReductionConfig::new()
288            .with_op(ReductionOp::Sum)
289            .with_type("double");
290        self.cooperative_groups = true;
291        self
292    }
293
294    /// Generate the kernel function name.
295    pub fn kernel_name(&self) -> String {
296        format!("ring_kernel_{}", self.id)
297    }
298
299    /// Generate CUDA kernel signature.
300    pub fn generate_signature(&self) -> String {
301        let mut sig = String::new();
302
303        writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
304        writeln!(sig, "    ControlBlock* __restrict__ control,").unwrap();
305        writeln!(sig, "    unsigned char* __restrict__ input_buffer,").unwrap();
306        writeln!(sig, "    unsigned char* __restrict__ output_buffer,").unwrap();
307
308        if self.enable_k2k {
309            writeln!(sig, "    K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
310            writeln!(sig, "    unsigned char* __restrict__ k2k_inbox,").unwrap();
311            writeln!(sig, "    unsigned char* __restrict__ k2k_outbox,").unwrap();
312        }
313
314        write!(sig, "    void* __restrict__ shared_state").unwrap();
315        write!(sig, "\n)").unwrap();
316
317        sig
318    }
319
320    /// Generate the kernel preamble (thread setup, variable declarations).
321    pub fn generate_preamble(&self, indent: &str) -> String {
322        let mut code = String::new();
323
324        // Thread identification
325        writeln!(code, "{}// Thread identification", indent).unwrap();
326        writeln!(
327            code,
328            "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
329            indent
330        )
331        .unwrap();
332        writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
333        writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
334        writeln!(code).unwrap();
335
336        // Kernel ID constant
337        writeln!(code, "{}// Kernel identity", indent).unwrap();
338        writeln!(
339            code,
340            "{}const unsigned long long KERNEL_ID = {}ULL;",
341            indent, self.kernel_id_num
342        )
343        .unwrap();
344        writeln!(
345            code,
346            "{}const unsigned long long HLC_NODE_ID = {}ULL;",
347            indent, self.hlc_node_id
348        )
349        .unwrap();
350        writeln!(code).unwrap();
351
352        // Message size constants
353        writeln!(code, "{}// Message buffer constants", indent).unwrap();
354        if self.use_envelope_format {
355            // With envelope format, MSG_SIZE is envelope size (header + payload)
356            writeln!(
357                code,
358                "{}const unsigned int PAYLOAD_SIZE = {};  // User payload size",
359                indent, self.message_size
360            )
361            .unwrap();
362            writeln!(
363                code,
364                "{}const unsigned int MSG_SIZE = MESSAGE_HEADER_SIZE + PAYLOAD_SIZE;  // Full envelope",
365                indent
366            )
367            .unwrap();
368            writeln!(
369                code,
370                "{}const unsigned int RESP_PAYLOAD_SIZE = {};",
371                indent, self.response_size
372            )
373            .unwrap();
374            writeln!(
375                code,
376                "{}const unsigned int RESP_SIZE = MESSAGE_HEADER_SIZE + RESP_PAYLOAD_SIZE;",
377                indent
378            )
379            .unwrap();
380        } else {
381            // Legacy: raw message data
382            writeln!(
383                code,
384                "{}const unsigned int MSG_SIZE = {};",
385                indent, self.message_size
386            )
387            .unwrap();
388            writeln!(
389                code,
390                "{}const unsigned int RESP_SIZE = {};",
391                indent, self.response_size
392            )
393            .unwrap();
394        }
395        writeln!(
396            code,
397            "{}const unsigned int QUEUE_MASK = {};",
398            indent,
399            self.queue_capacity - 1
400        )
401        .unwrap();
402        writeln!(code).unwrap();
403
404        // HLC state (if enabled)
405        if self.enable_hlc {
406            writeln!(code, "{}// HLC clock state", indent).unwrap();
407            writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
408            writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
409            writeln!(code).unwrap();
410        }
411
412        code
413    }
414
415    /// Generate the persistent message loop header.
416    pub fn generate_loop_header(&self, indent: &str) -> String {
417        let mut code = String::new();
418
419        writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
420        writeln!(code, "{}while (true) {{", indent).unwrap();
421        writeln!(code, "{}    // Check for termination signal", indent).unwrap();
422        writeln!(
423            code,
424            "{}    if (atomicAdd(&control->should_terminate, 0) != 0) {{",
425            indent
426        )
427        .unwrap();
428        writeln!(code, "{}        break;", indent).unwrap();
429        writeln!(code, "{}    }}", indent).unwrap();
430        writeln!(code).unwrap();
431
432        // Check if active
433        writeln!(code, "{}    // Check if kernel is active", indent).unwrap();
434        writeln!(
435            code,
436            "{}    if (atomicAdd(&control->is_active, 0) == 0) {{",
437            indent
438        )
439        .unwrap();
440        if self.idle_sleep_ns > 0 {
441            writeln!(
442                code,
443                "{}        __nanosleep({});",
444                indent, self.idle_sleep_ns
445            )
446            .unwrap();
447        }
448        writeln!(code, "{}        continue;", indent).unwrap();
449        writeln!(code, "{}    }}", indent).unwrap();
450        writeln!(code).unwrap();
451
452        // Check for messages
453        writeln!(code, "{}    // Check input queue for messages", indent).unwrap();
454        writeln!(
455            code,
456            "{}    unsigned long long head = atomicAdd(&control->input_head, 0);",
457            indent
458        )
459        .unwrap();
460        writeln!(
461            code,
462            "{}    unsigned long long tail = atomicAdd(&control->input_tail, 0);",
463            indent
464        )
465        .unwrap();
466        writeln!(code).unwrap();
467        writeln!(code, "{}    if (head == tail) {{", indent).unwrap();
468        writeln!(code, "{}        // No messages, yield", indent).unwrap();
469        if self.idle_sleep_ns > 0 {
470            writeln!(
471                code,
472                "{}        __nanosleep({});",
473                indent, self.idle_sleep_ns
474            )
475            .unwrap();
476        }
477        writeln!(code, "{}        continue;", indent).unwrap();
478        writeln!(code, "{}    }}", indent).unwrap();
479        writeln!(code).unwrap();
480
481        // Calculate message pointer
482        writeln!(code, "{}    // Get message from queue", indent).unwrap();
483        writeln!(
484            code,
485            "{}    unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
486            indent
487        )
488        .unwrap();
489        writeln!(
490            code,
491            "{}    unsigned char* envelope_ptr = &input_buffer[msg_idx * MSG_SIZE];",
492            indent
493        )
494        .unwrap();
495
496        if self.use_envelope_format {
497            // Envelope format: parse header and get payload pointer
498            writeln!(code).unwrap();
499            writeln!(code, "{}    // Parse message envelope", indent).unwrap();
500            writeln!(
501                code,
502                "{}    MessageHeader* msg_header = message_get_header(envelope_ptr);",
503                indent
504            )
505            .unwrap();
506            writeln!(
507                code,
508                "{}    unsigned char* msg_ptr = message_get_payload(envelope_ptr);",
509                indent
510            )
511            .unwrap();
512            writeln!(code).unwrap();
513            writeln!(code, "{}    // Validate message (skip invalid)", indent).unwrap();
514            writeln!(
515                code,
516                "{}    if (!message_header_validate(msg_header)) {{",
517                indent
518            )
519            .unwrap();
520            writeln!(
521                code,
522                "{}        atomicAdd(&control->input_tail, 1);",
523                indent
524            )
525            .unwrap();
526            writeln!(
527                code,
528                "{}        atomicAdd(&control->last_error, 1);  // Track errors",
529                indent
530            )
531            .unwrap();
532            writeln!(code, "{}        continue;", indent).unwrap();
533            writeln!(code, "{}    }}", indent).unwrap();
534
535            // Update HLC from incoming message timestamp
536            if self.enable_hlc {
537                writeln!(code).unwrap();
538                writeln!(code, "{}    // Update HLC from message timestamp", indent).unwrap();
539                writeln!(
540                    code,
541                    "{}    if (msg_header->timestamp.physical > hlc_physical) {{",
542                    indent
543                )
544                .unwrap();
545                writeln!(
546                    code,
547                    "{}        hlc_physical = msg_header->timestamp.physical;",
548                    indent
549                )
550                .unwrap();
551                writeln!(code, "{}        hlc_logical = 0;", indent).unwrap();
552                writeln!(code, "{}    }}", indent).unwrap();
553                writeln!(code, "{}    hlc_logical++;", indent).unwrap();
554            }
555        } else {
556            // Legacy raw format
557            writeln!(
558                code,
559                "{}    unsigned char* msg_ptr = envelope_ptr;  // Raw message data",
560                indent
561            )
562            .unwrap();
563        }
564        writeln!(code).unwrap();
565
566        code
567    }
568
569    /// Generate message processing completion code.
570    pub fn generate_message_complete(&self, indent: &str) -> String {
571        let mut code = String::new();
572
573        writeln!(code).unwrap();
574        writeln!(code, "{}    // Mark message as processed", indent).unwrap();
575        writeln!(code, "{}    atomicAdd(&control->input_tail, 1);", indent).unwrap();
576        writeln!(
577            code,
578            "{}    atomicAdd(&control->messages_processed, 1);",
579            indent
580        )
581        .unwrap();
582
583        if self.enable_hlc {
584            writeln!(code).unwrap();
585            writeln!(code, "{}    // Update HLC", indent).unwrap();
586            writeln!(code, "{}    hlc_logical++;", indent).unwrap();
587        }
588
589        code
590    }
591
592    /// Generate the loop footer (end of while loop).
593    pub fn generate_loop_footer(&self, indent: &str) -> String {
594        let mut code = String::new();
595
596        writeln!(code, "{}    __syncthreads();", indent).unwrap();
597        writeln!(code, "{}}}", indent).unwrap();
598
599        code
600    }
601
602    /// Generate kernel epilogue (termination marking).
603    pub fn generate_epilogue(&self, indent: &str) -> String {
604        let mut code = String::new();
605
606        writeln!(code).unwrap();
607        writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
608        writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
609
610        if self.enable_hlc {
611            writeln!(code, "{}    // Store final HLC state", indent).unwrap();
612            writeln!(
613                code,
614                "{}    control->hlc_state.physical = hlc_physical;",
615                indent
616            )
617            .unwrap();
618            writeln!(
619                code,
620                "{}    control->hlc_state.logical = hlc_logical;",
621                indent
622            )
623            .unwrap();
624        }
625
626        writeln!(
627            code,
628            "{}    atomicExch(&control->has_terminated, 1);",
629            indent
630        )
631        .unwrap();
632        writeln!(code, "{}}}", indent).unwrap();
633
634        code
635    }
636
637    /// Generate complete kernel wrapper (without handler body).
638    pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
639        let mut code = String::new();
640
641        // Struct definitions
642        code.push_str(&generate_control_block_struct());
643        code.push('\n');
644
645        if self.enable_hlc {
646            code.push_str(&generate_hlc_struct());
647            code.push('\n');
648        }
649
650        // MessageEnvelope structs (requires HLC for HlcTimestamp)
651        if self.use_envelope_format {
652            code.push_str(&generate_message_envelope_structs());
653            code.push('\n');
654        }
655
656        if self.enable_k2k {
657            code.push_str(&generate_k2k_structs());
658            code.push('\n');
659        }
660
661        // Kernel signature
662        code.push_str(&self.generate_signature());
663        code.push_str(" {\n");
664
665        // Preamble
666        code.push_str(&self.generate_preamble("    "));
667
668        // Message loop
669        code.push_str(&self.generate_loop_header("    "));
670
671        // Handler placeholder
672        writeln!(code, "        // === USER HANDLER CODE ===").unwrap();
673        for line in handler_placeholder.lines() {
674            writeln!(code, "        {}", line).unwrap();
675        }
676        writeln!(code, "        // === END HANDLER CODE ===").unwrap();
677
678        // Message completion
679        code.push_str(&self.generate_message_complete("    "));
680
681        // Loop footer
682        code.push_str(&self.generate_loop_footer("    "));
683
684        // Epilogue
685        code.push_str(&self.generate_epilogue("    "));
686
687        code.push_str("}\n");
688
689        code
690    }
691}
692
693/// Generate CUDA ControlBlock struct definition.
694pub fn generate_control_block_struct() -> String {
695    r#"// Control block for kernel state management (128 bytes, cache-line aligned)
696struct __align__(128) ControlBlock {
697    // Lifecycle state
698    unsigned int is_active;
699    unsigned int should_terminate;
700    unsigned int has_terminated;
701    unsigned int _pad1;
702
703    // Counters
704    unsigned long long messages_processed;
705    unsigned long long messages_in_flight;
706
707    // Queue pointers
708    unsigned long long input_head;
709    unsigned long long input_tail;
710    unsigned long long output_head;
711    unsigned long long output_tail;
712
713    // Queue metadata
714    unsigned int input_capacity;
715    unsigned int output_capacity;
716    unsigned int input_mask;
717    unsigned int output_mask;
718
719    // HLC state
720    struct {
721        unsigned long long physical;
722        unsigned long long logical;
723    } hlc_state;
724
725    // Error state
726    unsigned int last_error;
727    unsigned int error_count;
728
729    // Reserved padding
730    unsigned char _reserved[24];
731};
732"#
733    .to_string()
734}
735
736/// Generate CUDA HLC helper struct.
737pub fn generate_hlc_struct() -> String {
738    r#"// Hybrid Logical Clock state
739struct HlcState {
740    unsigned long long physical;
741    unsigned long long logical;
742};
743
744// HLC timestamp (24 bytes)
745struct __align__(8) HlcTimestamp {
746    unsigned long long physical;
747    unsigned long long logical;
748    unsigned long long node_id;
749};
750"#
751    .to_string()
752}
753
754/// Generate CUDA MessageEnvelope structs matching ringkernel-core layout.
755///
756/// This generates the GPU-side structures that match the Rust `MessageHeader`
757/// and `MessageEnvelope` types, enabling proper serialization between host and device.
758pub fn generate_message_envelope_structs() -> String {
759    r#"// Magic number for message validation
760#define MESSAGE_MAGIC 0x52494E474B45524E ULL  // "RINGKERN"
761#define MESSAGE_VERSION 1
762#define MESSAGE_HEADER_SIZE 256
763#define MAX_PAYLOAD_SIZE (64 * 1024)
764
765// Message priority levels
766#define PRIORITY_LOW 0
767#define PRIORITY_NORMAL 1
768#define PRIORITY_HIGH 2
769#define PRIORITY_CRITICAL 3
770
771// Message header structure (256 bytes, cache-line aligned)
772// Matches ringkernel_core::message::MessageHeader exactly
773struct __align__(64) MessageHeader {
774    // Magic number for validation (0xRINGKERN)
775    unsigned long long magic;
776    // Header version
777    unsigned int version;
778    // Message flags
779    unsigned int flags;
780    // Unique message identifier
781    unsigned long long message_id;
782    // Correlation ID for request-response
783    unsigned long long correlation_id;
784    // Source kernel ID (0 for host)
785    unsigned long long source_kernel;
786    // Destination kernel ID (0 for host)
787    unsigned long long dest_kernel;
788    // Message type discriminator
789    unsigned long long message_type;
790    // Priority level
791    unsigned char priority;
792    // Reserved for alignment
793    unsigned char _reserved1[7];
794    // Payload size in bytes
795    unsigned long long payload_size;
796    // Checksum of payload (CRC32)
797    unsigned int checksum;
798    // Reserved for alignment
799    unsigned int _reserved2;
800    // HLC timestamp when message was created
801    HlcTimestamp timestamp;
802    // Deadline timestamp (0 = no deadline)
803    HlcTimestamp deadline;
804    // Reserved for future use (104 bytes total: 32+32+32+8)
805    unsigned char _reserved3[104];
806};
807
808// Validate message header
809__device__ inline int message_header_validate(const MessageHeader* header) {
810    return header->magic == MESSAGE_MAGIC &&
811           header->version <= MESSAGE_VERSION &&
812           header->payload_size <= MAX_PAYLOAD_SIZE;
813}
814
815// Get payload pointer from header
816__device__ inline unsigned char* message_get_payload(unsigned char* envelope_ptr) {
817    return envelope_ptr + MESSAGE_HEADER_SIZE;
818}
819
820// Get header from envelope pointer
821__device__ inline MessageHeader* message_get_header(unsigned char* envelope_ptr) {
822    return (MessageHeader*)envelope_ptr;
823}
824
825// CRC32 computation for payload checksums
826// Uses CRC32-C (Castagnoli) polynomial 0x1EDC6F41
827__device__ inline unsigned int message_compute_checksum(const unsigned char* data, unsigned long long size) {
828    unsigned int crc = 0xFFFFFFFF;
829    for (unsigned long long i = 0; i < size; i++) {
830        crc ^= data[i];
831        for (int j = 0; j < 8; j++) {
832            crc = (crc >> 1) ^ (0x82F63B78 * (crc & 1));  // CRC32-C polynomial
833        }
834    }
835    return ~crc;
836}
837
838// Verify payload checksum
839__device__ inline int message_verify_checksum(const MessageHeader* header, const unsigned char* payload) {
840    if (header->payload_size == 0) {
841        return header->checksum == 0;  // Empty payload should have zero checksum
842    }
843    unsigned int computed = message_compute_checksum(payload, header->payload_size);
844    return computed == header->checksum;
845}
846
847// Create a response header based on request
848// Note: payload_ptr is optional, pass NULL if checksum not needed yet
849__device__ inline void message_create_response_header(
850    MessageHeader* response,
851    const MessageHeader* request,
852    unsigned long long this_kernel_id,
853    unsigned long long payload_size,
854    unsigned long long hlc_physical,
855    unsigned long long hlc_logical,
856    unsigned long long hlc_node_id,
857    const unsigned char* payload_ptr  // Optional: for checksum computation
858) {
859    response->magic = MESSAGE_MAGIC;
860    response->version = MESSAGE_VERSION;
861    response->flags = 0;
862    // Generate new message ID (simple increment from request)
863    response->message_id = request->message_id + 0x100000000ULL;
864    // Preserve correlation ID for request-response matching
865    response->correlation_id = request->correlation_id != 0
866        ? request->correlation_id
867        : request->message_id;
868    response->source_kernel = this_kernel_id;
869    response->dest_kernel = request->source_kernel;  // Response goes back to sender
870    response->message_type = request->message_type + 1;  // Convention: response type = request + 1
871    response->priority = request->priority;
872    response->payload_size = payload_size;
873    // Compute checksum if payload provided
874    if (payload_ptr != NULL && payload_size > 0) {
875        response->checksum = message_compute_checksum(payload_ptr, payload_size);
876    } else {
877        response->checksum = 0;
878    }
879    response->timestamp.physical = hlc_physical;
880    response->timestamp.logical = hlc_logical;
881    response->timestamp.node_id = hlc_node_id;
882    response->deadline.physical = 0;
883    response->deadline.logical = 0;
884    response->deadline.node_id = 0;
885}
886
887// Calculate total envelope size
888__device__ inline unsigned int message_envelope_size(const MessageHeader* header) {
889    return MESSAGE_HEADER_SIZE + (unsigned int)header->payload_size;
890}
891"#
892    .to_string()
893}
894
895/// Generate CUDA K2K routing structs and helper functions.
896pub fn generate_k2k_structs() -> String {
897    r#"// Kernel-to-kernel routing table entry (48 bytes)
898// Matches ringkernel_cuda::k2k_gpu::K2KRouteEntry
899struct K2KRoute {
900    unsigned long long target_kernel_id;
901    unsigned long long target_inbox;      // device pointer
902    unsigned long long target_head;       // device pointer to head
903    unsigned long long target_tail;       // device pointer to tail
904    unsigned int capacity;
905    unsigned int mask;
906    unsigned int msg_size;
907    unsigned int _pad;
908};
909
910// K2K routing table (8 + 48*16 = 776 bytes)
911struct K2KRoutingTable {
912    unsigned int num_routes;
913    unsigned int _pad;
914    K2KRoute routes[16];  // Max 16 K2K connections
915};
916
917// K2K inbox header (64 bytes, cache-line aligned)
918// Matches ringkernel_cuda::k2k_gpu::K2KInboxHeader
919struct __align__(64) K2KInboxHeader {
920    unsigned long long head;
921    unsigned long long tail;
922    unsigned int capacity;
923    unsigned int mask;
924    unsigned int msg_size;
925    unsigned int _pad;
926};
927
928// Send a message envelope to another kernel via K2K
929// The entire envelope (header + payload) is copied to the target's inbox
930__device__ inline int k2k_send_envelope(
931    K2KRoutingTable* routes,
932    unsigned long long target_id,
933    unsigned long long source_kernel_id,
934    const void* payload_ptr,
935    unsigned int payload_size,
936    unsigned long long message_type,
937    unsigned long long hlc_physical,
938    unsigned long long hlc_logical,
939    unsigned long long hlc_node_id
940) {
941    // Find route for target
942    for (unsigned int i = 0; i < routes->num_routes; i++) {
943        if (routes->routes[i].target_kernel_id == target_id) {
944            K2KRoute* route = &routes->routes[i];
945
946            // Calculate total envelope size
947            unsigned int envelope_size = MESSAGE_HEADER_SIZE + payload_size;
948            if (envelope_size > route->msg_size) {
949                return -1;  // Message too large
950            }
951
952            // Atomically claim a slot in target's inbox
953            unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
954            unsigned long long slot = atomicAdd(target_head_ptr, 1);
955            unsigned int idx = (unsigned int)(slot & route->mask);
956
957            // Calculate destination pointer
958            unsigned char* dest = ((unsigned char*)route->target_inbox) +
959                                  sizeof(K2KInboxHeader) + idx * route->msg_size;
960
961            // Build message header
962            MessageHeader* header = (MessageHeader*)dest;
963            header->magic = MESSAGE_MAGIC;
964            header->version = MESSAGE_VERSION;
965            header->flags = 0;
966            header->message_id = (source_kernel_id << 32) | (slot & 0xFFFFFFFF);
967            header->correlation_id = 0;
968            header->source_kernel = source_kernel_id;
969            header->dest_kernel = target_id;
970            header->message_type = message_type;
971            header->priority = PRIORITY_NORMAL;
972            header->payload_size = payload_size;
973            header->timestamp.physical = hlc_physical;
974            header->timestamp.logical = hlc_logical;
975            header->timestamp.node_id = hlc_node_id;
976            header->deadline.physical = 0;
977            header->deadline.logical = 0;
978            header->deadline.node_id = 0;
979
980            // Copy payload after header and compute checksum
981            if (payload_size > 0 && payload_ptr != NULL) {
982                memcpy(dest + MESSAGE_HEADER_SIZE, payload_ptr, payload_size);
983                header->checksum = message_compute_checksum(
984                    (const unsigned char*)payload_ptr, payload_size);
985            } else {
986                header->checksum = 0;
987            }
988
989            __threadfence();  // Ensure write is visible
990            return 1;  // Success
991        }
992    }
993    return 0;  // Route not found
994}
995
996// Legacy k2k_send for raw message data (no envelope)
997__device__ inline int k2k_send(
998    K2KRoutingTable* routes,
999    unsigned long long target_id,
1000    const void* msg_ptr,
1001    unsigned int msg_size
1002) {
1003    // Find route for target
1004    for (unsigned int i = 0; i < routes->num_routes; i++) {
1005        if (routes->routes[i].target_kernel_id == target_id) {
1006            K2KRoute* route = &routes->routes[i];
1007
1008            // Atomically claim a slot in target's inbox
1009            unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
1010            unsigned long long slot = atomicAdd(target_head_ptr, 1);
1011            unsigned int idx = (unsigned int)(slot & route->mask);
1012
1013            // Copy message to target inbox
1014            unsigned char* dest = ((unsigned char*)route->target_inbox) +
1015                                  sizeof(K2KInboxHeader) + idx * route->msg_size;
1016            memcpy(dest, msg_ptr, msg_size < route->msg_size ? msg_size : route->msg_size);
1017
1018            __threadfence();
1019            return 1;  // Success
1020        }
1021    }
1022    return 0;  // Route not found
1023}
1024
1025// Check if there are K2K messages in inbox
1026__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
1027    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1028    unsigned long long head = atomicAdd(&header->head, 0);
1029    unsigned long long tail = atomicAdd(&header->tail, 0);
1030    return head != tail;
1031}
1032
1033// Try to receive a K2K message envelope
1034// Returns pointer to MessageHeader if available, NULL otherwise
1035__device__ inline MessageHeader* k2k_try_recv_envelope(unsigned char* k2k_inbox) {
1036    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1037
1038    unsigned long long head = atomicAdd(&header->head, 0);
1039    unsigned long long tail = atomicAdd(&header->tail, 0);
1040
1041    if (head == tail) {
1042        return NULL;  // No messages
1043    }
1044
1045    // Get message pointer
1046    unsigned int idx = (unsigned int)(tail & header->mask);
1047    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1048    MessageHeader* msg_header = (MessageHeader*)(data_start + idx * header->msg_size);
1049
1050    // Validate header
1051    if (!message_header_validate(msg_header)) {
1052        // Invalid message, skip it
1053        atomicAdd(&header->tail, 1);
1054        return NULL;
1055    }
1056
1057    // Advance tail (consume message)
1058    atomicAdd(&header->tail, 1);
1059
1060    return msg_header;
1061}
1062
1063// Legacy k2k_try_recv for raw data
1064__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
1065    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1066
1067    unsigned long long head = atomicAdd(&header->head, 0);
1068    unsigned long long tail = atomicAdd(&header->tail, 0);
1069
1070    if (head == tail) {
1071        return NULL;  // No messages
1072    }
1073
1074    // Get message pointer
1075    unsigned int idx = (unsigned int)(tail & header->mask);
1076    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1077
1078    // Advance tail (consume message)
1079    atomicAdd(&header->tail, 1);
1080
1081    return data_start + idx * header->msg_size;
1082}
1083
1084// Peek at next K2K message without consuming
1085__device__ inline MessageHeader* k2k_peek_envelope(unsigned char* k2k_inbox) {
1086    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1087
1088    unsigned long long head = atomicAdd(&header->head, 0);
1089    unsigned long long tail = atomicAdd(&header->tail, 0);
1090
1091    if (head == tail) {
1092        return NULL;  // No messages
1093    }
1094
1095    unsigned int idx = (unsigned int)(tail & header->mask);
1096    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1097
1098    return (MessageHeader*)(data_start + idx * header->msg_size);
1099}
1100
1101// Legacy k2k_peek
1102__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
1103    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1104
1105    unsigned long long head = atomicAdd(&header->head, 0);
1106    unsigned long long tail = atomicAdd(&header->tail, 0);
1107
1108    if (head == tail) {
1109        return NULL;  // No messages
1110    }
1111
1112    unsigned int idx = (unsigned int)(tail & header->mask);
1113    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1114
1115    return data_start + idx * header->msg_size;
1116}
1117
1118// Get number of pending K2K messages
1119__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
1120    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1121    unsigned long long head = atomicAdd(&header->head, 0);
1122    unsigned long long tail = atomicAdd(&header->tail, 0);
1123    return (unsigned int)(head - tail);
1124}
1125"#
1126    .to_string()
1127}
1128
1129/// Intrinsic functions available in ring kernel handlers.
1130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1131pub enum RingKernelIntrinsic {
1132    // Control block access
1133    IsActive,
1134    ShouldTerminate,
1135    MarkTerminated,
1136    GetMessagesProcessed,
1137
1138    // Queue operations
1139    InputQueueSize,
1140    OutputQueueSize,
1141    InputQueueEmpty,
1142    OutputQueueEmpty,
1143    EnqueueResponse,
1144
1145    // HLC operations
1146    HlcTick,
1147    HlcUpdate,
1148    HlcNow,
1149
1150    // K2K operations
1151    K2kSend,
1152    K2kTryRecv,
1153    K2kHasMessage,
1154    K2kPeek,
1155    K2kPendingCount,
1156}
1157
1158impl RingKernelIntrinsic {
1159    /// Get the CUDA code for this intrinsic.
1160    pub fn to_cuda(&self, args: &[String]) -> String {
1161        match self {
1162            Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
1163            Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
1164            Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
1165            Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
1166
1167            Self::InputQueueSize => {
1168                "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
1169                    .to_string()
1170            }
1171            Self::OutputQueueSize => {
1172                "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
1173                    .to_string()
1174            }
1175            Self::InputQueueEmpty => {
1176                "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
1177                    .to_string()
1178            }
1179            Self::OutputQueueEmpty => {
1180                "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
1181                    .to_string()
1182            }
1183            Self::EnqueueResponse => {
1184                if !args.is_empty() {
1185                    format!(
1186                        "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
1187                         memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
1188                        args[0]
1189                    )
1190                } else {
1191                    "/* enqueue_response requires response pointer */".to_string()
1192                }
1193            }
1194
1195            Self::HlcTick => "hlc_logical++".to_string(),
1196            Self::HlcUpdate => {
1197                if !args.is_empty() {
1198                    format!(
1199                        "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
1200                        args[0], args[0]
1201                    )
1202                } else {
1203                    "hlc_logical++".to_string()
1204                }
1205            }
1206            Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
1207
1208            Self::K2kSend => {
1209                if args.len() >= 2 {
1210                    // k2k_send(target_id, msg_ptr) -> k2k_send(k2k_routes, target_id, msg_ptr, sizeof(*msg_ptr))
1211                    format!(
1212                        "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
1213                        args[0], args[1], args[1]
1214                    )
1215                } else {
1216                    "/* k2k_send requires target_id and msg_ptr */".to_string()
1217                }
1218            }
1219            Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
1220            Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
1221            Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
1222            Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
1223        }
1224    }
1225
1226    /// Parse a function name to get the intrinsic.
1227    pub fn from_name(name: &str) -> Option<Self> {
1228        match name {
1229            "is_active" | "is_kernel_active" => Some(Self::IsActive),
1230            "should_terminate" => Some(Self::ShouldTerminate),
1231            "mark_terminated" => Some(Self::MarkTerminated),
1232            "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
1233
1234            "input_queue_size" => Some(Self::InputQueueSize),
1235            "output_queue_size" => Some(Self::OutputQueueSize),
1236            "input_queue_empty" => Some(Self::InputQueueEmpty),
1237            "output_queue_empty" => Some(Self::OutputQueueEmpty),
1238            "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
1239
1240            "hlc_tick" => Some(Self::HlcTick),
1241            "hlc_update" => Some(Self::HlcUpdate),
1242            "hlc_now" => Some(Self::HlcNow),
1243
1244            "k2k_send" => Some(Self::K2kSend),
1245            "k2k_try_recv" => Some(Self::K2kTryRecv),
1246            "k2k_has_message" => Some(Self::K2kHasMessage),
1247            "k2k_peek" => Some(Self::K2kPeek),
1248            "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
1249
1250            _ => None,
1251        }
1252    }
1253
1254    /// Check if this intrinsic requires K2K support.
1255    pub fn requires_k2k(&self) -> bool {
1256        matches!(
1257            self,
1258            Self::K2kSend
1259                | Self::K2kTryRecv
1260                | Self::K2kHasMessage
1261                | Self::K2kPeek
1262                | Self::K2kPendingCount
1263        )
1264    }
1265
1266    /// Check if this intrinsic requires HLC support.
1267    pub fn requires_hlc(&self) -> bool {
1268        matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
1269    }
1270
1271    /// Check if this intrinsic requires control block access.
1272    pub fn requires_control_block(&self) -> bool {
1273        matches!(
1274            self,
1275            Self::IsActive
1276                | Self::ShouldTerminate
1277                | Self::MarkTerminated
1278                | Self::GetMessagesProcessed
1279                | Self::InputQueueSize
1280                | Self::OutputQueueSize
1281                | Self::InputQueueEmpty
1282                | Self::OutputQueueEmpty
1283                | Self::EnqueueResponse
1284        )
1285    }
1286}
1287
1288// ============================================================================
1289// Handler Dispatch Code Generation
1290// ============================================================================
1291
1292/// Handler registration for CUDA code generation.
1293///
1294/// This mirrors `ringkernel_core::persistent_message::HandlerRegistration`
1295/// but is designed for code generation use.
1296#[derive(Debug, Clone)]
1297pub struct CudaHandlerInfo {
1298    /// Handler ID (0-255) used in switch statement.
1299    pub handler_id: u32,
1300    /// Handler function name in CUDA code.
1301    pub func_name: String,
1302    /// Message type name (for documentation).
1303    pub message_type_name: String,
1304    /// Message type ID (for validation).
1305    pub message_type_id: u64,
1306    /// CUDA code body for this handler (optional).
1307    /// If None, generates a call to the named function.
1308    pub cuda_body: Option<String>,
1309    /// Whether this handler produces a response.
1310    pub produces_response: bool,
1311}
1312
1313impl CudaHandlerInfo {
1314    /// Create a new handler info.
1315    pub fn new(handler_id: u32, func_name: impl Into<String>) -> Self {
1316        Self {
1317            handler_id,
1318            func_name: func_name.into(),
1319            message_type_name: String::new(),
1320            message_type_id: 0,
1321            cuda_body: None,
1322            produces_response: false,
1323        }
1324    }
1325
1326    /// Set the message type information.
1327    pub fn with_message_type(mut self, name: &str, type_id: u64) -> Self {
1328        self.message_type_name = name.to_string();
1329        self.message_type_id = type_id;
1330        self
1331    }
1332
1333    /// Set the CUDA body for inline handler code.
1334    pub fn with_cuda_body(mut self, body: impl Into<String>) -> Self {
1335        self.cuda_body = Some(body.into());
1336        self
1337    }
1338
1339    /// Mark this handler as producing a response.
1340    pub fn with_response(mut self) -> Self {
1341        self.produces_response = true;
1342        self
1343    }
1344}
1345
1346/// Dispatch table for multi-handler kernel code generation.
1347#[derive(Debug, Clone, Default)]
1348pub struct CudaDispatchTable {
1349    handlers: Vec<CudaHandlerInfo>,
1350    /// Name of the ExtendedH2KMessage struct (default: "ExtendedH2KMessage").
1351    pub message_struct_name: String,
1352    /// Field name for handler_id in the message (default: "handler_id").
1353    pub handler_id_field: String,
1354    /// Counter for unknown handlers (default: "unknown_handler_count").
1355    pub unknown_counter_field: String,
1356}
1357
1358impl CudaDispatchTable {
1359    /// Create a new empty dispatch table.
1360    pub fn new() -> Self {
1361        Self {
1362            handlers: Vec::new(),
1363            message_struct_name: "ExtendedH2KMessage".to_string(),
1364            handler_id_field: "handler_id".to_string(),
1365            unknown_counter_field: "unknown_handler_count".to_string(),
1366        }
1367    }
1368
1369    /// Add a handler to the table.
1370    pub fn add_handler(&mut self, handler: CudaHandlerInfo) {
1371        self.handlers.push(handler);
1372    }
1373
1374    /// Add a handler using builder pattern.
1375    pub fn with_handler(mut self, handler: CudaHandlerInfo) -> Self {
1376        self.add_handler(handler);
1377        self
1378    }
1379
1380    /// Set the message struct name.
1381    pub fn with_message_struct(mut self, name: &str) -> Self {
1382        self.message_struct_name = name.to_string();
1383        self
1384    }
1385
1386    /// Get all handlers.
1387    pub fn handlers(&self) -> &[CudaHandlerInfo] {
1388        &self.handlers
1389    }
1390
1391    /// Get handler count.
1392    pub fn len(&self) -> usize {
1393        self.handlers.len()
1394    }
1395
1396    /// Check if empty.
1397    pub fn is_empty(&self) -> bool {
1398        self.handlers.is_empty()
1399    }
1400}
1401
1402/// Generate CUDA switch statement for handler dispatch.
1403///
1404/// # Arguments
1405///
1406/// * `table` - Dispatch table with handler information
1407/// * `indent` - Indentation string for formatting
1408///
1409/// # Returns
1410///
1411/// CUDA code with switch statement for handler dispatch.
1412///
1413/// # Example Output
1414///
1415/// ```cuda
1416/// // Handler dispatch based on handler_id
1417/// uint32_t handler_id = msg->handler_id;
1418/// switch (handler_id) {
1419///     case 1: {
1420///         // Handler: fraud_check (type_id: 1001)
1421///         handle_fraud_check(msg, state, response);
1422///         break;
1423///     }
1424///     case 2: {
1425///         // Handler: aggregate (type_id: 1002)
1426///         handle_aggregate(msg, state, response);
1427///         break;
1428///     }
1429///     default:
1430///         atomicAdd(&ctrl->unknown_handler_count, 1);
1431///         break;
1432/// }
1433/// ```
1434pub fn generate_handler_dispatch_code(table: &CudaDispatchTable, indent: &str) -> String {
1435    let mut code = String::new();
1436
1437    if table.is_empty() {
1438        writeln!(
1439            code,
1440            "{}// No handlers registered - dispatch table empty",
1441            indent
1442        )
1443        .unwrap();
1444        return code;
1445    }
1446
1447    writeln!(code, "{}// Handler dispatch based on handler_id", indent).unwrap();
1448    writeln!(
1449        code,
1450        "{}uint32_t handler_id = msg->{};",
1451        indent, table.handler_id_field
1452    )
1453    .unwrap();
1454    writeln!(code, "{}switch (handler_id) {{", indent).unwrap();
1455
1456    for handler in &table.handlers {
1457        writeln!(code, "{}    case {}: {{", indent, handler.handler_id).unwrap();
1458
1459        // Add comment with handler info
1460        if !handler.message_type_name.is_empty() {
1461            writeln!(
1462                code,
1463                "{}        // Handler: {} (type_id: {})",
1464                indent, handler.message_type_name, handler.message_type_id
1465            )
1466            .unwrap();
1467        } else {
1468            writeln!(code, "{}        // Handler: {}", indent, handler.func_name).unwrap();
1469        }
1470
1471        // Generate handler code
1472        if let Some(body) = &handler.cuda_body {
1473            // Inline the handler body
1474            for line in body.lines() {
1475                writeln!(code, "{}        {}", indent, line).unwrap();
1476            }
1477        } else {
1478            // Generate function call
1479            if handler.produces_response {
1480                writeln!(
1481                    code,
1482                    "{}        {}(msg, state, response);",
1483                    indent, handler.func_name
1484                )
1485                .unwrap();
1486            } else {
1487                writeln!(code, "{}        {}(msg, state);", indent, handler.func_name).unwrap();
1488            }
1489        }
1490
1491        writeln!(code, "{}        break;", indent).unwrap();
1492        writeln!(code, "{}    }}", indent).unwrap();
1493    }
1494
1495    // Default case for unknown handlers
1496    writeln!(code, "{}    default:", indent).unwrap();
1497    writeln!(
1498        code,
1499        "{}        atomicAdd(&ctrl->{}, 1);",
1500        indent, table.unknown_counter_field
1501    )
1502    .unwrap();
1503    writeln!(code, "{}        break;", indent).unwrap();
1504    writeln!(code, "{}}}", indent).unwrap();
1505
1506    code
1507}
1508
1509/// Generate the ExtendedH2KMessage struct for handler dispatch.
1510///
1511/// This message format includes a handler_id field for type-based dispatch.
1512pub fn generate_extended_h2k_message_struct() -> String {
1513    r#"// Extended H2K message for handler dispatch (64 bytes, cache-line aligned)
1514// This format supports type-based routing within persistent kernels
1515struct __align__(64) ExtendedH2KMessage {
1516    uint32_t handler_id;      // Handler ID for dispatch (0-255)
1517    uint32_t flags;           // Message flags (see message_flags)
1518    uint64_t cmd_id;          // Command/correlation ID for request tracking
1519    uint64_t timestamp;       // HLC timestamp
1520    uint8_t payload[40];      // Inline payload (up to 32 bytes used typically)
1521};
1522
1523// Message flags for ExtendedH2KMessage
1524#define EXT_MSG_FLAG_EXTENDED       0x01
1525#define EXT_MSG_FLAG_HIGH_PRIORITY  0x02
1526#define EXT_MSG_FLAG_EXTERNAL_BUF   0x04
1527#define EXT_MSG_FLAG_REQUIRES_RESP  0x08
1528"#
1529    .to_string()
1530}
1531
1532/// Generate a complete multi-handler kernel with dispatch table.
1533///
1534/// # Arguments
1535///
1536/// * `config` - Ring kernel configuration
1537/// * `table` - Dispatch table with handler information
1538///
1539/// # Returns
1540///
1541/// Complete CUDA kernel code with handler dispatch.
1542pub fn generate_multi_handler_kernel(
1543    config: &RingKernelConfig,
1544    table: &CudaDispatchTable,
1545) -> String {
1546    let mut code = String::new();
1547
1548    // Struct definitions
1549    code.push_str(&generate_control_block_struct());
1550    code.push('\n');
1551
1552    if config.enable_hlc {
1553        code.push_str(&generate_hlc_struct());
1554        code.push('\n');
1555    }
1556
1557    // Extended H2K message struct
1558    code.push_str(&generate_extended_h2k_message_struct());
1559    code.push('\n');
1560
1561    if config.use_envelope_format {
1562        code.push_str(&generate_message_envelope_structs());
1563        code.push('\n');
1564    }
1565
1566    if config.enable_k2k {
1567        code.push_str(&generate_k2k_structs());
1568        code.push('\n');
1569    }
1570
1571    // Kernel signature
1572    code.push_str(&config.generate_signature());
1573    code.push_str(" {\n");
1574
1575    // Preamble
1576    code.push_str(&config.generate_preamble("    "));
1577
1578    // Message loop
1579    code.push_str(&config.generate_loop_header("    "));
1580
1581    // Handler dispatch
1582    writeln!(code, "        // === MULTI-HANDLER DISPATCH ===").unwrap();
1583    writeln!(
1584        code,
1585        "        ExtendedH2KMessage* msg = (ExtendedH2KMessage*)(input_buffer + (msg_idx * MSG_SIZE));"
1586    )
1587    .unwrap();
1588    code.push_str(&generate_handler_dispatch_code(table, "        "));
1589    writeln!(code, "        // === END DISPATCH ===").unwrap();
1590
1591    // Message completion
1592    code.push_str(&config.generate_message_complete("    "));
1593
1594    // Loop footer
1595    code.push_str(&config.generate_loop_footer("    "));
1596
1597    // Epilogue
1598    code.push_str(&config.generate_epilogue("    "));
1599
1600    code.push_str("}\n");
1601
1602    code
1603}
1604
1605#[cfg(test)]
1606mod tests {
1607    use super::*;
1608
1609    #[test]
1610    fn test_default_config() {
1611        let config = RingKernelConfig::default();
1612        assert_eq!(config.block_size, 128);
1613        assert_eq!(config.queue_capacity, 1024);
1614        assert!(config.enable_hlc);
1615        assert!(!config.enable_k2k);
1616    }
1617
1618    #[test]
1619    fn test_config_builder() {
1620        let config = RingKernelConfig::new("processor")
1621            .with_block_size(256)
1622            .with_queue_capacity(2048)
1623            .with_k2k(true)
1624            .with_hlc(true);
1625
1626        assert_eq!(config.id, "processor");
1627        assert_eq!(config.kernel_name(), "ring_kernel_processor");
1628        assert_eq!(config.block_size, 256);
1629        assert_eq!(config.queue_capacity, 2048);
1630        assert!(config.enable_k2k);
1631        assert!(config.enable_hlc);
1632    }
1633
1634    #[test]
1635    fn test_kernel_signature() {
1636        let config = RingKernelConfig::new("test");
1637        let sig = config.generate_signature();
1638
1639        assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
1640        assert!(sig.contains("ControlBlock* __restrict__ control"));
1641        assert!(sig.contains("input_buffer"));
1642        assert!(sig.contains("output_buffer"));
1643        assert!(sig.contains("shared_state"));
1644    }
1645
1646    #[test]
1647    fn test_kernel_signature_with_k2k() {
1648        let config = RingKernelConfig::new("k2k_test").with_k2k(true);
1649        let sig = config.generate_signature();
1650
1651        assert!(sig.contains("K2KRoutingTable"));
1652        assert!(sig.contains("k2k_inbox"));
1653        assert!(sig.contains("k2k_outbox"));
1654    }
1655
1656    #[test]
1657    fn test_preamble_generation() {
1658        let config = RingKernelConfig::new("test").with_hlc(true);
1659        let preamble = config.generate_preamble("    ");
1660
1661        assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
1662        assert!(preamble.contains("int lane_id"));
1663        assert!(preamble.contains("int warp_id"));
1664        assert!(preamble.contains("MSG_SIZE"));
1665        assert!(preamble.contains("hlc_physical"));
1666        assert!(preamble.contains("hlc_logical"));
1667    }
1668
1669    #[test]
1670    fn test_loop_header() {
1671        let config = RingKernelConfig::new("test");
1672        let header = config.generate_loop_header("    ");
1673
1674        assert!(header.contains("while (true)"));
1675        assert!(header.contains("should_terminate"));
1676        assert!(header.contains("is_active"));
1677        assert!(header.contains("input_head"));
1678        assert!(header.contains("input_tail"));
1679        assert!(header.contains("msg_ptr"));
1680    }
1681
1682    #[test]
1683    fn test_epilogue() {
1684        let config = RingKernelConfig::new("test").with_hlc(true);
1685        let epilogue = config.generate_epilogue("    ");
1686
1687        assert!(epilogue.contains("has_terminated"));
1688        assert!(epilogue.contains("hlc_state.physical"));
1689        assert!(epilogue.contains("hlc_state.logical"));
1690    }
1691
1692    #[test]
1693    fn test_control_block_struct() {
1694        let code = generate_control_block_struct();
1695
1696        assert!(code.contains("struct __align__(128) ControlBlock"));
1697        assert!(code.contains("is_active"));
1698        assert!(code.contains("should_terminate"));
1699        assert!(code.contains("has_terminated"));
1700        assert!(code.contains("messages_processed"));
1701        assert!(code.contains("input_head"));
1702        assert!(code.contains("input_tail"));
1703        assert!(code.contains("hlc_state"));
1704    }
1705
1706    #[test]
1707    fn test_full_kernel_wrapper() {
1708        let config = RingKernelConfig::new("example")
1709            .with_block_size(128)
1710            .with_hlc(true);
1711
1712        let kernel = config.generate_kernel_wrapper("// Process message here");
1713
1714        assert!(kernel.contains("struct __align__(128) ControlBlock"));
1715        assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
1716        assert!(kernel.contains("while (true)"));
1717        assert!(kernel.contains("// Process message here"));
1718        assert!(kernel.contains("has_terminated"));
1719
1720        println!("Generated kernel:\n{}", kernel);
1721    }
1722
1723    #[test]
1724    fn test_intrinsic_lookup() {
1725        assert_eq!(
1726            RingKernelIntrinsic::from_name("is_active"),
1727            Some(RingKernelIntrinsic::IsActive)
1728        );
1729        assert_eq!(
1730            RingKernelIntrinsic::from_name("should_terminate"),
1731            Some(RingKernelIntrinsic::ShouldTerminate)
1732        );
1733        assert_eq!(
1734            RingKernelIntrinsic::from_name("hlc_tick"),
1735            Some(RingKernelIntrinsic::HlcTick)
1736        );
1737        assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
1738    }
1739
1740    #[test]
1741    fn test_intrinsic_cuda_output() {
1742        assert!(RingKernelIntrinsic::IsActive
1743            .to_cuda(&[])
1744            .contains("is_active"));
1745        assert!(RingKernelIntrinsic::ShouldTerminate
1746            .to_cuda(&[])
1747            .contains("should_terminate"));
1748        assert!(RingKernelIntrinsic::HlcTick
1749            .to_cuda(&[])
1750            .contains("hlc_logical++"));
1751    }
1752
1753    #[test]
1754    fn test_k2k_structs_generation() {
1755        let k2k_code = generate_k2k_structs();
1756
1757        // Check struct definitions
1758        assert!(
1759            k2k_code.contains("struct K2KRoute"),
1760            "Should have K2KRoute struct"
1761        );
1762        assert!(
1763            k2k_code.contains("struct K2KRoutingTable"),
1764            "Should have K2KRoutingTable struct"
1765        );
1766        assert!(
1767            k2k_code.contains("K2KInboxHeader"),
1768            "Should have K2KInboxHeader struct"
1769        );
1770
1771        // Check helper functions
1772        assert!(
1773            k2k_code.contains("__device__ inline int k2k_send"),
1774            "Should have k2k_send function"
1775        );
1776        assert!(
1777            k2k_code.contains("__device__ inline int k2k_has_message"),
1778            "Should have k2k_has_message function"
1779        );
1780        assert!(
1781            k2k_code.contains("__device__ inline void* k2k_try_recv"),
1782            "Should have k2k_try_recv function"
1783        );
1784        assert!(
1785            k2k_code.contains("__device__ inline void* k2k_peek"),
1786            "Should have k2k_peek function"
1787        );
1788        assert!(
1789            k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
1790            "Should have k2k_pending_count function"
1791        );
1792
1793        println!("K2K code:\n{}", k2k_code);
1794    }
1795
1796    #[test]
1797    fn test_full_k2k_kernel() {
1798        let config = RingKernelConfig::new("k2k_processor")
1799            .with_block_size(128)
1800            .with_k2k(true)
1801            .with_hlc(true);
1802
1803        let kernel = config.generate_kernel_wrapper("// K2K handler code");
1804
1805        // Check K2K-specific components
1806        assert!(
1807            kernel.contains("K2KRoutingTable"),
1808            "Should have K2KRoutingTable"
1809        );
1810        assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
1811        assert!(
1812            kernel.contains("K2KInboxHeader"),
1813            "Should have K2KInboxHeader"
1814        );
1815        assert!(
1816            kernel.contains("k2k_routes"),
1817            "Should have k2k_routes param"
1818        );
1819        assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
1820        assert!(
1821            kernel.contains("k2k_outbox"),
1822            "Should have k2k_outbox param"
1823        );
1824        assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
1825        assert!(
1826            kernel.contains("k2k_try_recv"),
1827            "Should have k2k_try_recv function"
1828        );
1829
1830        println!("Full K2K kernel:\n{}", kernel);
1831    }
1832
1833    #[test]
1834    fn test_k2k_intrinsic_requirements() {
1835        assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
1836        assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
1837        assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
1838        assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
1839        assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
1840
1841        assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
1842        assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
1843    }
1844
1845    // ============== REDUCTION CONFIGURATION TESTS ==============
1846
1847    #[test]
1848    fn test_kernel_reduction_config_default() {
1849        let config = KernelReductionConfig::new();
1850        assert!(!config.enabled);
1851        assert_eq!(config.op, ReductionOp::Sum);
1852        assert_eq!(config.accumulator_type, "double");
1853        assert!(config.use_cooperative);
1854    }
1855
1856    #[test]
1857    fn test_kernel_reduction_config_builder() {
1858        let config = KernelReductionConfig::new()
1859            .with_op(ReductionOp::Max)
1860            .with_type("float")
1861            .with_cooperative(false)
1862            .with_shared_name("my_shared")
1863            .with_accumulator_name("my_accumulator");
1864
1865        assert!(config.enabled);
1866        assert_eq!(config.op, ReductionOp::Max);
1867        assert_eq!(config.accumulator_type, "float");
1868        assert!(!config.use_cooperative);
1869        assert_eq!(config.shared_array_name, "my_shared");
1870        assert_eq!(config.accumulator_name, "my_accumulator");
1871    }
1872
1873    #[test]
1874    fn test_kernel_reduction_shared_declaration() {
1875        let config = KernelReductionConfig::new()
1876            .with_op(ReductionOp::Sum)
1877            .with_type("double")
1878            .with_shared_name("reduction_shared");
1879
1880        let decl = config.generate_shared_declaration(256);
1881        assert!(decl.contains("__shared__"));
1882        assert!(decl.contains("double"));
1883        assert!(decl.contains("reduction_shared"));
1884        assert!(decl.contains("[256]"));
1885    }
1886
1887    #[test]
1888    fn test_kernel_reduction_accumulator_param() {
1889        let config = KernelReductionConfig::new()
1890            .with_op(ReductionOp::Sum)
1891            .with_type("double")
1892            .with_accumulator_name("dangling_sum");
1893
1894        let param = config.generate_accumulator_param();
1895        assert!(param.contains("double*"));
1896        assert!(param.contains("__restrict__"));
1897        assert!(param.contains("dangling_sum"));
1898    }
1899
1900    #[test]
1901    fn test_kernel_reduction_disabled_generates_empty() {
1902        let config = KernelReductionConfig::new(); // Not enabled
1903        assert!(config.generate_shared_declaration(256).is_empty());
1904        assert!(config.generate_accumulator_param().is_empty());
1905    }
1906
1907    #[test]
1908    fn test_ring_kernel_config_with_reduction() {
1909        let reduction = KernelReductionConfig::new()
1910            .with_op(ReductionOp::Sum)
1911            .with_type("double");
1912
1913        let config = RingKernelConfig::new("pagerank")
1914            .with_block_size(256)
1915            .with_reduction(reduction);
1916
1917        assert!(config.reduction.enabled);
1918        assert_eq!(config.reduction.op, ReductionOp::Sum);
1919        // Should automatically enable cooperative groups
1920        assert!(config.cooperative_groups);
1921    }
1922
1923    #[test]
1924    fn test_ring_kernel_config_with_sum_reduction() {
1925        let config = RingKernelConfig::new("pagerank")
1926            .with_block_size(256)
1927            .with_sum_reduction();
1928
1929        assert!(config.reduction.enabled);
1930        assert_eq!(config.reduction.op, ReductionOp::Sum);
1931        assert_eq!(config.reduction.accumulator_type, "double");
1932        assert!(config.cooperative_groups);
1933    }
1934
1935    #[test]
1936    fn test_ring_kernel_config_reduction_without_cooperative() {
1937        let reduction = KernelReductionConfig::new()
1938            .with_op(ReductionOp::Sum)
1939            .with_cooperative(false);
1940
1941        let config = RingKernelConfig::new("pagerank").with_reduction(reduction);
1942
1943        assert!(config.reduction.enabled);
1944        // Should NOT automatically enable cooperative groups when use_cooperative is false
1945        assert!(!config.cooperative_groups);
1946    }
1947
1948    #[test]
1949    fn test_reduction_op_default() {
1950        let op = ReductionOp::default();
1951        assert_eq!(op, ReductionOp::Sum);
1952    }
1953
1954    // ========================================================================
1955    // Handler Dispatch Code Generation Tests
1956    // ========================================================================
1957
1958    #[test]
1959    fn test_cuda_handler_info_creation() {
1960        let handler = CudaHandlerInfo::new(1, "handle_fraud_check")
1961            .with_message_type("FraudCheckRequest", 1001)
1962            .with_response();
1963
1964        assert_eq!(handler.handler_id, 1);
1965        assert_eq!(handler.func_name, "handle_fraud_check");
1966        assert_eq!(handler.message_type_name, "FraudCheckRequest");
1967        assert_eq!(handler.message_type_id, 1001);
1968        assert!(handler.produces_response);
1969        assert!(handler.cuda_body.is_none());
1970    }
1971
1972    #[test]
1973    fn test_cuda_handler_info_with_body() {
1974        let handler = CudaHandlerInfo::new(2, "handle_aggregate")
1975            .with_cuda_body("float sum = 0.0f;\nfor (int i = 0; i < 10; i++) sum += data[i];");
1976
1977        assert_eq!(handler.handler_id, 2);
1978        assert!(handler.cuda_body.is_some());
1979        assert!(handler.cuda_body.as_ref().unwrap().contains("float sum"));
1980    }
1981
1982    #[test]
1983    fn test_cuda_dispatch_table_creation() {
1984        let table = CudaDispatchTable::new()
1985            .with_handler(CudaHandlerInfo::new(1, "handle_a"))
1986            .with_handler(CudaHandlerInfo::new(2, "handle_b"));
1987
1988        assert_eq!(table.len(), 2);
1989        assert!(!table.is_empty());
1990        assert_eq!(table.handlers()[0].handler_id, 1);
1991        assert_eq!(table.handlers()[1].handler_id, 2);
1992    }
1993
1994    #[test]
1995    fn test_generate_handler_dispatch_code_empty() {
1996        let table = CudaDispatchTable::new();
1997        let code = generate_handler_dispatch_code(&table, "    ");
1998
1999        assert!(code.contains("No handlers registered"));
2000        assert!(!code.contains("switch"));
2001    }
2002
2003    #[test]
2004    fn test_generate_handler_dispatch_code_single_handler() {
2005        let table = CudaDispatchTable::new().with_handler(
2006            CudaHandlerInfo::new(1, "handle_fraud").with_message_type("FraudCheck", 1001),
2007        );
2008
2009        let code = generate_handler_dispatch_code(&table, "    ");
2010
2011        assert!(code.contains("switch (handler_id)"));
2012        assert!(code.contains("case 1:"));
2013        assert!(code.contains("handle_fraud(msg, state)"));
2014        assert!(code.contains("default:"));
2015        assert!(code.contains("unknown_handler_count"));
2016    }
2017
2018    #[test]
2019    fn test_generate_handler_dispatch_code_multiple_handlers() {
2020        let table = CudaDispatchTable::new()
2021            .with_handler(
2022                CudaHandlerInfo::new(1, "handle_fraud").with_message_type("FraudCheck", 1001),
2023            )
2024            .with_handler(
2025                CudaHandlerInfo::new(2, "handle_aggregate")
2026                    .with_message_type("Aggregate", 1002)
2027                    .with_response(),
2028            )
2029            .with_handler(
2030                CudaHandlerInfo::new(5, "handle_pattern").with_message_type("Pattern", 1005),
2031            );
2032
2033        let code = generate_handler_dispatch_code(&table, "    ");
2034
2035        assert!(code.contains("case 1:"));
2036        assert!(code.contains("case 2:"));
2037        assert!(code.contains("case 5:"));
2038        assert!(code.contains("handle_fraud(msg, state)"));
2039        assert!(code.contains("handle_aggregate(msg, state, response)")); // With response
2040        assert!(code.contains("handle_pattern(msg, state)"));
2041    }
2042
2043    #[test]
2044    fn test_generate_handler_dispatch_code_with_inline_body() {
2045        let table = CudaDispatchTable::new().with_handler(
2046            CudaHandlerInfo::new(1, "inline_handler")
2047                .with_cuda_body("int result = msg->payload[0] * 2;\nresponse->result = result;"),
2048        );
2049
2050        let code = generate_handler_dispatch_code(&table, "    ");
2051
2052        assert!(code.contains("case 1:"));
2053        assert!(code.contains("int result = msg->payload[0] * 2;"));
2054        assert!(code.contains("response->result = result;"));
2055        // Should NOT contain function call when body is inline
2056        assert!(!code.contains("inline_handler(msg,"));
2057    }
2058
2059    #[test]
2060    fn test_generate_extended_h2k_message_struct() {
2061        let code = generate_extended_h2k_message_struct();
2062
2063        assert!(code.contains("struct __align__(64) ExtendedH2KMessage"));
2064        assert!(code.contains("uint32_t handler_id"));
2065        assert!(code.contains("uint32_t flags"));
2066        assert!(code.contains("uint64_t cmd_id"));
2067        assert!(code.contains("uint64_t timestamp"));
2068        assert!(code.contains("uint8_t payload[40]"));
2069        assert!(code.contains("EXT_MSG_FLAG_EXTENDED"));
2070        assert!(code.contains("EXT_MSG_FLAG_REQUIRES_RESP"));
2071    }
2072
2073    #[test]
2074    fn test_generate_multi_handler_kernel() {
2075        let config = RingKernelConfig::new("multi_handler")
2076            .with_block_size(128)
2077            .with_hlc(true);
2078
2079        let table = CudaDispatchTable::new()
2080            .with_handler(CudaHandlerInfo::new(1, "handle_fraud"))
2081            .with_handler(CudaHandlerInfo::new(2, "handle_aggregate"));
2082
2083        let code = generate_multi_handler_kernel(&config, &table);
2084
2085        // Should contain struct definitions
2086        assert!(code.contains("struct __align__(128) ControlBlock"));
2087        assert!(code.contains("struct __align__(64) ExtendedH2KMessage"));
2088
2089        // Should contain kernel signature
2090        assert!(code.contains("ring_kernel_multi_handler"));
2091
2092        // Should contain dispatch code
2093        assert!(code.contains("MULTI-HANDLER DISPATCH"));
2094        assert!(code.contains("switch (handler_id)"));
2095        assert!(code.contains("case 1:"));
2096        assert!(code.contains("case 2:"));
2097    }
2098}