ringkernel_cuda_codegen/
ring_kernel.rs

1//! Ring kernel configuration and code generation for persistent actor kernels.
2//!
3//! This module provides the infrastructure for generating CUDA code for
4//! persistent actor kernels that process messages in a loop.
5//!
6//! # Overview
7//!
8//! Ring kernels are persistent GPU kernels that:
9//! - Run continuously until terminated
10//! - Process messages from input queues
11//! - Produce responses to output queues
12//! - Use HLC (Hybrid Logical Clocks) for causal ordering
13//! - Support kernel-to-kernel (K2K) messaging
14//!
15//! # Generated Kernel Structure
16//!
17//! ```cuda
18//! extern "C" __global__ void ring_kernel_NAME(
19//!     ControlBlock* __restrict__ control,
20//!     unsigned char* __restrict__ input_buffer,
21//!     unsigned char* __restrict__ output_buffer,
22//!     void* __restrict__ shared_state
23//! ) {
24//!     // Preamble: thread setup
25//!     int tid = threadIdx.x + blockIdx.x * blockDim.x;
26//!
27//!     // Persistent message loop
28//!     while (!atomicLoad(&control->should_terminate)) {
29//!         if (!atomicLoad(&control->is_active)) {
30//!             __nanosleep(1000);
31//!             continue;
32//!         }
33//!
34//!         // Message processing...
35//!         // (user handler code inserted here)
36//!     }
37//!
38//!     // Epilogue: mark terminated
39//!     if (tid == 0) {
40//!         atomicStore(&control->has_terminated, 1);
41//!     }
42//! }
43//! ```
44
45use std::fmt::Write;
46
47/// Configuration for a ring kernel.
48#[derive(Debug, Clone)]
49pub struct RingKernelConfig {
50    /// Kernel identifier (used in function name).
51    pub id: String,
52    /// Block size (threads per block).
53    pub block_size: u32,
54    /// Input queue capacity (must be power of 2).
55    pub queue_capacity: u32,
56    /// Enable kernel-to-kernel messaging.
57    pub enable_k2k: bool,
58    /// Enable HLC clock operations.
59    pub enable_hlc: bool,
60    /// Message size in bytes (for buffer offset calculations).
61    pub message_size: usize,
62    /// Response size in bytes.
63    pub response_size: usize,
64    /// Use cooperative thread groups.
65    pub cooperative_groups: bool,
66    /// Nanosleep duration when idle (0 to spin).
67    pub idle_sleep_ns: u32,
68}
69
70impl Default for RingKernelConfig {
71    fn default() -> Self {
72        Self {
73            id: "ring_kernel".to_string(),
74            block_size: 128,
75            queue_capacity: 1024,
76            enable_k2k: false,
77            enable_hlc: true,
78            message_size: 64,
79            response_size: 64,
80            cooperative_groups: false,
81            idle_sleep_ns: 1000,
82        }
83    }
84}
85
86impl RingKernelConfig {
87    /// Create a new ring kernel configuration with the given ID.
88    pub fn new(id: impl Into<String>) -> Self {
89        Self {
90            id: id.into(),
91            ..Default::default()
92        }
93    }
94
95    /// Set the block size.
96    pub fn with_block_size(mut self, size: u32) -> Self {
97        self.block_size = size;
98        self
99    }
100
101    /// Set the queue capacity (must be power of 2).
102    pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
103        debug_assert!(
104            capacity.is_power_of_two(),
105            "Queue capacity must be power of 2"
106        );
107        self.queue_capacity = capacity;
108        self
109    }
110
111    /// Enable kernel-to-kernel messaging.
112    pub fn with_k2k(mut self, enabled: bool) -> Self {
113        self.enable_k2k = enabled;
114        self
115    }
116
117    /// Enable HLC clock operations.
118    pub fn with_hlc(mut self, enabled: bool) -> Self {
119        self.enable_hlc = enabled;
120        self
121    }
122
123    /// Set message and response sizes.
124    pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
125        self.message_size = message;
126        self.response_size = response;
127        self
128    }
129
130    /// Set idle sleep duration in nanoseconds.
131    pub fn with_idle_sleep(mut self, ns: u32) -> Self {
132        self.idle_sleep_ns = ns;
133        self
134    }
135
136    /// Generate the kernel function name.
137    pub fn kernel_name(&self) -> String {
138        format!("ring_kernel_{}", self.id)
139    }
140
141    /// Generate CUDA kernel signature.
142    pub fn generate_signature(&self) -> String {
143        let mut sig = String::new();
144
145        writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
146        writeln!(sig, "    ControlBlock* __restrict__ control,").unwrap();
147        writeln!(sig, "    unsigned char* __restrict__ input_buffer,").unwrap();
148        writeln!(sig, "    unsigned char* __restrict__ output_buffer,").unwrap();
149
150        if self.enable_k2k {
151            writeln!(sig, "    K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
152            writeln!(sig, "    unsigned char* __restrict__ k2k_inbox,").unwrap();
153            writeln!(sig, "    unsigned char* __restrict__ k2k_outbox,").unwrap();
154        }
155
156        write!(sig, "    void* __restrict__ shared_state").unwrap();
157        write!(sig, "\n)").unwrap();
158
159        sig
160    }
161
162    /// Generate the kernel preamble (thread setup, variable declarations).
163    pub fn generate_preamble(&self, indent: &str) -> String {
164        let mut code = String::new();
165
166        // Thread identification
167        writeln!(code, "{}// Thread identification", indent).unwrap();
168        writeln!(
169            code,
170            "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
171            indent
172        )
173        .unwrap();
174        writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
175        writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
176        writeln!(code).unwrap();
177
178        // Message size constants
179        writeln!(code, "{}// Message buffer constants", indent).unwrap();
180        writeln!(
181            code,
182            "{}const unsigned int MSG_SIZE = {};",
183            indent, self.message_size
184        )
185        .unwrap();
186        writeln!(
187            code,
188            "{}const unsigned int RESP_SIZE = {};",
189            indent, self.response_size
190        )
191        .unwrap();
192        writeln!(
193            code,
194            "{}const unsigned int QUEUE_MASK = {};",
195            indent,
196            self.queue_capacity - 1
197        )
198        .unwrap();
199        writeln!(code).unwrap();
200
201        // HLC state (if enabled)
202        if self.enable_hlc {
203            writeln!(code, "{}// HLC clock state", indent).unwrap();
204            writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
205            writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
206            writeln!(code).unwrap();
207        }
208
209        code
210    }
211
212    /// Generate the persistent message loop header.
213    pub fn generate_loop_header(&self, indent: &str) -> String {
214        let mut code = String::new();
215
216        writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
217        writeln!(code, "{}while (true) {{", indent).unwrap();
218        writeln!(code, "{}    // Check for termination signal", indent).unwrap();
219        writeln!(
220            code,
221            "{}    if (atomicAdd(&control->should_terminate, 0) != 0) {{",
222            indent
223        )
224        .unwrap();
225        writeln!(code, "{}        break;", indent).unwrap();
226        writeln!(code, "{}    }}", indent).unwrap();
227        writeln!(code).unwrap();
228
229        // Check if active
230        writeln!(code, "{}    // Check if kernel is active", indent).unwrap();
231        writeln!(
232            code,
233            "{}    if (atomicAdd(&control->is_active, 0) == 0) {{",
234            indent
235        )
236        .unwrap();
237        if self.idle_sleep_ns > 0 {
238            writeln!(
239                code,
240                "{}        __nanosleep({});",
241                indent, self.idle_sleep_ns
242            )
243            .unwrap();
244        }
245        writeln!(code, "{}        continue;", indent).unwrap();
246        writeln!(code, "{}    }}", indent).unwrap();
247        writeln!(code).unwrap();
248
249        // Check for messages
250        writeln!(code, "{}    // Check input queue for messages", indent).unwrap();
251        writeln!(
252            code,
253            "{}    unsigned long long head = atomicAdd(&control->input_head, 0);",
254            indent
255        )
256        .unwrap();
257        writeln!(
258            code,
259            "{}    unsigned long long tail = atomicAdd(&control->input_tail, 0);",
260            indent
261        )
262        .unwrap();
263        writeln!(code).unwrap();
264        writeln!(code, "{}    if (head == tail) {{", indent).unwrap();
265        writeln!(code, "{}        // No messages, yield", indent).unwrap();
266        if self.idle_sleep_ns > 0 {
267            writeln!(
268                code,
269                "{}        __nanosleep({});",
270                indent, self.idle_sleep_ns
271            )
272            .unwrap();
273        }
274        writeln!(code, "{}        continue;", indent).unwrap();
275        writeln!(code, "{}    }}", indent).unwrap();
276        writeln!(code).unwrap();
277
278        // Calculate message pointer
279        writeln!(code, "{}    // Get message from queue", indent).unwrap();
280        writeln!(
281            code,
282            "{}    unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
283            indent
284        )
285        .unwrap();
286        writeln!(
287            code,
288            "{}    unsigned char* msg_ptr = &input_buffer[msg_idx * MSG_SIZE];",
289            indent
290        )
291        .unwrap();
292        writeln!(code).unwrap();
293
294        code
295    }
296
297    /// Generate message processing completion code.
298    pub fn generate_message_complete(&self, indent: &str) -> String {
299        let mut code = String::new();
300
301        writeln!(code).unwrap();
302        writeln!(code, "{}    // Mark message as processed", indent).unwrap();
303        writeln!(code, "{}    atomicAdd(&control->input_tail, 1);", indent).unwrap();
304        writeln!(
305            code,
306            "{}    atomicAdd(&control->messages_processed, 1);",
307            indent
308        )
309        .unwrap();
310
311        if self.enable_hlc {
312            writeln!(code).unwrap();
313            writeln!(code, "{}    // Update HLC", indent).unwrap();
314            writeln!(code, "{}    hlc_logical++;", indent).unwrap();
315        }
316
317        code
318    }
319
320    /// Generate the loop footer (end of while loop).
321    pub fn generate_loop_footer(&self, indent: &str) -> String {
322        let mut code = String::new();
323
324        writeln!(code, "{}    __syncthreads();", indent).unwrap();
325        writeln!(code, "{}}}", indent).unwrap();
326
327        code
328    }
329
330    /// Generate kernel epilogue (termination marking).
331    pub fn generate_epilogue(&self, indent: &str) -> String {
332        let mut code = String::new();
333
334        writeln!(code).unwrap();
335        writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
336        writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
337
338        if self.enable_hlc {
339            writeln!(code, "{}    // Store final HLC state", indent).unwrap();
340            writeln!(
341                code,
342                "{}    control->hlc_state.physical = hlc_physical;",
343                indent
344            )
345            .unwrap();
346            writeln!(
347                code,
348                "{}    control->hlc_state.logical = hlc_logical;",
349                indent
350            )
351            .unwrap();
352        }
353
354        writeln!(
355            code,
356            "{}    atomicExch(&control->has_terminated, 1);",
357            indent
358        )
359        .unwrap();
360        writeln!(code, "{}}}", indent).unwrap();
361
362        code
363    }
364
365    /// Generate complete kernel wrapper (without handler body).
366    pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
367        let mut code = String::new();
368
369        // Struct definitions
370        code.push_str(&generate_control_block_struct());
371        code.push('\n');
372
373        if self.enable_hlc {
374            code.push_str(&generate_hlc_struct());
375            code.push('\n');
376        }
377
378        if self.enable_k2k {
379            code.push_str(&generate_k2k_structs());
380            code.push('\n');
381        }
382
383        // Kernel signature
384        code.push_str(&self.generate_signature());
385        code.push_str(" {\n");
386
387        // Preamble
388        code.push_str(&self.generate_preamble("    "));
389
390        // Message loop
391        code.push_str(&self.generate_loop_header("    "));
392
393        // Handler placeholder
394        writeln!(code, "        // === USER HANDLER CODE ===").unwrap();
395        for line in handler_placeholder.lines() {
396            writeln!(code, "        {}", line).unwrap();
397        }
398        writeln!(code, "        // === END HANDLER CODE ===").unwrap();
399
400        // Message completion
401        code.push_str(&self.generate_message_complete("    "));
402
403        // Loop footer
404        code.push_str(&self.generate_loop_footer("    "));
405
406        // Epilogue
407        code.push_str(&self.generate_epilogue("    "));
408
409        code.push_str("}\n");
410
411        code
412    }
413}
414
415/// Generate CUDA ControlBlock struct definition.
416pub fn generate_control_block_struct() -> String {
417    r#"// Control block for kernel state management (128 bytes, cache-line aligned)
418struct __align__(128) ControlBlock {
419    // Lifecycle state
420    unsigned int is_active;
421    unsigned int should_terminate;
422    unsigned int has_terminated;
423    unsigned int _pad1;
424
425    // Counters
426    unsigned long long messages_processed;
427    unsigned long long messages_in_flight;
428
429    // Queue pointers
430    unsigned long long input_head;
431    unsigned long long input_tail;
432    unsigned long long output_head;
433    unsigned long long output_tail;
434
435    // Queue metadata
436    unsigned int input_capacity;
437    unsigned int output_capacity;
438    unsigned int input_mask;
439    unsigned int output_mask;
440
441    // HLC state
442    struct {
443        unsigned long long physical;
444        unsigned long long logical;
445    } hlc_state;
446
447    // Error state
448    unsigned int last_error;
449    unsigned int error_count;
450
451    // Reserved padding
452    unsigned char _reserved[24];
453};
454"#
455    .to_string()
456}
457
458/// Generate CUDA HLC helper struct.
459pub fn generate_hlc_struct() -> String {
460    r#"// Hybrid Logical Clock state
461struct HlcState {
462    unsigned long long physical;
463    unsigned long long logical;
464};
465"#
466    .to_string()
467}
468
469/// Generate CUDA K2K routing structs and helper functions.
470pub fn generate_k2k_structs() -> String {
471    r#"// Kernel-to-kernel routing table entry
472struct K2KRoute {
473    unsigned long long target_kernel_id;
474    unsigned char* target_inbox;
475    unsigned long long* target_head;
476    unsigned long long* target_tail;
477    unsigned int capacity;
478    unsigned int mask;
479    unsigned int msg_size;
480};
481
482// K2K routing table
483struct K2KRoutingTable {
484    unsigned int num_routes;
485    unsigned int _pad;
486    K2KRoute routes[16];  // Max 16 K2K connections
487};
488
489// K2K inbox header (at start of k2k_inbox buffer)
490struct K2KInboxHeader {
491    unsigned long long head;
492    unsigned long long tail;
493    unsigned int capacity;
494    unsigned int mask;
495    unsigned int msg_size;
496    unsigned int _pad;
497};
498
499// Send a message to another kernel
500__device__ inline int k2k_send(
501    K2KRoutingTable* routes,
502    unsigned long long target_id,
503    const void* msg_ptr,
504    unsigned int msg_size
505) {
506    // Find route for target
507    for (unsigned int i = 0; i < routes->num_routes; i++) {
508        if (routes->routes[i].target_kernel_id == target_id) {
509            K2KRoute* route = &routes->routes[i];
510
511            // Atomically claim a slot in target's inbox
512            unsigned long long slot = atomicAdd(route->target_head, 1);
513            unsigned int idx = (unsigned int)(slot & route->mask);
514
515            // Copy message to target inbox
516            memcpy(
517                route->target_inbox + idx * route->msg_size,
518                msg_ptr,
519                msg_size < route->msg_size ? msg_size : route->msg_size
520            );
521
522            return 1;  // Success
523        }
524    }
525    return 0;  // Route not found
526}
527
528// Check if there are K2K messages in inbox
529__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
530    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
531    return atomicAdd(&header->head, 0) != atomicAdd(&header->tail, 0);
532}
533
534// Try to receive a K2K message
535__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
536    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
537
538    unsigned long long head = atomicAdd(&header->head, 0);
539    unsigned long long tail = atomicAdd(&header->tail, 0);
540
541    if (head == tail) {
542        return NULL;  // No messages
543    }
544
545    // Get message pointer
546    unsigned int idx = (unsigned int)(tail & header->mask);
547    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
548
549    // Advance tail (consume message)
550    atomicAdd(&header->tail, 1);
551
552    return data_start + idx * header->msg_size;
553}
554
555// Peek at next K2K message without consuming
556__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
557    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
558
559    unsigned long long head = atomicAdd(&header->head, 0);
560    unsigned long long tail = atomicAdd(&header->tail, 0);
561
562    if (head == tail) {
563        return NULL;  // No messages
564    }
565
566    unsigned int idx = (unsigned int)(tail & header->mask);
567    unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
568
569    return data_start + idx * header->msg_size;
570}
571
572// Get number of pending K2K messages
573__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
574    K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
575    unsigned long long head = atomicAdd(&header->head, 0);
576    unsigned long long tail = atomicAdd(&header->tail, 0);
577    return (unsigned int)(head - tail);
578}
579"#
580    .to_string()
581}
582
583/// Intrinsic functions available in ring kernel handlers.
584#[derive(Debug, Clone, Copy, PartialEq, Eq)]
585pub enum RingKernelIntrinsic {
586    // Control block access
587    IsActive,
588    ShouldTerminate,
589    MarkTerminated,
590    GetMessagesProcessed,
591
592    // Queue operations
593    InputQueueSize,
594    OutputQueueSize,
595    InputQueueEmpty,
596    OutputQueueEmpty,
597    EnqueueResponse,
598
599    // HLC operations
600    HlcTick,
601    HlcUpdate,
602    HlcNow,
603
604    // K2K operations
605    K2kSend,
606    K2kTryRecv,
607    K2kHasMessage,
608    K2kPeek,
609    K2kPendingCount,
610}
611
612impl RingKernelIntrinsic {
613    /// Get the CUDA code for this intrinsic.
614    pub fn to_cuda(&self, args: &[String]) -> String {
615        match self {
616            Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
617            Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
618            Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
619            Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
620
621            Self::InputQueueSize => {
622                "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
623                    .to_string()
624            }
625            Self::OutputQueueSize => {
626                "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
627                    .to_string()
628            }
629            Self::InputQueueEmpty => {
630                "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
631                    .to_string()
632            }
633            Self::OutputQueueEmpty => {
634                "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
635                    .to_string()
636            }
637            Self::EnqueueResponse => {
638                if !args.is_empty() {
639                    format!(
640                        "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
641                         memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
642                        args[0]
643                    )
644                } else {
645                    "/* enqueue_response requires response pointer */".to_string()
646                }
647            }
648
649            Self::HlcTick => "hlc_logical++".to_string(),
650            Self::HlcUpdate => {
651                if !args.is_empty() {
652                    format!(
653                        "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
654                        args[0], args[0]
655                    )
656                } else {
657                    "hlc_logical++".to_string()
658                }
659            }
660            Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
661
662            Self::K2kSend => {
663                if args.len() >= 2 {
664                    // k2k_send(target_id, msg_ptr) -> k2k_send(k2k_routes, target_id, msg_ptr, sizeof(*msg_ptr))
665                    format!(
666                        "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
667                        args[0], args[1], args[1]
668                    )
669                } else {
670                    "/* k2k_send requires target_id and msg_ptr */".to_string()
671                }
672            }
673            Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
674            Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
675            Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
676            Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
677        }
678    }
679
680    /// Parse a function name to get the intrinsic.
681    pub fn from_name(name: &str) -> Option<Self> {
682        match name {
683            "is_active" | "is_kernel_active" => Some(Self::IsActive),
684            "should_terminate" => Some(Self::ShouldTerminate),
685            "mark_terminated" => Some(Self::MarkTerminated),
686            "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
687
688            "input_queue_size" => Some(Self::InputQueueSize),
689            "output_queue_size" => Some(Self::OutputQueueSize),
690            "input_queue_empty" => Some(Self::InputQueueEmpty),
691            "output_queue_empty" => Some(Self::OutputQueueEmpty),
692            "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
693
694            "hlc_tick" => Some(Self::HlcTick),
695            "hlc_update" => Some(Self::HlcUpdate),
696            "hlc_now" => Some(Self::HlcNow),
697
698            "k2k_send" => Some(Self::K2kSend),
699            "k2k_try_recv" => Some(Self::K2kTryRecv),
700            "k2k_has_message" => Some(Self::K2kHasMessage),
701            "k2k_peek" => Some(Self::K2kPeek),
702            "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
703
704            _ => None,
705        }
706    }
707
708    /// Check if this intrinsic requires K2K support.
709    pub fn requires_k2k(&self) -> bool {
710        matches!(
711            self,
712            Self::K2kSend
713                | Self::K2kTryRecv
714                | Self::K2kHasMessage
715                | Self::K2kPeek
716                | Self::K2kPendingCount
717        )
718    }
719
720    /// Check if this intrinsic requires HLC support.
721    pub fn requires_hlc(&self) -> bool {
722        matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
723    }
724
725    /// Check if this intrinsic requires control block access.
726    pub fn requires_control_block(&self) -> bool {
727        matches!(
728            self,
729            Self::IsActive
730                | Self::ShouldTerminate
731                | Self::MarkTerminated
732                | Self::GetMessagesProcessed
733                | Self::InputQueueSize
734                | Self::OutputQueueSize
735                | Self::InputQueueEmpty
736                | Self::OutputQueueEmpty
737                | Self::EnqueueResponse
738        )
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745
746    #[test]
747    fn test_default_config() {
748        let config = RingKernelConfig::default();
749        assert_eq!(config.block_size, 128);
750        assert_eq!(config.queue_capacity, 1024);
751        assert!(config.enable_hlc);
752        assert!(!config.enable_k2k);
753    }
754
755    #[test]
756    fn test_config_builder() {
757        let config = RingKernelConfig::new("processor")
758            .with_block_size(256)
759            .with_queue_capacity(2048)
760            .with_k2k(true)
761            .with_hlc(true);
762
763        assert_eq!(config.id, "processor");
764        assert_eq!(config.kernel_name(), "ring_kernel_processor");
765        assert_eq!(config.block_size, 256);
766        assert_eq!(config.queue_capacity, 2048);
767        assert!(config.enable_k2k);
768        assert!(config.enable_hlc);
769    }
770
771    #[test]
772    fn test_kernel_signature() {
773        let config = RingKernelConfig::new("test");
774        let sig = config.generate_signature();
775
776        assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
777        assert!(sig.contains("ControlBlock* __restrict__ control"));
778        assert!(sig.contains("input_buffer"));
779        assert!(sig.contains("output_buffer"));
780        assert!(sig.contains("shared_state"));
781    }
782
783    #[test]
784    fn test_kernel_signature_with_k2k() {
785        let config = RingKernelConfig::new("k2k_test").with_k2k(true);
786        let sig = config.generate_signature();
787
788        assert!(sig.contains("K2KRoutingTable"));
789        assert!(sig.contains("k2k_inbox"));
790        assert!(sig.contains("k2k_outbox"));
791    }
792
793    #[test]
794    fn test_preamble_generation() {
795        let config = RingKernelConfig::new("test").with_hlc(true);
796        let preamble = config.generate_preamble("    ");
797
798        assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
799        assert!(preamble.contains("int lane_id"));
800        assert!(preamble.contains("int warp_id"));
801        assert!(preamble.contains("MSG_SIZE"));
802        assert!(preamble.contains("hlc_physical"));
803        assert!(preamble.contains("hlc_logical"));
804    }
805
806    #[test]
807    fn test_loop_header() {
808        let config = RingKernelConfig::new("test");
809        let header = config.generate_loop_header("    ");
810
811        assert!(header.contains("while (true)"));
812        assert!(header.contains("should_terminate"));
813        assert!(header.contains("is_active"));
814        assert!(header.contains("input_head"));
815        assert!(header.contains("input_tail"));
816        assert!(header.contains("msg_ptr"));
817    }
818
819    #[test]
820    fn test_epilogue() {
821        let config = RingKernelConfig::new("test").with_hlc(true);
822        let epilogue = config.generate_epilogue("    ");
823
824        assert!(epilogue.contains("has_terminated"));
825        assert!(epilogue.contains("hlc_state.physical"));
826        assert!(epilogue.contains("hlc_state.logical"));
827    }
828
829    #[test]
830    fn test_control_block_struct() {
831        let code = generate_control_block_struct();
832
833        assert!(code.contains("struct __align__(128) ControlBlock"));
834        assert!(code.contains("is_active"));
835        assert!(code.contains("should_terminate"));
836        assert!(code.contains("has_terminated"));
837        assert!(code.contains("messages_processed"));
838        assert!(code.contains("input_head"));
839        assert!(code.contains("input_tail"));
840        assert!(code.contains("hlc_state"));
841    }
842
843    #[test]
844    fn test_full_kernel_wrapper() {
845        let config = RingKernelConfig::new("example")
846            .with_block_size(128)
847            .with_hlc(true);
848
849        let kernel = config.generate_kernel_wrapper("// Process message here");
850
851        assert!(kernel.contains("struct __align__(128) ControlBlock"));
852        assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
853        assert!(kernel.contains("while (true)"));
854        assert!(kernel.contains("// Process message here"));
855        assert!(kernel.contains("has_terminated"));
856
857        println!("Generated kernel:\n{}", kernel);
858    }
859
860    #[test]
861    fn test_intrinsic_lookup() {
862        assert_eq!(
863            RingKernelIntrinsic::from_name("is_active"),
864            Some(RingKernelIntrinsic::IsActive)
865        );
866        assert_eq!(
867            RingKernelIntrinsic::from_name("should_terminate"),
868            Some(RingKernelIntrinsic::ShouldTerminate)
869        );
870        assert_eq!(
871            RingKernelIntrinsic::from_name("hlc_tick"),
872            Some(RingKernelIntrinsic::HlcTick)
873        );
874        assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
875    }
876
877    #[test]
878    fn test_intrinsic_cuda_output() {
879        assert!(RingKernelIntrinsic::IsActive
880            .to_cuda(&[])
881            .contains("is_active"));
882        assert!(RingKernelIntrinsic::ShouldTerminate
883            .to_cuda(&[])
884            .contains("should_terminate"));
885        assert!(RingKernelIntrinsic::HlcTick
886            .to_cuda(&[])
887            .contains("hlc_logical++"));
888    }
889
890    #[test]
891    fn test_k2k_structs_generation() {
892        let k2k_code = generate_k2k_structs();
893
894        // Check struct definitions
895        assert!(
896            k2k_code.contains("struct K2KRoute"),
897            "Should have K2KRoute struct"
898        );
899        assert!(
900            k2k_code.contains("struct K2KRoutingTable"),
901            "Should have K2KRoutingTable struct"
902        );
903        assert!(
904            k2k_code.contains("struct K2KInboxHeader"),
905            "Should have K2KInboxHeader struct"
906        );
907
908        // Check helper functions
909        assert!(
910            k2k_code.contains("__device__ inline int k2k_send"),
911            "Should have k2k_send function"
912        );
913        assert!(
914            k2k_code.contains("__device__ inline int k2k_has_message"),
915            "Should have k2k_has_message function"
916        );
917        assert!(
918            k2k_code.contains("__device__ inline void* k2k_try_recv"),
919            "Should have k2k_try_recv function"
920        );
921        assert!(
922            k2k_code.contains("__device__ inline void* k2k_peek"),
923            "Should have k2k_peek function"
924        );
925        assert!(
926            k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
927            "Should have k2k_pending_count function"
928        );
929
930        println!("K2K code:\n{}", k2k_code);
931    }
932
933    #[test]
934    fn test_full_k2k_kernel() {
935        let config = RingKernelConfig::new("k2k_processor")
936            .with_block_size(128)
937            .with_k2k(true)
938            .with_hlc(true);
939
940        let kernel = config.generate_kernel_wrapper("// K2K handler code");
941
942        // Check K2K-specific components
943        assert!(
944            kernel.contains("K2KRoutingTable"),
945            "Should have K2KRoutingTable"
946        );
947        assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
948        assert!(
949            kernel.contains("K2KInboxHeader"),
950            "Should have K2KInboxHeader"
951        );
952        assert!(
953            kernel.contains("k2k_routes"),
954            "Should have k2k_routes param"
955        );
956        assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
957        assert!(
958            kernel.contains("k2k_outbox"),
959            "Should have k2k_outbox param"
960        );
961        assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
962        assert!(
963            kernel.contains("k2k_try_recv"),
964            "Should have k2k_try_recv function"
965        );
966
967        println!("Full K2K kernel:\n{}", kernel);
968    }
969
970    #[test]
971    fn test_k2k_intrinsic_requirements() {
972        assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
973        assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
974        assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
975        assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
976        assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
977
978        assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
979        assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
980    }
981}