ringkernel_cuda_codegen/
ring_kernel.rs

1//! Ring kernel configuration and code generation for persistent actor kernels.
2//!
3//! This module provides the infrastructure for generating CUDA code for
4//! persistent actor kernels that process messages in a loop.
5//!
6//! # Overview
7//!
8//! Ring kernels are persistent GPU kernels that:
9//! - Run continuously until terminated
10//! - Process messages from input queues
11//! - Produce responses to output queues
12//! - Use HLC (Hybrid Logical Clocks) for causal ordering
13//! - Support kernel-to-kernel (K2K) messaging
14//!
15//! # Generated Kernel Structure
16//!
17//! ```cuda
18//! extern "C" __global__ void ring_kernel_NAME(
19//!     ControlBlock* __restrict__ control,
20//!     unsigned char* __restrict__ input_buffer,
21//!     unsigned char* __restrict__ output_buffer,
22//!     void* __restrict__ shared_state
23//! ) {
24//!     // Preamble: thread setup
25//!     int tid = threadIdx.x + blockIdx.x * blockDim.x;
26//!
27//!     // Persistent message loop
28//!     while (!atomicLoad(&control->should_terminate)) {
29//!         if (!atomicLoad(&control->is_active)) {
30//!             __nanosleep(1000);
31//!             continue;
32//!         }
33//!
34//!         // Message processing...
35//!         // (user handler code inserted here)
36//!     }
37//!
38//!     // Epilogue: mark terminated
39//!     if (tid == 0) {
40//!         atomicStore(&control->has_terminated, 1);
41//!     }
42//! }
43//! ```
44
45use std::fmt::Write;
46
47/// Configuration for a ring kernel.
48#[derive(Debug, Clone)]
49pub struct RingKernelConfig {
50    /// Kernel identifier (used in function name).
51    pub id: String,
52    /// Block size (threads per block).
53    pub block_size: u32,
54    /// Input queue capacity (must be power of 2).
55    pub queue_capacity: u32,
56    /// Enable kernel-to-kernel messaging.
57    pub enable_k2k: bool,
58    /// Enable HLC clock operations.
59    pub enable_hlc: bool,
60    /// Message size in bytes (for buffer offset calculations).
61    pub message_size: usize,
62    /// Response size in bytes.
63    pub response_size: usize,
64    /// Use cooperative thread groups.
65    pub cooperative_groups: bool,
66    /// Nanosleep duration when idle (0 to spin).
67    pub idle_sleep_ns: u32,
68    /// Use MessageEnvelope format (256-byte header + payload).
69    /// When enabled, messages in queues are full envelopes with headers.
70    pub use_envelope_format: bool,
71    /// Kernel numeric ID (for response routing).
72    pub kernel_id_num: u64,
73    /// HLC node ID for this kernel.
74    pub hlc_node_id: u64,
75}
76
77impl Default for RingKernelConfig {
78    fn default() -> Self {
79        Self {
80            id: "ring_kernel".to_string(),
81            block_size: 128,
82            queue_capacity: 1024,
83            enable_k2k: false,
84            enable_hlc: true,
85            message_size: 64,
86            response_size: 64,
87            cooperative_groups: false,
88            idle_sleep_ns: 1000,
89            use_envelope_format: true, // Default to envelope format for proper serialization
90            kernel_id_num: 0,
91            hlc_node_id: 0,
92        }
93    }
94}
95
96impl RingKernelConfig {
97    /// Create a new ring kernel configuration with the given ID.
98    pub fn new(id: impl Into<String>) -> Self {
99        Self {
100            id: id.into(),
101            ..Default::default()
102        }
103    }
104
105    /// Set the block size.
106    pub fn with_block_size(mut self, size: u32) -> Self {
107        self.block_size = size;
108        self
109    }
110
111    /// Set the queue capacity (must be power of 2).
112    pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
113        debug_assert!(
114            capacity.is_power_of_two(),
115            "Queue capacity must be power of 2"
116        );
117        self.queue_capacity = capacity;
118        self
119    }
120
121    /// Enable kernel-to-kernel messaging.
122    pub fn with_k2k(mut self, enabled: bool) -> Self {
123        self.enable_k2k = enabled;
124        self
125    }
126
127    /// Enable HLC clock operations.
128    pub fn with_hlc(mut self, enabled: bool) -> Self {
129        self.enable_hlc = enabled;
130        self
131    }
132
133    /// Set message and response sizes.
134    pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
135        self.message_size = message;
136        self.response_size = response;
137        self
138    }
139
140    /// Set idle sleep duration in nanoseconds.
141    pub fn with_idle_sleep(mut self, ns: u32) -> Self {
142        self.idle_sleep_ns = ns;
143        self
144    }
145
146    /// Enable or disable MessageEnvelope format.
147    /// When enabled, messages use the 256-byte header + payload format.
148    pub fn with_envelope_format(mut self, enabled: bool) -> Self {
149        self.use_envelope_format = enabled;
150        self
151    }
152
153    /// Set the kernel numeric ID (for K2K routing and response headers).
154    pub fn with_kernel_id(mut self, id: u64) -> Self {
155        self.kernel_id_num = id;
156        self
157    }
158
159    /// Set the HLC node ID for this kernel.
160    pub fn with_hlc_node_id(mut self, node_id: u64) -> Self {
161        self.hlc_node_id = node_id;
162        self
163    }
164
165    /// Generate the kernel function name.
166    pub fn kernel_name(&self) -> String {
167        format!("ring_kernel_{}", self.id)
168    }
169
170    /// Generate CUDA kernel signature.
171    pub fn generate_signature(&self) -> String {
172        let mut sig = String::new();
173
174        writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
175        writeln!(sig, "    ControlBlock* __restrict__ control,").unwrap();
176        writeln!(sig, "    unsigned char* __restrict__ input_buffer,").unwrap();
177        writeln!(sig, "    unsigned char* __restrict__ output_buffer,").unwrap();
178
179        if self.enable_k2k {
180            writeln!(sig, "    K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
181            writeln!(sig, "    unsigned char* __restrict__ k2k_inbox,").unwrap();
182            writeln!(sig, "    unsigned char* __restrict__ k2k_outbox,").unwrap();
183        }
184
185        write!(sig, "    void* __restrict__ shared_state").unwrap();
186        write!(sig, "\n)").unwrap();
187
188        sig
189    }
190
191    /// Generate the kernel preamble (thread setup, variable declarations).
192    pub fn generate_preamble(&self, indent: &str) -> String {
193        let mut code = String::new();
194
195        // Thread identification
196        writeln!(code, "{}// Thread identification", indent).unwrap();
197        writeln!(
198            code,
199            "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
200            indent
201        )
202        .unwrap();
203        writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
204        writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
205        writeln!(code).unwrap();
206
207        // Kernel ID constant
208        writeln!(code, "{}// Kernel identity", indent).unwrap();
209        writeln!(
210            code,
211            "{}const unsigned long long KERNEL_ID = {}ULL;",
212            indent, self.kernel_id_num
213        )
214        .unwrap();
215        writeln!(
216            code,
217            "{}const unsigned long long HLC_NODE_ID = {}ULL;",
218            indent, self.hlc_node_id
219        )
220        .unwrap();
221        writeln!(code).unwrap();
222
223        // Message size constants
224        writeln!(code, "{}// Message buffer constants", indent).unwrap();
225        if self.use_envelope_format {
226            // With envelope format, MSG_SIZE is envelope size (header + payload)
227            writeln!(
228                code,
229                "{}const unsigned int PAYLOAD_SIZE = {};  // User payload size",
230                indent, self.message_size
231            )
232            .unwrap();
233            writeln!(
234                code,
235                "{}const unsigned int MSG_SIZE = MESSAGE_HEADER_SIZE + PAYLOAD_SIZE;  // Full envelope",
236                indent
237            )
238            .unwrap();
239            writeln!(
240                code,
241                "{}const unsigned int RESP_PAYLOAD_SIZE = {};",
242                indent, self.response_size
243            )
244            .unwrap();
245            writeln!(
246                code,
247                "{}const unsigned int RESP_SIZE = MESSAGE_HEADER_SIZE + RESP_PAYLOAD_SIZE;",
248                indent
249            )
250            .unwrap();
251        } else {
252            // Legacy: raw message data
253            writeln!(
254                code,
255                "{}const unsigned int MSG_SIZE = {};",
256                indent, self.message_size
257            )
258            .unwrap();
259            writeln!(
260                code,
261                "{}const unsigned int RESP_SIZE = {};",
262                indent, self.response_size
263            )
264            .unwrap();
265        }
266        writeln!(
267            code,
268            "{}const unsigned int QUEUE_MASK = {};",
269            indent,
270            self.queue_capacity - 1
271        )
272        .unwrap();
273        writeln!(code).unwrap();
274
275        // HLC state (if enabled)
276        if self.enable_hlc {
277            writeln!(code, "{}// HLC clock state", indent).unwrap();
278            writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
279            writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
280            writeln!(code).unwrap();
281        }
282
283        code
284    }
285
286    /// Generate the persistent message loop header.
287    pub fn generate_loop_header(&self, indent: &str) -> String {
288        let mut code = String::new();
289
290        writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
291        writeln!(code, "{}while (true) {{", indent).unwrap();
292        writeln!(code, "{}    // Check for termination signal", indent).unwrap();
293        writeln!(
294            code,
295            "{}    if (atomicAdd(&control->should_terminate, 0) != 0) {{",
296            indent
297        )
298        .unwrap();
299        writeln!(code, "{}        break;", indent).unwrap();
300        writeln!(code, "{}    }}", indent).unwrap();
301        writeln!(code).unwrap();
302
303        // Check if active
304        writeln!(code, "{}    // Check if kernel is active", indent).unwrap();
305        writeln!(
306            code,
307            "{}    if (atomicAdd(&control->is_active, 0) == 0) {{",
308            indent
309        )
310        .unwrap();
311        if self.idle_sleep_ns > 0 {
312            writeln!(
313                code,
314                "{}        __nanosleep({});",
315                indent, self.idle_sleep_ns
316            )
317            .unwrap();
318        }
319        writeln!(code, "{}        continue;", indent).unwrap();
320        writeln!(code, "{}    }}", indent).unwrap();
321        writeln!(code).unwrap();
322
323        // Check for messages
324        writeln!(code, "{}    // Check input queue for messages", indent).unwrap();
325        writeln!(
326            code,
327            "{}    unsigned long long head = atomicAdd(&control->input_head, 0);",
328            indent
329        )
330        .unwrap();
331        writeln!(
332            code,
333            "{}    unsigned long long tail = atomicAdd(&control->input_tail, 0);",
334            indent
335        )
336        .unwrap();
337        writeln!(code).unwrap();
338        writeln!(code, "{}    if (head == tail) {{", indent).unwrap();
339        writeln!(code, "{}        // No messages, yield", indent).unwrap();
340        if self.idle_sleep_ns > 0 {
341            writeln!(
342                code,
343                "{}        __nanosleep({});",
344                indent, self.idle_sleep_ns
345            )
346            .unwrap();
347        }
348        writeln!(code, "{}        continue;", indent).unwrap();
349        writeln!(code, "{}    }}", indent).unwrap();
350        writeln!(code).unwrap();
351
352        // Calculate message pointer
353        writeln!(code, "{}    // Get message from queue", indent).unwrap();
354        writeln!(
355            code,
356            "{}    unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
357            indent
358        )
359        .unwrap();
360        writeln!(
361            code,
362            "{}    unsigned char* envelope_ptr = &input_buffer[msg_idx * MSG_SIZE];",
363            indent
364        )
365        .unwrap();
366
367        if self.use_envelope_format {
368            // Envelope format: parse header and get payload pointer
369            writeln!(code).unwrap();
370            writeln!(code, "{}    // Parse message envelope", indent).unwrap();
371            writeln!(
372                code,
373                "{}    MessageHeader* msg_header = message_get_header(envelope_ptr);",
374                indent
375            )
376            .unwrap();
377            writeln!(
378                code,
379                "{}    unsigned char* msg_ptr = message_get_payload(envelope_ptr);",
380                indent
381            )
382            .unwrap();
383            writeln!(code).unwrap();
384            writeln!(code, "{}    // Validate message (skip invalid)", indent).unwrap();
385            writeln!(
386                code,
387                "{}    if (!message_header_validate(msg_header)) {{",
388                indent
389            )
390            .unwrap();
391            writeln!(
392                code,
393                "{}        atomicAdd(&control->input_tail, 1);",
394                indent
395            )
396            .unwrap();
397            writeln!(
398                code,
399                "{}        atomicAdd(&control->last_error, 1);  // Track errors",
400                indent
401            )
402            .unwrap();
403            writeln!(code, "{}        continue;", indent).unwrap();
404            writeln!(code, "{}    }}", indent).unwrap();
405
406            // Update HLC from incoming message timestamp
407            if self.enable_hlc {
408                writeln!(code).unwrap();
409                writeln!(code, "{}    // Update HLC from message timestamp", indent).unwrap();
410                writeln!(
411                    code,
412                    "{}    if (msg_header->timestamp.physical > hlc_physical) {{",
413                    indent
414                )
415                .unwrap();
416                writeln!(
417                    code,
418                    "{}        hlc_physical = msg_header->timestamp.physical;",
419                    indent
420                )
421                .unwrap();
422                writeln!(code, "{}        hlc_logical = 0;", indent).unwrap();
423                writeln!(code, "{}    }}", indent).unwrap();
424                writeln!(code, "{}    hlc_logical++;", indent).unwrap();
425            }
426        } else {
427            // Legacy raw format
428            writeln!(
429                code,
430                "{}    unsigned char* msg_ptr = envelope_ptr;  // Raw message data",
431                indent
432            )
433            .unwrap();
434        }
435        writeln!(code).unwrap();
436
437        code
438    }
439
440    /// Generate message processing completion code.
441    pub fn generate_message_complete(&self, indent: &str) -> String {
442        let mut code = String::new();
443
444        writeln!(code).unwrap();
445        writeln!(code, "{}    // Mark message as processed", indent).unwrap();
446        writeln!(code, "{}    atomicAdd(&control->input_tail, 1);", indent).unwrap();
447        writeln!(
448            code,
449            "{}    atomicAdd(&control->messages_processed, 1);",
450            indent
451        )
452        .unwrap();
453
454        if self.enable_hlc {
455            writeln!(code).unwrap();
456            writeln!(code, "{}    // Update HLC", indent).unwrap();
457            writeln!(code, "{}    hlc_logical++;", indent).unwrap();
458        }
459
460        code
461    }
462
463    /// Generate the loop footer (end of while loop).
464    pub fn generate_loop_footer(&self, indent: &str) -> String {
465        let mut code = String::new();
466
467        writeln!(code, "{}    __syncthreads();", indent).unwrap();
468        writeln!(code, "{}}}", indent).unwrap();
469
470        code
471    }
472
473    /// Generate kernel epilogue (termination marking).
474    pub fn generate_epilogue(&self, indent: &str) -> String {
475        let mut code = String::new();
476
477        writeln!(code).unwrap();
478        writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
479        writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
480
481        if self.enable_hlc {
482            writeln!(code, "{}    // Store final HLC state", indent).unwrap();
483            writeln!(
484                code,
485                "{}    control->hlc_state.physical = hlc_physical;",
486                indent
487            )
488            .unwrap();
489            writeln!(
490                code,
491                "{}    control->hlc_state.logical = hlc_logical;",
492                indent
493            )
494            .unwrap();
495        }
496
497        writeln!(
498            code,
499            "{}    atomicExch(&control->has_terminated, 1);",
500            indent
501        )
502        .unwrap();
503        writeln!(code, "{}}}", indent).unwrap();
504
505        code
506    }
507
508    /// Generate complete kernel wrapper (without handler body).
509    pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
510        let mut code = String::new();
511
512        // Struct definitions
513        code.push_str(&generate_control_block_struct());
514        code.push('\n');
515
516        if self.enable_hlc {
517            code.push_str(&generate_hlc_struct());
518            code.push('\n');
519        }
520
521        // MessageEnvelope structs (requires HLC for HlcTimestamp)
522        if self.use_envelope_format {
523            code.push_str(&generate_message_envelope_structs());
524            code.push('\n');
525        }
526
527        if self.enable_k2k {
528            code.push_str(&generate_k2k_structs());
529            code.push('\n');
530        }
531
532        // Kernel signature
533        code.push_str(&self.generate_signature());
534        code.push_str(" {\n");
535
536        // Preamble
537        code.push_str(&self.generate_preamble("    "));
538
539        // Message loop
540        code.push_str(&self.generate_loop_header("    "));
541
542        // Handler placeholder
543        writeln!(code, "        // === USER HANDLER CODE ===").unwrap();
544        for line in handler_placeholder.lines() {
545            writeln!(code, "        {}", line).unwrap();
546        }
547        writeln!(code, "        // === END HANDLER CODE ===").unwrap();
548
549        // Message completion
550        code.push_str(&self.generate_message_complete("    "));
551
552        // Loop footer
553        code.push_str(&self.generate_loop_footer("    "));
554
555        // Epilogue
556        code.push_str(&self.generate_epilogue("    "));
557
558        code.push_str("}\n");
559
560        code
561    }
562}
563
564/// Generate CUDA ControlBlock struct definition.
565pub fn generate_control_block_struct() -> String {
566    r#"// Control block for kernel state management (128 bytes, cache-line aligned)
567struct __align__(128) ControlBlock {
568    // Lifecycle state
569    unsigned int is_active;
570    unsigned int should_terminate;
571    unsigned int has_terminated;
572    unsigned int _pad1;
573
574    // Counters
575    unsigned long long messages_processed;
576    unsigned long long messages_in_flight;
577
578    // Queue pointers
579    unsigned long long input_head;
580    unsigned long long input_tail;
581    unsigned long long output_head;
582    unsigned long long output_tail;
583
584    // Queue metadata
585    unsigned int input_capacity;
586    unsigned int output_capacity;
587    unsigned int input_mask;
588    unsigned int output_mask;
589
590    // HLC state
591    struct {
592        unsigned long long physical;
593        unsigned long long logical;
594    } hlc_state;
595
596    // Error state
597    unsigned int last_error;
598    unsigned int error_count;
599
600    // Reserved padding
601    unsigned char _reserved[24];
602};
603"#
604    .to_string()
605}
606
607/// Generate CUDA HLC helper struct.
608pub fn generate_hlc_struct() -> String {
609    r#"// Hybrid Logical Clock state
610struct HlcState {
611    unsigned long long physical;
612    unsigned long long logical;
613};
614
615// HLC timestamp (24 bytes)
616struct __align__(8) HlcTimestamp {
617    unsigned long long physical;
618    unsigned long long logical;
619    unsigned long long node_id;
620};
621"#
622    .to_string()
623}
624
625/// Generate CUDA MessageEnvelope structs matching ringkernel-core layout.
626///
627/// This generates the GPU-side structures that match the Rust `MessageHeader`
628/// and `MessageEnvelope` types, enabling proper serialization between host and device.
629pub fn generate_message_envelope_structs() -> String {
630    r#"// Magic number for message validation
631#define MESSAGE_MAGIC 0x52494E474B45524E ULL  // "RINGKERN"
632#define MESSAGE_VERSION 1
633#define MESSAGE_HEADER_SIZE 256
634#define MAX_PAYLOAD_SIZE (64 * 1024)
635
636// Message priority levels
637#define PRIORITY_LOW 0
638#define PRIORITY_NORMAL 1
639#define PRIORITY_HIGH 2
640#define PRIORITY_CRITICAL 3
641
642// Message header structure (256 bytes, cache-line aligned)
643// Matches ringkernel_core::message::MessageHeader exactly
644struct __align__(64) MessageHeader {
645    // Magic number for validation (0xRINGKERN)
646    unsigned long long magic;
647    // Header version
648    unsigned int version;
649    // Message flags
650    unsigned int flags;
651    // Unique message identifier
652    unsigned long long message_id;
653    // Correlation ID for request-response
654    unsigned long long correlation_id;
655    // Source kernel ID (0 for host)
656    unsigned long long source_kernel;
657    // Destination kernel ID (0 for host)
658    unsigned long long dest_kernel;
659    // Message type discriminator
660    unsigned long long message_type;
661    // Priority level
662    unsigned char priority;
663    // Reserved for alignment
664    unsigned char _reserved1[7];
665    // Payload size in bytes
666    unsigned long long payload_size;
667    // Checksum of payload (CRC32)
668    unsigned int checksum;
669    // Reserved for alignment
670    unsigned int _reserved2;
671    // HLC timestamp when message was created
672    HlcTimestamp timestamp;
673    // Deadline timestamp (0 = no deadline)
674    HlcTimestamp deadline;
675    // Reserved for future use (104 bytes total: 32+32+32+8)
676    unsigned char _reserved3[104];
677};
678
679// Validate message header
680__device__ inline int message_header_validate(const MessageHeader* header) {
681    return header->magic == MESSAGE_MAGIC &&
682           header->version <= MESSAGE_VERSION &&
683           header->payload_size <= MAX_PAYLOAD_SIZE;
684}
685
686// Get payload pointer from header
687__device__ inline unsigned char* message_get_payload(unsigned char* envelope_ptr) {
688    return envelope_ptr + MESSAGE_HEADER_SIZE;
689}
690
691// Get header from envelope pointer
692__device__ inline MessageHeader* message_get_header(unsigned char* envelope_ptr) {
693    return (MessageHeader*)envelope_ptr;
694}
695
696// Create a response header based on request
697__device__ inline void message_create_response_header(
698    MessageHeader* response,
699    const MessageHeader* request,
700    unsigned long long this_kernel_id,
701    unsigned long long payload_size,
702    unsigned long long hlc_physical,
703    unsigned long long hlc_logical,
704    unsigned long long hlc_node_id
705) {
706    response->magic = MESSAGE_MAGIC;
707    response->version = MESSAGE_VERSION;
708    response->flags = 0;
709    // Generate new message ID (simple increment from request)
710    response->message_id = request->message_id + 0x100000000ULL;
711    // Preserve correlation ID for request-response matching
712    response->correlation_id = request->correlation_id != 0
713        ? request->correlation_id
714        : request->message_id;
715    response->source_kernel = this_kernel_id;
716    response->dest_kernel = request->source_kernel;  // Response goes back to sender
717    response->message_type = request->message_type + 1;  // Convention: response type = request + 1
718    response->priority = request->priority;
719    response->payload_size = payload_size;
720    response->checksum = 0;  // TODO: compute checksum
721    response->timestamp.physical = hlc_physical;
722    response->timestamp.logical = hlc_logical;
723    response->timestamp.node_id = hlc_node_id;
724    response->deadline.physical = 0;
725    response->deadline.logical = 0;
726    response->deadline.node_id = 0;
727}
728
729// Calculate total envelope size
730__device__ inline unsigned int message_envelope_size(const MessageHeader* header) {
731    return MESSAGE_HEADER_SIZE + (unsigned int)header->payload_size;
732}
733"#
734    .to_string()
735}
736
737/// Generate CUDA K2K routing structs and helper functions.
738pub fn generate_k2k_structs() -> String {
739    r#"// Kernel-to-kernel routing table entry (48 bytes)
740// Matches ringkernel_cuda::k2k_gpu::K2KRouteEntry
741struct K2KRoute {
742    unsigned long long target_kernel_id;
743    unsigned long long target_inbox;      // device pointer
744    unsigned long long target_head;       // device pointer to head
745    unsigned long long target_tail;       // device pointer to tail
746    unsigned int capacity;
747    unsigned int mask;
748    unsigned int msg_size;
749    unsigned int _pad;
750};
751
752// K2K routing table (8 + 48*16 = 776 bytes)
753struct K2KRoutingTable {
754    unsigned int num_routes;
755    unsigned int _pad;
756    K2KRoute routes[16];  // Max 16 K2K connections
757};
758
759// K2K inbox header (64 bytes, cache-line aligned)
760// Matches ringkernel_cuda::k2k_gpu::K2KInboxHeader
761struct __align__(64) K2KInboxHeader {
762    unsigned long long head;
763    unsigned long long tail;
764    unsigned int capacity;
765    unsigned int mask;
766    unsigned int msg_size;
767    unsigned int _pad;
768};
769
770// Send a message envelope to another kernel via K2K
771// The entire envelope (header + payload) is copied to the target's inbox
772__device__ inline int k2k_send_envelope(
773    K2KRoutingTable* routes,
774    unsigned long long target_id,
775    unsigned long long source_kernel_id,
776    const void* payload_ptr,
777    unsigned int payload_size,
778    unsigned long long message_type,
779    unsigned long long hlc_physical,
780    unsigned long long hlc_logical,
781    unsigned long long hlc_node_id
782) {
783    // Find route for target
784    for (unsigned int i = 0; i < routes->num_routes; i++) {
785        if (routes->routes[i].target_kernel_id == target_id) {
786            K2KRoute* route = &routes->routes[i];
787
788            // Calculate total envelope size
789            unsigned int envelope_size = MESSAGE_HEADER_SIZE + payload_size;
790            if (envelope_size > route->msg_size) {
791                return -1;  // Message too large
792            }
793
794            // Atomically claim a slot in target's inbox
795            unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
796            unsigned long long slot = atomicAdd(target_head_ptr, 1);
797            unsigned int idx = (unsigned int)(slot & route->mask);
798
799            // Calculate destination pointer
800            unsigned char* dest = ((unsigned char*)route->target_inbox) +
801                                  sizeof(K2KInboxHeader) + idx * route->msg_size;
802
803            // Build message header
804            MessageHeader* header = (MessageHeader*)dest;
805            header->magic = MESSAGE_MAGIC;
806            header->version = MESSAGE_VERSION;
807            header->flags = 0;
808            header->message_id = (source_kernel_id << 32) | (slot & 0xFFFFFFFF);
809            header->correlation_id = 0;
810            header->source_kernel = source_kernel_id;
811            header->dest_kernel = target_id;
812            header->message_type = message_type;
813            header->priority = PRIORITY_NORMAL;
814            header->payload_size = payload_size;
815            header->checksum = 0;
816            header->timestamp.physical = hlc_physical;
817            header->timestamp.logical = hlc_logical;
818            header->timestamp.node_id = hlc_node_id;
819            header->deadline.physical = 0;
820            header->deadline.logical = 0;
821            header->deadline.node_id = 0;
822
823            // Copy payload after header
824            if (payload_size > 0 && payload_ptr != NULL) {
825                memcpy(dest + MESSAGE_HEADER_SIZE, payload_ptr, payload_size);
826            }
827
828            __threadfence();  // Ensure write is visible
829            return 1;  // Success
830        }
831    }
832    return 0;  // Route not found
833}
834
835// Legacy k2k_send for raw message data (no envelope)
836__device__ inline int k2k_send(
837    K2KRoutingTable* routes,
838    unsigned long long target_id,
839    const void* msg_ptr,
840    unsigned int msg_size
841) {
842    // Find route for target
843    for (unsigned int i = 0; i < routes->num_routes; i++) {
844        if (routes->routes[i].target_kernel_id == target_id) {
845            K2KRoute* route = &routes->routes[i];
846
847            // Atomically claim a slot in target's inbox
848            unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
849            unsigned long long slot = atomicAdd(target_head_ptr, 1);
850            unsigned int idx = (unsigned int)(slot & route->mask);
851
852            // Copy message to target inbox
853            unsigned char* dest = ((unsigned char*)route->target_inbox) +
854                                  sizeof(K2KInboxHeader) + idx * route->msg_size;
855            memcpy(dest, msg_ptr, msg_size < route->msg_size ? msg_size : route->msg_size);
856
857            __threadfence();
858            return 1;  // Success
859        }
860    }
861    return 0;  // Route not found
862}
863
864// Check if there are K2K messages in inbox
865__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
866    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
867    unsigned long long head = atomicAdd(&header->head, 0);
868    unsigned long long tail = atomicAdd(&header->tail, 0);
869    return head != tail;
870}
871
872// Try to receive a K2K message envelope
873// Returns pointer to MessageHeader if available, NULL otherwise
874__device__ inline MessageHeader* k2k_try_recv_envelope(unsigned char* k2k_inbox) {
875    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
876
877    unsigned long long head = atomicAdd(&header->head, 0);
878    unsigned long long tail = atomicAdd(&header->tail, 0);
879
880    if (head == tail) {
881        return NULL;  // No messages
882    }
883
884    // Get message pointer
885    unsigned int idx = (unsigned int)(tail & header->mask);
886    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
887    MessageHeader* msg_header = (MessageHeader*)(data_start + idx * header->msg_size);
888
889    // Validate header
890    if (!message_header_validate(msg_header)) {
891        // Invalid message, skip it
892        atomicAdd(&header->tail, 1);
893        return NULL;
894    }
895
896    // Advance tail (consume message)
897    atomicAdd(&header->tail, 1);
898
899    return msg_header;
900}
901
902// Legacy k2k_try_recv for raw data
903__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
904    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
905
906    unsigned long long head = atomicAdd(&header->head, 0);
907    unsigned long long tail = atomicAdd(&header->tail, 0);
908
909    if (head == tail) {
910        return NULL;  // No messages
911    }
912
913    // Get message pointer
914    unsigned int idx = (unsigned int)(tail & header->mask);
915    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
916
917    // Advance tail (consume message)
918    atomicAdd(&header->tail, 1);
919
920    return data_start + idx * header->msg_size;
921}
922
923// Peek at next K2K message without consuming
924__device__ inline MessageHeader* k2k_peek_envelope(unsigned char* k2k_inbox) {
925    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
926
927    unsigned long long head = atomicAdd(&header->head, 0);
928    unsigned long long tail = atomicAdd(&header->tail, 0);
929
930    if (head == tail) {
931        return NULL;  // No messages
932    }
933
934    unsigned int idx = (unsigned int)(tail & header->mask);
935    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
936
937    return (MessageHeader*)(data_start + idx * header->msg_size);
938}
939
940// Legacy k2k_peek
941__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
942    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
943
944    unsigned long long head = atomicAdd(&header->head, 0);
945    unsigned long long tail = atomicAdd(&header->tail, 0);
946
947    if (head == tail) {
948        return NULL;  // No messages
949    }
950
951    unsigned int idx = (unsigned int)(tail & header->mask);
952    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
953
954    return data_start + idx * header->msg_size;
955}
956
957// Get number of pending K2K messages
958__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
959    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
960    unsigned long long head = atomicAdd(&header->head, 0);
961    unsigned long long tail = atomicAdd(&header->tail, 0);
962    return (unsigned int)(head - tail);
963}
964"#
965    .to_string()
966}
967
968/// Intrinsic functions available in ring kernel handlers.
969#[derive(Debug, Clone, Copy, PartialEq, Eq)]
970pub enum RingKernelIntrinsic {
971    // Control block access
972    IsActive,
973    ShouldTerminate,
974    MarkTerminated,
975    GetMessagesProcessed,
976
977    // Queue operations
978    InputQueueSize,
979    OutputQueueSize,
980    InputQueueEmpty,
981    OutputQueueEmpty,
982    EnqueueResponse,
983
984    // HLC operations
985    HlcTick,
986    HlcUpdate,
987    HlcNow,
988
989    // K2K operations
990    K2kSend,
991    K2kTryRecv,
992    K2kHasMessage,
993    K2kPeek,
994    K2kPendingCount,
995}
996
997impl RingKernelIntrinsic {
998    /// Get the CUDA code for this intrinsic.
999    pub fn to_cuda(&self, args: &[String]) -> String {
1000        match self {
1001            Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
1002            Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
1003            Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
1004            Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
1005
1006            Self::InputQueueSize => {
1007                "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
1008                    .to_string()
1009            }
1010            Self::OutputQueueSize => {
1011                "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
1012                    .to_string()
1013            }
1014            Self::InputQueueEmpty => {
1015                "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
1016                    .to_string()
1017            }
1018            Self::OutputQueueEmpty => {
1019                "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
1020                    .to_string()
1021            }
1022            Self::EnqueueResponse => {
1023                if !args.is_empty() {
1024                    format!(
1025                        "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
1026                         memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
1027                        args[0]
1028                    )
1029                } else {
1030                    "/* enqueue_response requires response pointer */".to_string()
1031                }
1032            }
1033
1034            Self::HlcTick => "hlc_logical++".to_string(),
1035            Self::HlcUpdate => {
1036                if !args.is_empty() {
1037                    format!(
1038                        "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
1039                        args[0], args[0]
1040                    )
1041                } else {
1042                    "hlc_logical++".to_string()
1043                }
1044            }
1045            Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
1046
1047            Self::K2kSend => {
1048                if args.len() >= 2 {
1049                    // k2k_send(target_id, msg_ptr) -> k2k_send(k2k_routes, target_id, msg_ptr, sizeof(*msg_ptr))
1050                    format!(
1051                        "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
1052                        args[0], args[1], args[1]
1053                    )
1054                } else {
1055                    "/* k2k_send requires target_id and msg_ptr */".to_string()
1056                }
1057            }
1058            Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
1059            Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
1060            Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
1061            Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
1062        }
1063    }
1064
1065    /// Parse a function name to get the intrinsic.
1066    pub fn from_name(name: &str) -> Option<Self> {
1067        match name {
1068            "is_active" | "is_kernel_active" => Some(Self::IsActive),
1069            "should_terminate" => Some(Self::ShouldTerminate),
1070            "mark_terminated" => Some(Self::MarkTerminated),
1071            "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
1072
1073            "input_queue_size" => Some(Self::InputQueueSize),
1074            "output_queue_size" => Some(Self::OutputQueueSize),
1075            "input_queue_empty" => Some(Self::InputQueueEmpty),
1076            "output_queue_empty" => Some(Self::OutputQueueEmpty),
1077            "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
1078
1079            "hlc_tick" => Some(Self::HlcTick),
1080            "hlc_update" => Some(Self::HlcUpdate),
1081            "hlc_now" => Some(Self::HlcNow),
1082
1083            "k2k_send" => Some(Self::K2kSend),
1084            "k2k_try_recv" => Some(Self::K2kTryRecv),
1085            "k2k_has_message" => Some(Self::K2kHasMessage),
1086            "k2k_peek" => Some(Self::K2kPeek),
1087            "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
1088
1089            _ => None,
1090        }
1091    }
1092
1093    /// Check if this intrinsic requires K2K support.
1094    pub fn requires_k2k(&self) -> bool {
1095        matches!(
1096            self,
1097            Self::K2kSend
1098                | Self::K2kTryRecv
1099                | Self::K2kHasMessage
1100                | Self::K2kPeek
1101                | Self::K2kPendingCount
1102        )
1103    }
1104
1105    /// Check if this intrinsic requires HLC support.
1106    pub fn requires_hlc(&self) -> bool {
1107        matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
1108    }
1109
1110    /// Check if this intrinsic requires control block access.
1111    pub fn requires_control_block(&self) -> bool {
1112        matches!(
1113            self,
1114            Self::IsActive
1115                | Self::ShouldTerminate
1116                | Self::MarkTerminated
1117                | Self::GetMessagesProcessed
1118                | Self::InputQueueSize
1119                | Self::OutputQueueSize
1120                | Self::InputQueueEmpty
1121                | Self::OutputQueueEmpty
1122                | Self::EnqueueResponse
1123        )
1124    }
1125}
1126
1127#[cfg(test)]
1128mod tests {
1129    use super::*;
1130
1131    #[test]
1132    fn test_default_config() {
1133        let config = RingKernelConfig::default();
1134        assert_eq!(config.block_size, 128);
1135        assert_eq!(config.queue_capacity, 1024);
1136        assert!(config.enable_hlc);
1137        assert!(!config.enable_k2k);
1138    }
1139
1140    #[test]
1141    fn test_config_builder() {
1142        let config = RingKernelConfig::new("processor")
1143            .with_block_size(256)
1144            .with_queue_capacity(2048)
1145            .with_k2k(true)
1146            .with_hlc(true);
1147
1148        assert_eq!(config.id, "processor");
1149        assert_eq!(config.kernel_name(), "ring_kernel_processor");
1150        assert_eq!(config.block_size, 256);
1151        assert_eq!(config.queue_capacity, 2048);
1152        assert!(config.enable_k2k);
1153        assert!(config.enable_hlc);
1154    }
1155
1156    #[test]
1157    fn test_kernel_signature() {
1158        let config = RingKernelConfig::new("test");
1159        let sig = config.generate_signature();
1160
1161        assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
1162        assert!(sig.contains("ControlBlock* __restrict__ control"));
1163        assert!(sig.contains("input_buffer"));
1164        assert!(sig.contains("output_buffer"));
1165        assert!(sig.contains("shared_state"));
1166    }
1167
1168    #[test]
1169    fn test_kernel_signature_with_k2k() {
1170        let config = RingKernelConfig::new("k2k_test").with_k2k(true);
1171        let sig = config.generate_signature();
1172
1173        assert!(sig.contains("K2KRoutingTable"));
1174        assert!(sig.contains("k2k_inbox"));
1175        assert!(sig.contains("k2k_outbox"));
1176    }
1177
1178    #[test]
1179    fn test_preamble_generation() {
1180        let config = RingKernelConfig::new("test").with_hlc(true);
1181        let preamble = config.generate_preamble("    ");
1182
1183        assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
1184        assert!(preamble.contains("int lane_id"));
1185        assert!(preamble.contains("int warp_id"));
1186        assert!(preamble.contains("MSG_SIZE"));
1187        assert!(preamble.contains("hlc_physical"));
1188        assert!(preamble.contains("hlc_logical"));
1189    }
1190
1191    #[test]
1192    fn test_loop_header() {
1193        let config = RingKernelConfig::new("test");
1194        let header = config.generate_loop_header("    ");
1195
1196        assert!(header.contains("while (true)"));
1197        assert!(header.contains("should_terminate"));
1198        assert!(header.contains("is_active"));
1199        assert!(header.contains("input_head"));
1200        assert!(header.contains("input_tail"));
1201        assert!(header.contains("msg_ptr"));
1202    }
1203
1204    #[test]
1205    fn test_epilogue() {
1206        let config = RingKernelConfig::new("test").with_hlc(true);
1207        let epilogue = config.generate_epilogue("    ");
1208
1209        assert!(epilogue.contains("has_terminated"));
1210        assert!(epilogue.contains("hlc_state.physical"));
1211        assert!(epilogue.contains("hlc_state.logical"));
1212    }
1213
1214    #[test]
1215    fn test_control_block_struct() {
1216        let code = generate_control_block_struct();
1217
1218        assert!(code.contains("struct __align__(128) ControlBlock"));
1219        assert!(code.contains("is_active"));
1220        assert!(code.contains("should_terminate"));
1221        assert!(code.contains("has_terminated"));
1222        assert!(code.contains("messages_processed"));
1223        assert!(code.contains("input_head"));
1224        assert!(code.contains("input_tail"));
1225        assert!(code.contains("hlc_state"));
1226    }
1227
1228    #[test]
1229    fn test_full_kernel_wrapper() {
1230        let config = RingKernelConfig::new("example")
1231            .with_block_size(128)
1232            .with_hlc(true);
1233
1234        let kernel = config.generate_kernel_wrapper("// Process message here");
1235
1236        assert!(kernel.contains("struct __align__(128) ControlBlock"));
1237        assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
1238        assert!(kernel.contains("while (true)"));
1239        assert!(kernel.contains("// Process message here"));
1240        assert!(kernel.contains("has_terminated"));
1241
1242        println!("Generated kernel:\n{}", kernel);
1243    }
1244
1245    #[test]
1246    fn test_intrinsic_lookup() {
1247        assert_eq!(
1248            RingKernelIntrinsic::from_name("is_active"),
1249            Some(RingKernelIntrinsic::IsActive)
1250        );
1251        assert_eq!(
1252            RingKernelIntrinsic::from_name("should_terminate"),
1253            Some(RingKernelIntrinsic::ShouldTerminate)
1254        );
1255        assert_eq!(
1256            RingKernelIntrinsic::from_name("hlc_tick"),
1257            Some(RingKernelIntrinsic::HlcTick)
1258        );
1259        assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
1260    }
1261
1262    #[test]
1263    fn test_intrinsic_cuda_output() {
1264        assert!(RingKernelIntrinsic::IsActive
1265            .to_cuda(&[])
1266            .contains("is_active"));
1267        assert!(RingKernelIntrinsic::ShouldTerminate
1268            .to_cuda(&[])
1269            .contains("should_terminate"));
1270        assert!(RingKernelIntrinsic::HlcTick
1271            .to_cuda(&[])
1272            .contains("hlc_logical++"));
1273    }
1274
1275    #[test]
1276    fn test_k2k_structs_generation() {
1277        let k2k_code = generate_k2k_structs();
1278
1279        // Check struct definitions
1280        assert!(
1281            k2k_code.contains("struct K2KRoute"),
1282            "Should have K2KRoute struct"
1283        );
1284        assert!(
1285            k2k_code.contains("struct K2KRoutingTable"),
1286            "Should have K2KRoutingTable struct"
1287        );
1288        assert!(
1289            k2k_code.contains("K2KInboxHeader"),
1290            "Should have K2KInboxHeader struct"
1291        );
1292
1293        // Check helper functions
1294        assert!(
1295            k2k_code.contains("__device__ inline int k2k_send"),
1296            "Should have k2k_send function"
1297        );
1298        assert!(
1299            k2k_code.contains("__device__ inline int k2k_has_message"),
1300            "Should have k2k_has_message function"
1301        );
1302        assert!(
1303            k2k_code.contains("__device__ inline void* k2k_try_recv"),
1304            "Should have k2k_try_recv function"
1305        );
1306        assert!(
1307            k2k_code.contains("__device__ inline void* k2k_peek"),
1308            "Should have k2k_peek function"
1309        );
1310        assert!(
1311            k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
1312            "Should have k2k_pending_count function"
1313        );
1314
1315        println!("K2K code:\n{}", k2k_code);
1316    }
1317
1318    #[test]
1319    fn test_full_k2k_kernel() {
1320        let config = RingKernelConfig::new("k2k_processor")
1321            .with_block_size(128)
1322            .with_k2k(true)
1323            .with_hlc(true);
1324
1325        let kernel = config.generate_kernel_wrapper("// K2K handler code");
1326
1327        // Check K2K-specific components
1328        assert!(
1329            kernel.contains("K2KRoutingTable"),
1330            "Should have K2KRoutingTable"
1331        );
1332        assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
1333        assert!(
1334            kernel.contains("K2KInboxHeader"),
1335            "Should have K2KInboxHeader"
1336        );
1337        assert!(
1338            kernel.contains("k2k_routes"),
1339            "Should have k2k_routes param"
1340        );
1341        assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
1342        assert!(
1343            kernel.contains("k2k_outbox"),
1344            "Should have k2k_outbox param"
1345        );
1346        assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
1347        assert!(
1348            kernel.contains("k2k_try_recv"),
1349            "Should have k2k_try_recv function"
1350        );
1351
1352        println!("Full K2K kernel:\n{}", kernel);
1353    }
1354
1355    #[test]
1356    fn test_k2k_intrinsic_requirements() {
1357        assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
1358        assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
1359        assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
1360        assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
1361        assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
1362
1363        assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
1364        assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
1365    }
1366}