ringkernel_cuda_codegen/
ring_kernel.rs

1//! Ring kernel configuration and code generation for persistent actor kernels.
2//!
3//! This module provides the infrastructure for generating CUDA code for
4//! persistent actor kernels that process messages in a loop.
5//!
6//! # Overview
7//!
8//! Ring kernels are persistent GPU kernels that:
9//! - Run continuously until terminated
10//! - Process messages from input queues
11//! - Produce responses to output queues
12//! - Use HLC (Hybrid Logical Clocks) for causal ordering
13//! - Support kernel-to-kernel (K2K) messaging
14//!
15//! # Generated Kernel Structure
16//!
17//! ```cuda
18//! extern "C" __global__ void ring_kernel_NAME(
19//!     ControlBlock* __restrict__ control,
20//!     unsigned char* __restrict__ input_buffer,
21//!     unsigned char* __restrict__ output_buffer,
22//!     void* __restrict__ shared_state
23//! ) {
24//!     // Preamble: thread setup
25//!     int tid = threadIdx.x + blockIdx.x * blockDim.x;
26//!
27//!     // Persistent message loop
28//!     while (!atomicLoad(&control->should_terminate)) {
29//!         if (!atomicLoad(&control->is_active)) {
30//!             __nanosleep(1000);
31//!             continue;
32//!         }
33//!
34//!         // Message processing...
35//!         // (user handler code inserted here)
36//!     }
37//!
38//!     // Epilogue: mark terminated
39//!     if (tid == 0) {
40//!         atomicStore(&control->has_terminated, 1);
41//!     }
42//! }
43//! ```
44
45use std::fmt::Write;
46
47/// Configuration for a ring kernel.
48#[derive(Debug, Clone)]
49pub struct RingKernelConfig {
50    /// Kernel identifier (used in function name).
51    pub id: String,
52    /// Block size (threads per block).
53    pub block_size: u32,
54    /// Input queue capacity (must be power of 2).
55    pub queue_capacity: u32,
56    /// Enable kernel-to-kernel messaging.
57    pub enable_k2k: bool,
58    /// Enable HLC clock operations.
59    pub enable_hlc: bool,
60    /// Message size in bytes (for buffer offset calculations).
61    pub message_size: usize,
62    /// Response size in bytes.
63    pub response_size: usize,
64    /// Use cooperative thread groups.
65    pub cooperative_groups: bool,
66    /// Nanosleep duration when idle (0 to spin).
67    pub idle_sleep_ns: u32,
68    /// Use MessageEnvelope format (256-byte header + payload).
69    /// When enabled, messages in queues are full envelopes with headers.
70    pub use_envelope_format: bool,
71    /// Kernel numeric ID (for response routing).
72    pub kernel_id_num: u64,
73    /// HLC node ID for this kernel.
74    pub hlc_node_id: u64,
75}
76
77impl Default for RingKernelConfig {
78    fn default() -> Self {
79        Self {
80            id: "ring_kernel".to_string(),
81            block_size: 128,
82            queue_capacity: 1024,
83            enable_k2k: false,
84            enable_hlc: true,
85            message_size: 64,
86            response_size: 64,
87            cooperative_groups: false,
88            idle_sleep_ns: 1000,
89            use_envelope_format: true, // Default to envelope format for proper serialization
90            kernel_id_num: 0,
91            hlc_node_id: 0,
92        }
93    }
94}
95
96impl RingKernelConfig {
97    /// Create a new ring kernel configuration with the given ID.
98    pub fn new(id: impl Into<String>) -> Self {
99        Self {
100            id: id.into(),
101            ..Default::default()
102        }
103    }
104
105    /// Set the block size.
106    pub fn with_block_size(mut self, size: u32) -> Self {
107        self.block_size = size;
108        self
109    }
110
111    /// Set the queue capacity (must be power of 2).
112    pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
113        debug_assert!(
114            capacity.is_power_of_two(),
115            "Queue capacity must be power of 2"
116        );
117        self.queue_capacity = capacity;
118        self
119    }
120
121    /// Enable kernel-to-kernel messaging.
122    pub fn with_k2k(mut self, enabled: bool) -> Self {
123        self.enable_k2k = enabled;
124        self
125    }
126
127    /// Enable HLC clock operations.
128    pub fn with_hlc(mut self, enabled: bool) -> Self {
129        self.enable_hlc = enabled;
130        self
131    }
132
133    /// Set message and response sizes.
134    pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
135        self.message_size = message;
136        self.response_size = response;
137        self
138    }
139
140    /// Set idle sleep duration in nanoseconds.
141    pub fn with_idle_sleep(mut self, ns: u32) -> Self {
142        self.idle_sleep_ns = ns;
143        self
144    }
145
146    /// Enable or disable MessageEnvelope format.
147    /// When enabled, messages use the 256-byte header + payload format.
148    pub fn with_envelope_format(mut self, enabled: bool) -> Self {
149        self.use_envelope_format = enabled;
150        self
151    }
152
153    /// Set the kernel numeric ID (for K2K routing and response headers).
154    pub fn with_kernel_id(mut self, id: u64) -> Self {
155        self.kernel_id_num = id;
156        self
157    }
158
159    /// Set the HLC node ID for this kernel.
160    pub fn with_hlc_node_id(mut self, node_id: u64) -> Self {
161        self.hlc_node_id = node_id;
162        self
163    }
164
165    /// Generate the kernel function name.
166    pub fn kernel_name(&self) -> String {
167        format!("ring_kernel_{}", self.id)
168    }
169
170    /// Generate CUDA kernel signature.
171    pub fn generate_signature(&self) -> String {
172        let mut sig = String::new();
173
174        writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
175        writeln!(sig, "    ControlBlock* __restrict__ control,").unwrap();
176        writeln!(sig, "    unsigned char* __restrict__ input_buffer,").unwrap();
177        writeln!(sig, "    unsigned char* __restrict__ output_buffer,").unwrap();
178
179        if self.enable_k2k {
180            writeln!(sig, "    K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
181            writeln!(sig, "    unsigned char* __restrict__ k2k_inbox,").unwrap();
182            writeln!(sig, "    unsigned char* __restrict__ k2k_outbox,").unwrap();
183        }
184
185        write!(sig, "    void* __restrict__ shared_state").unwrap();
186        write!(sig, "\n)").unwrap();
187
188        sig
189    }
190
191    /// Generate the kernel preamble (thread setup, variable declarations).
192    pub fn generate_preamble(&self, indent: &str) -> String {
193        let mut code = String::new();
194
195        // Thread identification
196        writeln!(code, "{}// Thread identification", indent).unwrap();
197        writeln!(
198            code,
199            "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
200            indent
201        )
202        .unwrap();
203        writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
204        writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
205        writeln!(code).unwrap();
206
207        // Kernel ID constant
208        writeln!(code, "{}// Kernel identity", indent).unwrap();
209        writeln!(
210            code,
211            "{}const unsigned long long KERNEL_ID = {}ULL;",
212            indent, self.kernel_id_num
213        )
214        .unwrap();
215        writeln!(
216            code,
217            "{}const unsigned long long HLC_NODE_ID = {}ULL;",
218            indent, self.hlc_node_id
219        )
220        .unwrap();
221        writeln!(code).unwrap();
222
223        // Message size constants
224        writeln!(code, "{}// Message buffer constants", indent).unwrap();
225        if self.use_envelope_format {
226            // With envelope format, MSG_SIZE is envelope size (header + payload)
227            writeln!(
228                code,
229                "{}const unsigned int PAYLOAD_SIZE = {};  // User payload size",
230                indent, self.message_size
231            )
232            .unwrap();
233            writeln!(
234                code,
235                "{}const unsigned int MSG_SIZE = MESSAGE_HEADER_SIZE + PAYLOAD_SIZE;  // Full envelope",
236                indent
237            )
238            .unwrap();
239            writeln!(
240                code,
241                "{}const unsigned int RESP_PAYLOAD_SIZE = {};",
242                indent, self.response_size
243            )
244            .unwrap();
245            writeln!(
246                code,
247                "{}const unsigned int RESP_SIZE = MESSAGE_HEADER_SIZE + RESP_PAYLOAD_SIZE;",
248                indent
249            )
250            .unwrap();
251        } else {
252            // Legacy: raw message data
253            writeln!(
254                code,
255                "{}const unsigned int MSG_SIZE = {};",
256                indent, self.message_size
257            )
258            .unwrap();
259            writeln!(
260                code,
261                "{}const unsigned int RESP_SIZE = {};",
262                indent, self.response_size
263            )
264            .unwrap();
265        }
266        writeln!(
267            code,
268            "{}const unsigned int QUEUE_MASK = {};",
269            indent,
270            self.queue_capacity - 1
271        )
272        .unwrap();
273        writeln!(code).unwrap();
274
275        // HLC state (if enabled)
276        if self.enable_hlc {
277            writeln!(code, "{}// HLC clock state", indent).unwrap();
278            writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
279            writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
280            writeln!(code).unwrap();
281        }
282
283        code
284    }
285
286    /// Generate the persistent message loop header.
287    pub fn generate_loop_header(&self, indent: &str) -> String {
288        let mut code = String::new();
289
290        writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
291        writeln!(code, "{}while (true) {{", indent).unwrap();
292        writeln!(code, "{}    // Check for termination signal", indent).unwrap();
293        writeln!(
294            code,
295            "{}    if (atomicAdd(&control->should_terminate, 0) != 0) {{",
296            indent
297        )
298        .unwrap();
299        writeln!(code, "{}        break;", indent).unwrap();
300        writeln!(code, "{}    }}", indent).unwrap();
301        writeln!(code).unwrap();
302
303        // Check if active
304        writeln!(code, "{}    // Check if kernel is active", indent).unwrap();
305        writeln!(
306            code,
307            "{}    if (atomicAdd(&control->is_active, 0) == 0) {{",
308            indent
309        )
310        .unwrap();
311        if self.idle_sleep_ns > 0 {
312            writeln!(
313                code,
314                "{}        __nanosleep({});",
315                indent, self.idle_sleep_ns
316            )
317            .unwrap();
318        }
319        writeln!(code, "{}        continue;", indent).unwrap();
320        writeln!(code, "{}    }}", indent).unwrap();
321        writeln!(code).unwrap();
322
323        // Check for messages
324        writeln!(code, "{}    // Check input queue for messages", indent).unwrap();
325        writeln!(
326            code,
327            "{}    unsigned long long head = atomicAdd(&control->input_head, 0);",
328            indent
329        )
330        .unwrap();
331        writeln!(
332            code,
333            "{}    unsigned long long tail = atomicAdd(&control->input_tail, 0);",
334            indent
335        )
336        .unwrap();
337        writeln!(code).unwrap();
338        writeln!(code, "{}    if (head == tail) {{", indent).unwrap();
339        writeln!(code, "{}        // No messages, yield", indent).unwrap();
340        if self.idle_sleep_ns > 0 {
341            writeln!(
342                code,
343                "{}        __nanosleep({});",
344                indent, self.idle_sleep_ns
345            )
346            .unwrap();
347        }
348        writeln!(code, "{}        continue;", indent).unwrap();
349        writeln!(code, "{}    }}", indent).unwrap();
350        writeln!(code).unwrap();
351
352        // Calculate message pointer
353        writeln!(code, "{}    // Get message from queue", indent).unwrap();
354        writeln!(
355            code,
356            "{}    unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
357            indent
358        )
359        .unwrap();
360        writeln!(
361            code,
362            "{}    unsigned char* envelope_ptr = &input_buffer[msg_idx * MSG_SIZE];",
363            indent
364        )
365        .unwrap();
366
367        if self.use_envelope_format {
368            // Envelope format: parse header and get payload pointer
369            writeln!(code).unwrap();
370            writeln!(code, "{}    // Parse message envelope", indent).unwrap();
371            writeln!(
372                code,
373                "{}    MessageHeader* msg_header = message_get_header(envelope_ptr);",
374                indent
375            )
376            .unwrap();
377            writeln!(
378                code,
379                "{}    unsigned char* msg_ptr = message_get_payload(envelope_ptr);",
380                indent
381            )
382            .unwrap();
383            writeln!(code).unwrap();
384            writeln!(code, "{}    // Validate message (skip invalid)", indent).unwrap();
385            writeln!(
386                code,
387                "{}    if (!message_header_validate(msg_header)) {{",
388                indent
389            )
390            .unwrap();
391            writeln!(
392                code,
393                "{}        atomicAdd(&control->input_tail, 1);",
394                indent
395            )
396            .unwrap();
397            writeln!(
398                code,
399                "{}        atomicAdd(&control->last_error, 1);  // Track errors",
400                indent
401            )
402            .unwrap();
403            writeln!(code, "{}        continue;", indent).unwrap();
404            writeln!(code, "{}    }}", indent).unwrap();
405
406            // Update HLC from incoming message timestamp
407            if self.enable_hlc {
408                writeln!(code).unwrap();
409                writeln!(code, "{}    // Update HLC from message timestamp", indent).unwrap();
410                writeln!(
411                    code,
412                    "{}    if (msg_header->timestamp.physical > hlc_physical) {{",
413                    indent
414                )
415                .unwrap();
416                writeln!(
417                    code,
418                    "{}        hlc_physical = msg_header->timestamp.physical;",
419                    indent
420                )
421                .unwrap();
422                writeln!(code, "{}        hlc_logical = 0;", indent).unwrap();
423                writeln!(code, "{}    }}", indent).unwrap();
424                writeln!(code, "{}    hlc_logical++;", indent).unwrap();
425            }
426        } else {
427            // Legacy raw format
428            writeln!(
429                code,
430                "{}    unsigned char* msg_ptr = envelope_ptr;  // Raw message data",
431                indent
432            )
433            .unwrap();
434        }
435        writeln!(code).unwrap();
436
437        code
438    }
439
440    /// Generate message processing completion code.
441    pub fn generate_message_complete(&self, indent: &str) -> String {
442        let mut code = String::new();
443
444        writeln!(code).unwrap();
445        writeln!(code, "{}    // Mark message as processed", indent).unwrap();
446        writeln!(code, "{}    atomicAdd(&control->input_tail, 1);", indent).unwrap();
447        writeln!(
448            code,
449            "{}    atomicAdd(&control->messages_processed, 1);",
450            indent
451        )
452        .unwrap();
453
454        if self.enable_hlc {
455            writeln!(code).unwrap();
456            writeln!(code, "{}    // Update HLC", indent).unwrap();
457            writeln!(code, "{}    hlc_logical++;", indent).unwrap();
458        }
459
460        code
461    }
462
463    /// Generate the loop footer (end of while loop).
464    pub fn generate_loop_footer(&self, indent: &str) -> String {
465        let mut code = String::new();
466
467        writeln!(code, "{}    __syncthreads();", indent).unwrap();
468        writeln!(code, "{}}}", indent).unwrap();
469
470        code
471    }
472
473    /// Generate kernel epilogue (termination marking).
474    pub fn generate_epilogue(&self, indent: &str) -> String {
475        let mut code = String::new();
476
477        writeln!(code).unwrap();
478        writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
479        writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
480
481        if self.enable_hlc {
482            writeln!(code, "{}    // Store final HLC state", indent).unwrap();
483            writeln!(
484                code,
485                "{}    control->hlc_state.physical = hlc_physical;",
486                indent
487            )
488            .unwrap();
489            writeln!(
490                code,
491                "{}    control->hlc_state.logical = hlc_logical;",
492                indent
493            )
494            .unwrap();
495        }
496
497        writeln!(
498            code,
499            "{}    atomicExch(&control->has_terminated, 1);",
500            indent
501        )
502        .unwrap();
503        writeln!(code, "{}}}", indent).unwrap();
504
505        code
506    }
507
508    /// Generate complete kernel wrapper (without handler body).
509    pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
510        let mut code = String::new();
511
512        // Struct definitions
513        code.push_str(&generate_control_block_struct());
514        code.push('\n');
515
516        if self.enable_hlc {
517            code.push_str(&generate_hlc_struct());
518            code.push('\n');
519        }
520
521        // MessageEnvelope structs (requires HLC for HlcTimestamp)
522        if self.use_envelope_format {
523            code.push_str(&generate_message_envelope_structs());
524            code.push('\n');
525        }
526
527        if self.enable_k2k {
528            code.push_str(&generate_k2k_structs());
529            code.push('\n');
530        }
531
532        // Kernel signature
533        code.push_str(&self.generate_signature());
534        code.push_str(" {\n");
535
536        // Preamble
537        code.push_str(&self.generate_preamble("    "));
538
539        // Message loop
540        code.push_str(&self.generate_loop_header("    "));
541
542        // Handler placeholder
543        writeln!(code, "        // === USER HANDLER CODE ===").unwrap();
544        for line in handler_placeholder.lines() {
545            writeln!(code, "        {}", line).unwrap();
546        }
547        writeln!(code, "        // === END HANDLER CODE ===").unwrap();
548
549        // Message completion
550        code.push_str(&self.generate_message_complete("    "));
551
552        // Loop footer
553        code.push_str(&self.generate_loop_footer("    "));
554
555        // Epilogue
556        code.push_str(&self.generate_epilogue("    "));
557
558        code.push_str("}\n");
559
560        code
561    }
562}
563
564/// Generate CUDA ControlBlock struct definition.
565pub fn generate_control_block_struct() -> String {
566    r#"// Control block for kernel state management (128 bytes, cache-line aligned)
567struct __align__(128) ControlBlock {
568    // Lifecycle state
569    unsigned int is_active;
570    unsigned int should_terminate;
571    unsigned int has_terminated;
572    unsigned int _pad1;
573
574    // Counters
575    unsigned long long messages_processed;
576    unsigned long long messages_in_flight;
577
578    // Queue pointers
579    unsigned long long input_head;
580    unsigned long long input_tail;
581    unsigned long long output_head;
582    unsigned long long output_tail;
583
584    // Queue metadata
585    unsigned int input_capacity;
586    unsigned int output_capacity;
587    unsigned int input_mask;
588    unsigned int output_mask;
589
590    // HLC state
591    struct {
592        unsigned long long physical;
593        unsigned long long logical;
594    } hlc_state;
595
596    // Error state
597    unsigned int last_error;
598    unsigned int error_count;
599
600    // Reserved padding
601    unsigned char _reserved[24];
602};
603"#
604    .to_string()
605}
606
607/// Generate CUDA HLC helper struct.
608pub fn generate_hlc_struct() -> String {
609    r#"// Hybrid Logical Clock state
610struct HlcState {
611    unsigned long long physical;
612    unsigned long long logical;
613};
614
615// HLC timestamp (24 bytes)
616struct __align__(8) HlcTimestamp {
617    unsigned long long physical;
618    unsigned long long logical;
619    unsigned long long node_id;
620};
621"#
622    .to_string()
623}
624
625/// Generate CUDA MessageEnvelope structs matching ringkernel-core layout.
626///
627/// This generates the GPU-side structures that match the Rust `MessageHeader`
628/// and `MessageEnvelope` types, enabling proper serialization between host and device.
629pub fn generate_message_envelope_structs() -> String {
630    r#"// Magic number for message validation
631#define MESSAGE_MAGIC 0x52494E474B45524E ULL  // "RINGKERN"
632#define MESSAGE_VERSION 1
633#define MESSAGE_HEADER_SIZE 256
634#define MAX_PAYLOAD_SIZE (64 * 1024)
635
636// Message priority levels
637#define PRIORITY_LOW 0
638#define PRIORITY_NORMAL 1
639#define PRIORITY_HIGH 2
640#define PRIORITY_CRITICAL 3
641
642// Message header structure (256 bytes, cache-line aligned)
643// Matches ringkernel_core::message::MessageHeader exactly
644struct __align__(64) MessageHeader {
645    // Magic number for validation (0xRINGKERN)
646    unsigned long long magic;
647    // Header version
648    unsigned int version;
649    // Message flags
650    unsigned int flags;
651    // Unique message identifier
652    unsigned long long message_id;
653    // Correlation ID for request-response
654    unsigned long long correlation_id;
655    // Source kernel ID (0 for host)
656    unsigned long long source_kernel;
657    // Destination kernel ID (0 for host)
658    unsigned long long dest_kernel;
659    // Message type discriminator
660    unsigned long long message_type;
661    // Priority level
662    unsigned char priority;
663    // Reserved for alignment
664    unsigned char _reserved1[7];
665    // Payload size in bytes
666    unsigned long long payload_size;
667    // Checksum of payload (CRC32)
668    unsigned int checksum;
669    // Reserved for alignment
670    unsigned int _reserved2;
671    // HLC timestamp when message was created
672    HlcTimestamp timestamp;
673    // Deadline timestamp (0 = no deadline)
674    HlcTimestamp deadline;
675    // Reserved for future use (104 bytes total: 32+32+32+8)
676    unsigned char _reserved3[104];
677};
678
679// Validate message header
680__device__ inline int message_header_validate(const MessageHeader* header) {
681    return header->magic == MESSAGE_MAGIC &&
682           header->version <= MESSAGE_VERSION &&
683           header->payload_size <= MAX_PAYLOAD_SIZE;
684}
685
686// Get payload pointer from header
687__device__ inline unsigned char* message_get_payload(unsigned char* envelope_ptr) {
688    return envelope_ptr + MESSAGE_HEADER_SIZE;
689}
690
691// Get header from envelope pointer
692__device__ inline MessageHeader* message_get_header(unsigned char* envelope_ptr) {
693    return (MessageHeader*)envelope_ptr;
694}
695
696// CRC32 computation for payload checksums
697// Uses CRC32-C (Castagnoli) polynomial 0x1EDC6F41
698__device__ inline unsigned int message_compute_checksum(const unsigned char* data, unsigned long long size) {
699    unsigned int crc = 0xFFFFFFFF;
700    for (unsigned long long i = 0; i < size; i++) {
701        crc ^= data[i];
702        for (int j = 0; j < 8; j++) {
703            crc = (crc >> 1) ^ (0x82F63B78 * (crc & 1));  // CRC32-C polynomial
704        }
705    }
706    return ~crc;
707}
708
709// Verify payload checksum
710__device__ inline int message_verify_checksum(const MessageHeader* header, const unsigned char* payload) {
711    if (header->payload_size == 0) {
712        return header->checksum == 0;  // Empty payload should have zero checksum
713    }
714    unsigned int computed = message_compute_checksum(payload, header->payload_size);
715    return computed == header->checksum;
716}
717
718// Create a response header based on request
719// Note: payload_ptr is optional, pass NULL if checksum not needed yet
720__device__ inline void message_create_response_header(
721    MessageHeader* response,
722    const MessageHeader* request,
723    unsigned long long this_kernel_id,
724    unsigned long long payload_size,
725    unsigned long long hlc_physical,
726    unsigned long long hlc_logical,
727    unsigned long long hlc_node_id,
728    const unsigned char* payload_ptr  // Optional: for checksum computation
729) {
730    response->magic = MESSAGE_MAGIC;
731    response->version = MESSAGE_VERSION;
732    response->flags = 0;
733    // Generate new message ID (simple increment from request)
734    response->message_id = request->message_id + 0x100000000ULL;
735    // Preserve correlation ID for request-response matching
736    response->correlation_id = request->correlation_id != 0
737        ? request->correlation_id
738        : request->message_id;
739    response->source_kernel = this_kernel_id;
740    response->dest_kernel = request->source_kernel;  // Response goes back to sender
741    response->message_type = request->message_type + 1;  // Convention: response type = request + 1
742    response->priority = request->priority;
743    response->payload_size = payload_size;
744    // Compute checksum if payload provided
745    if (payload_ptr != NULL && payload_size > 0) {
746        response->checksum = message_compute_checksum(payload_ptr, payload_size);
747    } else {
748        response->checksum = 0;
749    }
750    response->timestamp.physical = hlc_physical;
751    response->timestamp.logical = hlc_logical;
752    response->timestamp.node_id = hlc_node_id;
753    response->deadline.physical = 0;
754    response->deadline.logical = 0;
755    response->deadline.node_id = 0;
756}
757
758// Calculate total envelope size
759__device__ inline unsigned int message_envelope_size(const MessageHeader* header) {
760    return MESSAGE_HEADER_SIZE + (unsigned int)header->payload_size;
761}
762"#
763    .to_string()
764}
765
766/// Generate CUDA K2K routing structs and helper functions.
767pub fn generate_k2k_structs() -> String {
768    r#"// Kernel-to-kernel routing table entry (48 bytes)
769// Matches ringkernel_cuda::k2k_gpu::K2KRouteEntry
770struct K2KRoute {
771    unsigned long long target_kernel_id;
772    unsigned long long target_inbox;      // device pointer
773    unsigned long long target_head;       // device pointer to head
774    unsigned long long target_tail;       // device pointer to tail
775    unsigned int capacity;
776    unsigned int mask;
777    unsigned int msg_size;
778    unsigned int _pad;
779};
780
781// K2K routing table (8 + 48*16 = 776 bytes)
782struct K2KRoutingTable {
783    unsigned int num_routes;
784    unsigned int _pad;
785    K2KRoute routes[16];  // Max 16 K2K connections
786};
787
788// K2K inbox header (64 bytes, cache-line aligned)
789// Matches ringkernel_cuda::k2k_gpu::K2KInboxHeader
790struct __align__(64) K2KInboxHeader {
791    unsigned long long head;
792    unsigned long long tail;
793    unsigned int capacity;
794    unsigned int mask;
795    unsigned int msg_size;
796    unsigned int _pad;
797};
798
799// Send a message envelope to another kernel via K2K
800// The entire envelope (header + payload) is copied to the target's inbox
801__device__ inline int k2k_send_envelope(
802    K2KRoutingTable* routes,
803    unsigned long long target_id,
804    unsigned long long source_kernel_id,
805    const void* payload_ptr,
806    unsigned int payload_size,
807    unsigned long long message_type,
808    unsigned long long hlc_physical,
809    unsigned long long hlc_logical,
810    unsigned long long hlc_node_id
811) {
812    // Find route for target
813    for (unsigned int i = 0; i < routes->num_routes; i++) {
814        if (routes->routes[i].target_kernel_id == target_id) {
815            K2KRoute* route = &routes->routes[i];
816
817            // Calculate total envelope size
818            unsigned int envelope_size = MESSAGE_HEADER_SIZE + payload_size;
819            if (envelope_size > route->msg_size) {
820                return -1;  // Message too large
821            }
822
823            // Atomically claim a slot in target's inbox
824            unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
825            unsigned long long slot = atomicAdd(target_head_ptr, 1);
826            unsigned int idx = (unsigned int)(slot & route->mask);
827
828            // Calculate destination pointer
829            unsigned char* dest = ((unsigned char*)route->target_inbox) +
830                                  sizeof(K2KInboxHeader) + idx * route->msg_size;
831
832            // Build message header
833            MessageHeader* header = (MessageHeader*)dest;
834            header->magic = MESSAGE_MAGIC;
835            header->version = MESSAGE_VERSION;
836            header->flags = 0;
837            header->message_id = (source_kernel_id << 32) | (slot & 0xFFFFFFFF);
838            header->correlation_id = 0;
839            header->source_kernel = source_kernel_id;
840            header->dest_kernel = target_id;
841            header->message_type = message_type;
842            header->priority = PRIORITY_NORMAL;
843            header->payload_size = payload_size;
844            header->timestamp.physical = hlc_physical;
845            header->timestamp.logical = hlc_logical;
846            header->timestamp.node_id = hlc_node_id;
847            header->deadline.physical = 0;
848            header->deadline.logical = 0;
849            header->deadline.node_id = 0;
850
851            // Copy payload after header and compute checksum
852            if (payload_size > 0 && payload_ptr != NULL) {
853                memcpy(dest + MESSAGE_HEADER_SIZE, payload_ptr, payload_size);
854                header->checksum = message_compute_checksum(
855                    (const unsigned char*)payload_ptr, payload_size);
856            } else {
857                header->checksum = 0;
858            }
859
860            __threadfence();  // Ensure write is visible
861            return 1;  // Success
862        }
863    }
864    return 0;  // Route not found
865}
866
867// Legacy k2k_send for raw message data (no envelope)
868__device__ inline int k2k_send(
869    K2KRoutingTable* routes,
870    unsigned long long target_id,
871    const void* msg_ptr,
872    unsigned int msg_size
873) {
874    // Find route for target
875    for (unsigned int i = 0; i < routes->num_routes; i++) {
876        if (routes->routes[i].target_kernel_id == target_id) {
877            K2KRoute* route = &routes->routes[i];
878
879            // Atomically claim a slot in target's inbox
880            unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
881            unsigned long long slot = atomicAdd(target_head_ptr, 1);
882            unsigned int idx = (unsigned int)(slot & route->mask);
883
884            // Copy message to target inbox
885            unsigned char* dest = ((unsigned char*)route->target_inbox) +
886                                  sizeof(K2KInboxHeader) + idx * route->msg_size;
887            memcpy(dest, msg_ptr, msg_size < route->msg_size ? msg_size : route->msg_size);
888
889            __threadfence();
890            return 1;  // Success
891        }
892    }
893    return 0;  // Route not found
894}
895
896// Check if there are K2K messages in inbox
897__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
898    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
899    unsigned long long head = atomicAdd(&header->head, 0);
900    unsigned long long tail = atomicAdd(&header->tail, 0);
901    return head != tail;
902}
903
904// Try to receive a K2K message envelope
905// Returns pointer to MessageHeader if available, NULL otherwise
906__device__ inline MessageHeader* k2k_try_recv_envelope(unsigned char* k2k_inbox) {
907    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
908
909    unsigned long long head = atomicAdd(&header->head, 0);
910    unsigned long long tail = atomicAdd(&header->tail, 0);
911
912    if (head == tail) {
913        return NULL;  // No messages
914    }
915
916    // Get message pointer
917    unsigned int idx = (unsigned int)(tail & header->mask);
918    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
919    MessageHeader* msg_header = (MessageHeader*)(data_start + idx * header->msg_size);
920
921    // Validate header
922    if (!message_header_validate(msg_header)) {
923        // Invalid message, skip it
924        atomicAdd(&header->tail, 1);
925        return NULL;
926    }
927
928    // Advance tail (consume message)
929    atomicAdd(&header->tail, 1);
930
931    return msg_header;
932}
933
934// Legacy k2k_try_recv for raw data
935__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
936    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
937
938    unsigned long long head = atomicAdd(&header->head, 0);
939    unsigned long long tail = atomicAdd(&header->tail, 0);
940
941    if (head == tail) {
942        return NULL;  // No messages
943    }
944
945    // Get message pointer
946    unsigned int idx = (unsigned int)(tail & header->mask);
947    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
948
949    // Advance tail (consume message)
950    atomicAdd(&header->tail, 1);
951
952    return data_start + idx * header->msg_size;
953}
954
955// Peek at next K2K message without consuming
956__device__ inline MessageHeader* k2k_peek_envelope(unsigned char* k2k_inbox) {
957    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
958
959    unsigned long long head = atomicAdd(&header->head, 0);
960    unsigned long long tail = atomicAdd(&header->tail, 0);
961
962    if (head == tail) {
963        return NULL;  // No messages
964    }
965
966    unsigned int idx = (unsigned int)(tail & header->mask);
967    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
968
969    return (MessageHeader*)(data_start + idx * header->msg_size);
970}
971
972// Legacy k2k_peek
973__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
974    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
975
976    unsigned long long head = atomicAdd(&header->head, 0);
977    unsigned long long tail = atomicAdd(&header->tail, 0);
978
979    if (head == tail) {
980        return NULL;  // No messages
981    }
982
983    unsigned int idx = (unsigned int)(tail & header->mask);
984    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
985
986    return data_start + idx * header->msg_size;
987}
988
989// Get number of pending K2K messages
990__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
991    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
992    unsigned long long head = atomicAdd(&header->head, 0);
993    unsigned long long tail = atomicAdd(&header->tail, 0);
994    return (unsigned int)(head - tail);
995}
996"#
997    .to_string()
998}
999
1000/// Intrinsic functions available in ring kernel handlers.
1001#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1002pub enum RingKernelIntrinsic {
1003    // Control block access
1004    IsActive,
1005    ShouldTerminate,
1006    MarkTerminated,
1007    GetMessagesProcessed,
1008
1009    // Queue operations
1010    InputQueueSize,
1011    OutputQueueSize,
1012    InputQueueEmpty,
1013    OutputQueueEmpty,
1014    EnqueueResponse,
1015
1016    // HLC operations
1017    HlcTick,
1018    HlcUpdate,
1019    HlcNow,
1020
1021    // K2K operations
1022    K2kSend,
1023    K2kTryRecv,
1024    K2kHasMessage,
1025    K2kPeek,
1026    K2kPendingCount,
1027}
1028
1029impl RingKernelIntrinsic {
1030    /// Get the CUDA code for this intrinsic.
1031    pub fn to_cuda(&self, args: &[String]) -> String {
1032        match self {
1033            Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
1034            Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
1035            Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
1036            Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
1037
1038            Self::InputQueueSize => {
1039                "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
1040                    .to_string()
1041            }
1042            Self::OutputQueueSize => {
1043                "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
1044                    .to_string()
1045            }
1046            Self::InputQueueEmpty => {
1047                "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
1048                    .to_string()
1049            }
1050            Self::OutputQueueEmpty => {
1051                "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
1052                    .to_string()
1053            }
1054            Self::EnqueueResponse => {
1055                if !args.is_empty() {
1056                    format!(
1057                        "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
1058                         memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
1059                        args[0]
1060                    )
1061                } else {
1062                    "/* enqueue_response requires response pointer */".to_string()
1063                }
1064            }
1065
1066            Self::HlcTick => "hlc_logical++".to_string(),
1067            Self::HlcUpdate => {
1068                if !args.is_empty() {
1069                    format!(
1070                        "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
1071                        args[0], args[0]
1072                    )
1073                } else {
1074                    "hlc_logical++".to_string()
1075                }
1076            }
1077            Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
1078
1079            Self::K2kSend => {
1080                if args.len() >= 2 {
1081                    // k2k_send(target_id, msg_ptr) -> k2k_send(k2k_routes, target_id, msg_ptr, sizeof(*msg_ptr))
1082                    format!(
1083                        "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
1084                        args[0], args[1], args[1]
1085                    )
1086                } else {
1087                    "/* k2k_send requires target_id and msg_ptr */".to_string()
1088                }
1089            }
1090            Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
1091            Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
1092            Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
1093            Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
1094        }
1095    }
1096
1097    /// Parse a function name to get the intrinsic.
1098    pub fn from_name(name: &str) -> Option<Self> {
1099        match name {
1100            "is_active" | "is_kernel_active" => Some(Self::IsActive),
1101            "should_terminate" => Some(Self::ShouldTerminate),
1102            "mark_terminated" => Some(Self::MarkTerminated),
1103            "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
1104
1105            "input_queue_size" => Some(Self::InputQueueSize),
1106            "output_queue_size" => Some(Self::OutputQueueSize),
1107            "input_queue_empty" => Some(Self::InputQueueEmpty),
1108            "output_queue_empty" => Some(Self::OutputQueueEmpty),
1109            "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
1110
1111            "hlc_tick" => Some(Self::HlcTick),
1112            "hlc_update" => Some(Self::HlcUpdate),
1113            "hlc_now" => Some(Self::HlcNow),
1114
1115            "k2k_send" => Some(Self::K2kSend),
1116            "k2k_try_recv" => Some(Self::K2kTryRecv),
1117            "k2k_has_message" => Some(Self::K2kHasMessage),
1118            "k2k_peek" => Some(Self::K2kPeek),
1119            "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
1120
1121            _ => None,
1122        }
1123    }
1124
1125    /// Check if this intrinsic requires K2K support.
1126    pub fn requires_k2k(&self) -> bool {
1127        matches!(
1128            self,
1129            Self::K2kSend
1130                | Self::K2kTryRecv
1131                | Self::K2kHasMessage
1132                | Self::K2kPeek
1133                | Self::K2kPendingCount
1134        )
1135    }
1136
1137    /// Check if this intrinsic requires HLC support.
1138    pub fn requires_hlc(&self) -> bool {
1139        matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
1140    }
1141
1142    /// Check if this intrinsic requires control block access.
1143    pub fn requires_control_block(&self) -> bool {
1144        matches!(
1145            self,
1146            Self::IsActive
1147                | Self::ShouldTerminate
1148                | Self::MarkTerminated
1149                | Self::GetMessagesProcessed
1150                | Self::InputQueueSize
1151                | Self::OutputQueueSize
1152                | Self::InputQueueEmpty
1153                | Self::OutputQueueEmpty
1154                | Self::EnqueueResponse
1155        )
1156    }
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::*;
1162
1163    #[test]
1164    fn test_default_config() {
1165        let config = RingKernelConfig::default();
1166        assert_eq!(config.block_size, 128);
1167        assert_eq!(config.queue_capacity, 1024);
1168        assert!(config.enable_hlc);
1169        assert!(!config.enable_k2k);
1170    }
1171
1172    #[test]
1173    fn test_config_builder() {
1174        let config = RingKernelConfig::new("processor")
1175            .with_block_size(256)
1176            .with_queue_capacity(2048)
1177            .with_k2k(true)
1178            .with_hlc(true);
1179
1180        assert_eq!(config.id, "processor");
1181        assert_eq!(config.kernel_name(), "ring_kernel_processor");
1182        assert_eq!(config.block_size, 256);
1183        assert_eq!(config.queue_capacity, 2048);
1184        assert!(config.enable_k2k);
1185        assert!(config.enable_hlc);
1186    }
1187
1188    #[test]
1189    fn test_kernel_signature() {
1190        let config = RingKernelConfig::new("test");
1191        let sig = config.generate_signature();
1192
1193        assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
1194        assert!(sig.contains("ControlBlock* __restrict__ control"));
1195        assert!(sig.contains("input_buffer"));
1196        assert!(sig.contains("output_buffer"));
1197        assert!(sig.contains("shared_state"));
1198    }
1199
1200    #[test]
1201    fn test_kernel_signature_with_k2k() {
1202        let config = RingKernelConfig::new("k2k_test").with_k2k(true);
1203        let sig = config.generate_signature();
1204
1205        assert!(sig.contains("K2KRoutingTable"));
1206        assert!(sig.contains("k2k_inbox"));
1207        assert!(sig.contains("k2k_outbox"));
1208    }
1209
1210    #[test]
1211    fn test_preamble_generation() {
1212        let config = RingKernelConfig::new("test").with_hlc(true);
1213        let preamble = config.generate_preamble("    ");
1214
1215        assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
1216        assert!(preamble.contains("int lane_id"));
1217        assert!(preamble.contains("int warp_id"));
1218        assert!(preamble.contains("MSG_SIZE"));
1219        assert!(preamble.contains("hlc_physical"));
1220        assert!(preamble.contains("hlc_logical"));
1221    }
1222
1223    #[test]
1224    fn test_loop_header() {
1225        let config = RingKernelConfig::new("test");
1226        let header = config.generate_loop_header("    ");
1227
1228        assert!(header.contains("while (true)"));
1229        assert!(header.contains("should_terminate"));
1230        assert!(header.contains("is_active"));
1231        assert!(header.contains("input_head"));
1232        assert!(header.contains("input_tail"));
1233        assert!(header.contains("msg_ptr"));
1234    }
1235
1236    #[test]
1237    fn test_epilogue() {
1238        let config = RingKernelConfig::new("test").with_hlc(true);
1239        let epilogue = config.generate_epilogue("    ");
1240
1241        assert!(epilogue.contains("has_terminated"));
1242        assert!(epilogue.contains("hlc_state.physical"));
1243        assert!(epilogue.contains("hlc_state.logical"));
1244    }
1245
1246    #[test]
1247    fn test_control_block_struct() {
1248        let code = generate_control_block_struct();
1249
1250        assert!(code.contains("struct __align__(128) ControlBlock"));
1251        assert!(code.contains("is_active"));
1252        assert!(code.contains("should_terminate"));
1253        assert!(code.contains("has_terminated"));
1254        assert!(code.contains("messages_processed"));
1255        assert!(code.contains("input_head"));
1256        assert!(code.contains("input_tail"));
1257        assert!(code.contains("hlc_state"));
1258    }
1259
1260    #[test]
1261    fn test_full_kernel_wrapper() {
1262        let config = RingKernelConfig::new("example")
1263            .with_block_size(128)
1264            .with_hlc(true);
1265
1266        let kernel = config.generate_kernel_wrapper("// Process message here");
1267
1268        assert!(kernel.contains("struct __align__(128) ControlBlock"));
1269        assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
1270        assert!(kernel.contains("while (true)"));
1271        assert!(kernel.contains("// Process message here"));
1272        assert!(kernel.contains("has_terminated"));
1273
1274        println!("Generated kernel:\n{}", kernel);
1275    }
1276
1277    #[test]
1278    fn test_intrinsic_lookup() {
1279        assert_eq!(
1280            RingKernelIntrinsic::from_name("is_active"),
1281            Some(RingKernelIntrinsic::IsActive)
1282        );
1283        assert_eq!(
1284            RingKernelIntrinsic::from_name("should_terminate"),
1285            Some(RingKernelIntrinsic::ShouldTerminate)
1286        );
1287        assert_eq!(
1288            RingKernelIntrinsic::from_name("hlc_tick"),
1289            Some(RingKernelIntrinsic::HlcTick)
1290        );
1291        assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
1292    }
1293
1294    #[test]
1295    fn test_intrinsic_cuda_output() {
1296        assert!(RingKernelIntrinsic::IsActive
1297            .to_cuda(&[])
1298            .contains("is_active"));
1299        assert!(RingKernelIntrinsic::ShouldTerminate
1300            .to_cuda(&[])
1301            .contains("should_terminate"));
1302        assert!(RingKernelIntrinsic::HlcTick
1303            .to_cuda(&[])
1304            .contains("hlc_logical++"));
1305    }
1306
1307    #[test]
1308    fn test_k2k_structs_generation() {
1309        let k2k_code = generate_k2k_structs();
1310
1311        // Check struct definitions
1312        assert!(
1313            k2k_code.contains("struct K2KRoute"),
1314            "Should have K2KRoute struct"
1315        );
1316        assert!(
1317            k2k_code.contains("struct K2KRoutingTable"),
1318            "Should have K2KRoutingTable struct"
1319        );
1320        assert!(
1321            k2k_code.contains("K2KInboxHeader"),
1322            "Should have K2KInboxHeader struct"
1323        );
1324
1325        // Check helper functions
1326        assert!(
1327            k2k_code.contains("__device__ inline int k2k_send"),
1328            "Should have k2k_send function"
1329        );
1330        assert!(
1331            k2k_code.contains("__device__ inline int k2k_has_message"),
1332            "Should have k2k_has_message function"
1333        );
1334        assert!(
1335            k2k_code.contains("__device__ inline void* k2k_try_recv"),
1336            "Should have k2k_try_recv function"
1337        );
1338        assert!(
1339            k2k_code.contains("__device__ inline void* k2k_peek"),
1340            "Should have k2k_peek function"
1341        );
1342        assert!(
1343            k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
1344            "Should have k2k_pending_count function"
1345        );
1346
1347        println!("K2K code:\n{}", k2k_code);
1348    }
1349
1350    #[test]
1351    fn test_full_k2k_kernel() {
1352        let config = RingKernelConfig::new("k2k_processor")
1353            .with_block_size(128)
1354            .with_k2k(true)
1355            .with_hlc(true);
1356
1357        let kernel = config.generate_kernel_wrapper("// K2K handler code");
1358
1359        // Check K2K-specific components
1360        assert!(
1361            kernel.contains("K2KRoutingTable"),
1362            "Should have K2KRoutingTable"
1363        );
1364        assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
1365        assert!(
1366            kernel.contains("K2KInboxHeader"),
1367            "Should have K2KInboxHeader"
1368        );
1369        assert!(
1370            kernel.contains("k2k_routes"),
1371            "Should have k2k_routes param"
1372        );
1373        assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
1374        assert!(
1375            kernel.contains("k2k_outbox"),
1376            "Should have k2k_outbox param"
1377        );
1378        assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
1379        assert!(
1380            kernel.contains("k2k_try_recv"),
1381            "Should have k2k_try_recv function"
1382        );
1383
1384        println!("Full K2K kernel:\n{}", kernel);
1385    }
1386
1387    #[test]
1388    fn test_k2k_intrinsic_requirements() {
1389        assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
1390        assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
1391        assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
1392        assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
1393        assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
1394
1395        assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
1396        assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
1397    }
1398}