1use std::fmt::Write;
46
47#[derive(Debug, Clone)]
49pub struct RingKernelConfig {
50 pub id: String,
52 pub block_size: u32,
54 pub queue_capacity: u32,
56 pub enable_k2k: bool,
58 pub enable_hlc: bool,
60 pub message_size: usize,
62 pub response_size: usize,
64 pub cooperative_groups: bool,
66 pub idle_sleep_ns: u32,
68 pub use_envelope_format: bool,
71 pub kernel_id_num: u64,
73 pub hlc_node_id: u64,
75}
76
77impl Default for RingKernelConfig {
78 fn default() -> Self {
79 Self {
80 id: "ring_kernel".to_string(),
81 block_size: 128,
82 queue_capacity: 1024,
83 enable_k2k: false,
84 enable_hlc: true,
85 message_size: 64,
86 response_size: 64,
87 cooperative_groups: false,
88 idle_sleep_ns: 1000,
89 use_envelope_format: true, kernel_id_num: 0,
91 hlc_node_id: 0,
92 }
93 }
94}
95
96impl RingKernelConfig {
97 pub fn new(id: impl Into<String>) -> Self {
99 Self {
100 id: id.into(),
101 ..Default::default()
102 }
103 }
104
105 pub fn with_block_size(mut self, size: u32) -> Self {
107 self.block_size = size;
108 self
109 }
110
111 pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
113 debug_assert!(
114 capacity.is_power_of_two(),
115 "Queue capacity must be power of 2"
116 );
117 self.queue_capacity = capacity;
118 self
119 }
120
121 pub fn with_k2k(mut self, enabled: bool) -> Self {
123 self.enable_k2k = enabled;
124 self
125 }
126
127 pub fn with_hlc(mut self, enabled: bool) -> Self {
129 self.enable_hlc = enabled;
130 self
131 }
132
133 pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
135 self.message_size = message;
136 self.response_size = response;
137 self
138 }
139
140 pub fn with_idle_sleep(mut self, ns: u32) -> Self {
142 self.idle_sleep_ns = ns;
143 self
144 }
145
146 pub fn with_envelope_format(mut self, enabled: bool) -> Self {
149 self.use_envelope_format = enabled;
150 self
151 }
152
153 pub fn with_kernel_id(mut self, id: u64) -> Self {
155 self.kernel_id_num = id;
156 self
157 }
158
159 pub fn with_hlc_node_id(mut self, node_id: u64) -> Self {
161 self.hlc_node_id = node_id;
162 self
163 }
164
165 pub fn kernel_name(&self) -> String {
167 format!("ring_kernel_{}", self.id)
168 }
169
170 pub fn generate_signature(&self) -> String {
172 let mut sig = String::new();
173
174 writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
175 writeln!(sig, " ControlBlock* __restrict__ control,").unwrap();
176 writeln!(sig, " unsigned char* __restrict__ input_buffer,").unwrap();
177 writeln!(sig, " unsigned char* __restrict__ output_buffer,").unwrap();
178
179 if self.enable_k2k {
180 writeln!(sig, " K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
181 writeln!(sig, " unsigned char* __restrict__ k2k_inbox,").unwrap();
182 writeln!(sig, " unsigned char* __restrict__ k2k_outbox,").unwrap();
183 }
184
185 write!(sig, " void* __restrict__ shared_state").unwrap();
186 write!(sig, "\n)").unwrap();
187
188 sig
189 }
190
191 pub fn generate_preamble(&self, indent: &str) -> String {
193 let mut code = String::new();
194
195 writeln!(code, "{}// Thread identification", indent).unwrap();
197 writeln!(
198 code,
199 "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
200 indent
201 )
202 .unwrap();
203 writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
204 writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
205 writeln!(code).unwrap();
206
207 writeln!(code, "{}// Kernel identity", indent).unwrap();
209 writeln!(
210 code,
211 "{}const unsigned long long KERNEL_ID = {}ULL;",
212 indent, self.kernel_id_num
213 )
214 .unwrap();
215 writeln!(
216 code,
217 "{}const unsigned long long HLC_NODE_ID = {}ULL;",
218 indent, self.hlc_node_id
219 )
220 .unwrap();
221 writeln!(code).unwrap();
222
223 writeln!(code, "{}// Message buffer constants", indent).unwrap();
225 if self.use_envelope_format {
226 writeln!(
228 code,
229 "{}const unsigned int PAYLOAD_SIZE = {}; // User payload size",
230 indent, self.message_size
231 )
232 .unwrap();
233 writeln!(
234 code,
235 "{}const unsigned int MSG_SIZE = MESSAGE_HEADER_SIZE + PAYLOAD_SIZE; // Full envelope",
236 indent
237 )
238 .unwrap();
239 writeln!(
240 code,
241 "{}const unsigned int RESP_PAYLOAD_SIZE = {};",
242 indent, self.response_size
243 )
244 .unwrap();
245 writeln!(
246 code,
247 "{}const unsigned int RESP_SIZE = MESSAGE_HEADER_SIZE + RESP_PAYLOAD_SIZE;",
248 indent
249 )
250 .unwrap();
251 } else {
252 writeln!(
254 code,
255 "{}const unsigned int MSG_SIZE = {};",
256 indent, self.message_size
257 )
258 .unwrap();
259 writeln!(
260 code,
261 "{}const unsigned int RESP_SIZE = {};",
262 indent, self.response_size
263 )
264 .unwrap();
265 }
266 writeln!(
267 code,
268 "{}const unsigned int QUEUE_MASK = {};",
269 indent,
270 self.queue_capacity - 1
271 )
272 .unwrap();
273 writeln!(code).unwrap();
274
275 if self.enable_hlc {
277 writeln!(code, "{}// HLC clock state", indent).unwrap();
278 writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
279 writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
280 writeln!(code).unwrap();
281 }
282
283 code
284 }
285
286 pub fn generate_loop_header(&self, indent: &str) -> String {
288 let mut code = String::new();
289
290 writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
291 writeln!(code, "{}while (true) {{", indent).unwrap();
292 writeln!(code, "{} // Check for termination signal", indent).unwrap();
293 writeln!(
294 code,
295 "{} if (atomicAdd(&control->should_terminate, 0) != 0) {{",
296 indent
297 )
298 .unwrap();
299 writeln!(code, "{} break;", indent).unwrap();
300 writeln!(code, "{} }}", indent).unwrap();
301 writeln!(code).unwrap();
302
303 writeln!(code, "{} // Check if kernel is active", indent).unwrap();
305 writeln!(
306 code,
307 "{} if (atomicAdd(&control->is_active, 0) == 0) {{",
308 indent
309 )
310 .unwrap();
311 if self.idle_sleep_ns > 0 {
312 writeln!(
313 code,
314 "{} __nanosleep({});",
315 indent, self.idle_sleep_ns
316 )
317 .unwrap();
318 }
319 writeln!(code, "{} continue;", indent).unwrap();
320 writeln!(code, "{} }}", indent).unwrap();
321 writeln!(code).unwrap();
322
323 writeln!(code, "{} // Check input queue for messages", indent).unwrap();
325 writeln!(
326 code,
327 "{} unsigned long long head = atomicAdd(&control->input_head, 0);",
328 indent
329 )
330 .unwrap();
331 writeln!(
332 code,
333 "{} unsigned long long tail = atomicAdd(&control->input_tail, 0);",
334 indent
335 )
336 .unwrap();
337 writeln!(code).unwrap();
338 writeln!(code, "{} if (head == tail) {{", indent).unwrap();
339 writeln!(code, "{} // No messages, yield", indent).unwrap();
340 if self.idle_sleep_ns > 0 {
341 writeln!(
342 code,
343 "{} __nanosleep({});",
344 indent, self.idle_sleep_ns
345 )
346 .unwrap();
347 }
348 writeln!(code, "{} continue;", indent).unwrap();
349 writeln!(code, "{} }}", indent).unwrap();
350 writeln!(code).unwrap();
351
352 writeln!(code, "{} // Get message from queue", indent).unwrap();
354 writeln!(
355 code,
356 "{} unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
357 indent
358 )
359 .unwrap();
360 writeln!(
361 code,
362 "{} unsigned char* envelope_ptr = &input_buffer[msg_idx * MSG_SIZE];",
363 indent
364 )
365 .unwrap();
366
367 if self.use_envelope_format {
368 writeln!(code).unwrap();
370 writeln!(code, "{} // Parse message envelope", indent).unwrap();
371 writeln!(
372 code,
373 "{} MessageHeader* msg_header = message_get_header(envelope_ptr);",
374 indent
375 )
376 .unwrap();
377 writeln!(
378 code,
379 "{} unsigned char* msg_ptr = message_get_payload(envelope_ptr);",
380 indent
381 )
382 .unwrap();
383 writeln!(code).unwrap();
384 writeln!(code, "{} // Validate message (skip invalid)", indent).unwrap();
385 writeln!(
386 code,
387 "{} if (!message_header_validate(msg_header)) {{",
388 indent
389 )
390 .unwrap();
391 writeln!(
392 code,
393 "{} atomicAdd(&control->input_tail, 1);",
394 indent
395 )
396 .unwrap();
397 writeln!(
398 code,
399 "{} atomicAdd(&control->last_error, 1); // Track errors",
400 indent
401 )
402 .unwrap();
403 writeln!(code, "{} continue;", indent).unwrap();
404 writeln!(code, "{} }}", indent).unwrap();
405
406 if self.enable_hlc {
408 writeln!(code).unwrap();
409 writeln!(code, "{} // Update HLC from message timestamp", indent).unwrap();
410 writeln!(
411 code,
412 "{} if (msg_header->timestamp.physical > hlc_physical) {{",
413 indent
414 )
415 .unwrap();
416 writeln!(
417 code,
418 "{} hlc_physical = msg_header->timestamp.physical;",
419 indent
420 )
421 .unwrap();
422 writeln!(code, "{} hlc_logical = 0;", indent).unwrap();
423 writeln!(code, "{} }}", indent).unwrap();
424 writeln!(code, "{} hlc_logical++;", indent).unwrap();
425 }
426 } else {
427 writeln!(
429 code,
430 "{} unsigned char* msg_ptr = envelope_ptr; // Raw message data",
431 indent
432 )
433 .unwrap();
434 }
435 writeln!(code).unwrap();
436
437 code
438 }
439
440 pub fn generate_message_complete(&self, indent: &str) -> String {
442 let mut code = String::new();
443
444 writeln!(code).unwrap();
445 writeln!(code, "{} // Mark message as processed", indent).unwrap();
446 writeln!(code, "{} atomicAdd(&control->input_tail, 1);", indent).unwrap();
447 writeln!(
448 code,
449 "{} atomicAdd(&control->messages_processed, 1);",
450 indent
451 )
452 .unwrap();
453
454 if self.enable_hlc {
455 writeln!(code).unwrap();
456 writeln!(code, "{} // Update HLC", indent).unwrap();
457 writeln!(code, "{} hlc_logical++;", indent).unwrap();
458 }
459
460 code
461 }
462
463 pub fn generate_loop_footer(&self, indent: &str) -> String {
465 let mut code = String::new();
466
467 writeln!(code, "{} __syncthreads();", indent).unwrap();
468 writeln!(code, "{}}}", indent).unwrap();
469
470 code
471 }
472
473 pub fn generate_epilogue(&self, indent: &str) -> String {
475 let mut code = String::new();
476
477 writeln!(code).unwrap();
478 writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
479 writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
480
481 if self.enable_hlc {
482 writeln!(code, "{} // Store final HLC state", indent).unwrap();
483 writeln!(
484 code,
485 "{} control->hlc_state.physical = hlc_physical;",
486 indent
487 )
488 .unwrap();
489 writeln!(
490 code,
491 "{} control->hlc_state.logical = hlc_logical;",
492 indent
493 )
494 .unwrap();
495 }
496
497 writeln!(
498 code,
499 "{} atomicExch(&control->has_terminated, 1);",
500 indent
501 )
502 .unwrap();
503 writeln!(code, "{}}}", indent).unwrap();
504
505 code
506 }
507
508 pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
510 let mut code = String::new();
511
512 code.push_str(&generate_control_block_struct());
514 code.push('\n');
515
516 if self.enable_hlc {
517 code.push_str(&generate_hlc_struct());
518 code.push('\n');
519 }
520
521 if self.use_envelope_format {
523 code.push_str(&generate_message_envelope_structs());
524 code.push('\n');
525 }
526
527 if self.enable_k2k {
528 code.push_str(&generate_k2k_structs());
529 code.push('\n');
530 }
531
532 code.push_str(&self.generate_signature());
534 code.push_str(" {\n");
535
536 code.push_str(&self.generate_preamble(" "));
538
539 code.push_str(&self.generate_loop_header(" "));
541
542 writeln!(code, " // === USER HANDLER CODE ===").unwrap();
544 for line in handler_placeholder.lines() {
545 writeln!(code, " {}", line).unwrap();
546 }
547 writeln!(code, " // === END HANDLER CODE ===").unwrap();
548
549 code.push_str(&self.generate_message_complete(" "));
551
552 code.push_str(&self.generate_loop_footer(" "));
554
555 code.push_str(&self.generate_epilogue(" "));
557
558 code.push_str("}\n");
559
560 code
561 }
562}
563
564pub fn generate_control_block_struct() -> String {
566 r#"// Control block for kernel state management (128 bytes, cache-line aligned)
567struct __align__(128) ControlBlock {
568 // Lifecycle state
569 unsigned int is_active;
570 unsigned int should_terminate;
571 unsigned int has_terminated;
572 unsigned int _pad1;
573
574 // Counters
575 unsigned long long messages_processed;
576 unsigned long long messages_in_flight;
577
578 // Queue pointers
579 unsigned long long input_head;
580 unsigned long long input_tail;
581 unsigned long long output_head;
582 unsigned long long output_tail;
583
584 // Queue metadata
585 unsigned int input_capacity;
586 unsigned int output_capacity;
587 unsigned int input_mask;
588 unsigned int output_mask;
589
590 // HLC state
591 struct {
592 unsigned long long physical;
593 unsigned long long logical;
594 } hlc_state;
595
596 // Error state
597 unsigned int last_error;
598 unsigned int error_count;
599
600 // Reserved padding
601 unsigned char _reserved[24];
602};
603"#
604 .to_string()
605}
606
607pub fn generate_hlc_struct() -> String {
609 r#"// Hybrid Logical Clock state
610struct HlcState {
611 unsigned long long physical;
612 unsigned long long logical;
613};
614
615// HLC timestamp (24 bytes)
616struct __align__(8) HlcTimestamp {
617 unsigned long long physical;
618 unsigned long long logical;
619 unsigned long long node_id;
620};
621"#
622 .to_string()
623}
624
625pub fn generate_message_envelope_structs() -> String {
630 r#"// Magic number for message validation
631#define MESSAGE_MAGIC 0x52494E474B45524E ULL // "RINGKERN"
632#define MESSAGE_VERSION 1
633#define MESSAGE_HEADER_SIZE 256
634#define MAX_PAYLOAD_SIZE (64 * 1024)
635
636// Message priority levels
637#define PRIORITY_LOW 0
638#define PRIORITY_NORMAL 1
639#define PRIORITY_HIGH 2
640#define PRIORITY_CRITICAL 3
641
642// Message header structure (256 bytes, cache-line aligned)
643// Matches ringkernel_core::message::MessageHeader exactly
644struct __align__(64) MessageHeader {
645 // Magic number for validation (0xRINGKERN)
646 unsigned long long magic;
647 // Header version
648 unsigned int version;
649 // Message flags
650 unsigned int flags;
651 // Unique message identifier
652 unsigned long long message_id;
653 // Correlation ID for request-response
654 unsigned long long correlation_id;
655 // Source kernel ID (0 for host)
656 unsigned long long source_kernel;
657 // Destination kernel ID (0 for host)
658 unsigned long long dest_kernel;
659 // Message type discriminator
660 unsigned long long message_type;
661 // Priority level
662 unsigned char priority;
663 // Reserved for alignment
664 unsigned char _reserved1[7];
665 // Payload size in bytes
666 unsigned long long payload_size;
667 // Checksum of payload (CRC32)
668 unsigned int checksum;
669 // Reserved for alignment
670 unsigned int _reserved2;
671 // HLC timestamp when message was created
672 HlcTimestamp timestamp;
673 // Deadline timestamp (0 = no deadline)
674 HlcTimestamp deadline;
675 // Reserved for future use (104 bytes total: 32+32+32+8)
676 unsigned char _reserved3[104];
677};
678
679// Validate message header
680__device__ inline int message_header_validate(const MessageHeader* header) {
681 return header->magic == MESSAGE_MAGIC &&
682 header->version <= MESSAGE_VERSION &&
683 header->payload_size <= MAX_PAYLOAD_SIZE;
684}
685
686// Get payload pointer from header
687__device__ inline unsigned char* message_get_payload(unsigned char* envelope_ptr) {
688 return envelope_ptr + MESSAGE_HEADER_SIZE;
689}
690
691// Get header from envelope pointer
692__device__ inline MessageHeader* message_get_header(unsigned char* envelope_ptr) {
693 return (MessageHeader*)envelope_ptr;
694}
695
696// Create a response header based on request
697__device__ inline void message_create_response_header(
698 MessageHeader* response,
699 const MessageHeader* request,
700 unsigned long long this_kernel_id,
701 unsigned long long payload_size,
702 unsigned long long hlc_physical,
703 unsigned long long hlc_logical,
704 unsigned long long hlc_node_id
705) {
706 response->magic = MESSAGE_MAGIC;
707 response->version = MESSAGE_VERSION;
708 response->flags = 0;
709 // Generate new message ID (simple increment from request)
710 response->message_id = request->message_id + 0x100000000ULL;
711 // Preserve correlation ID for request-response matching
712 response->correlation_id = request->correlation_id != 0
713 ? request->correlation_id
714 : request->message_id;
715 response->source_kernel = this_kernel_id;
716 response->dest_kernel = request->source_kernel; // Response goes back to sender
717 response->message_type = request->message_type + 1; // Convention: response type = request + 1
718 response->priority = request->priority;
719 response->payload_size = payload_size;
720 response->checksum = 0; // TODO: compute checksum
721 response->timestamp.physical = hlc_physical;
722 response->timestamp.logical = hlc_logical;
723 response->timestamp.node_id = hlc_node_id;
724 response->deadline.physical = 0;
725 response->deadline.logical = 0;
726 response->deadline.node_id = 0;
727}
728
729// Calculate total envelope size
730__device__ inline unsigned int message_envelope_size(const MessageHeader* header) {
731 return MESSAGE_HEADER_SIZE + (unsigned int)header->payload_size;
732}
733"#
734 .to_string()
735}
736
737pub fn generate_k2k_structs() -> String {
739 r#"// Kernel-to-kernel routing table entry (48 bytes)
740// Matches ringkernel_cuda::k2k_gpu::K2KRouteEntry
741struct K2KRoute {
742 unsigned long long target_kernel_id;
743 unsigned long long target_inbox; // device pointer
744 unsigned long long target_head; // device pointer to head
745 unsigned long long target_tail; // device pointer to tail
746 unsigned int capacity;
747 unsigned int mask;
748 unsigned int msg_size;
749 unsigned int _pad;
750};
751
752// K2K routing table (8 + 48*16 = 776 bytes)
753struct K2KRoutingTable {
754 unsigned int num_routes;
755 unsigned int _pad;
756 K2KRoute routes[16]; // Max 16 K2K connections
757};
758
759// K2K inbox header (64 bytes, cache-line aligned)
760// Matches ringkernel_cuda::k2k_gpu::K2KInboxHeader
761struct __align__(64) K2KInboxHeader {
762 unsigned long long head;
763 unsigned long long tail;
764 unsigned int capacity;
765 unsigned int mask;
766 unsigned int msg_size;
767 unsigned int _pad;
768};
769
770// Send a message envelope to another kernel via K2K
771// The entire envelope (header + payload) is copied to the target's inbox
772__device__ inline int k2k_send_envelope(
773 K2KRoutingTable* routes,
774 unsigned long long target_id,
775 unsigned long long source_kernel_id,
776 const void* payload_ptr,
777 unsigned int payload_size,
778 unsigned long long message_type,
779 unsigned long long hlc_physical,
780 unsigned long long hlc_logical,
781 unsigned long long hlc_node_id
782) {
783 // Find route for target
784 for (unsigned int i = 0; i < routes->num_routes; i++) {
785 if (routes->routes[i].target_kernel_id == target_id) {
786 K2KRoute* route = &routes->routes[i];
787
788 // Calculate total envelope size
789 unsigned int envelope_size = MESSAGE_HEADER_SIZE + payload_size;
790 if (envelope_size > route->msg_size) {
791 return -1; // Message too large
792 }
793
794 // Atomically claim a slot in target's inbox
795 unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
796 unsigned long long slot = atomicAdd(target_head_ptr, 1);
797 unsigned int idx = (unsigned int)(slot & route->mask);
798
799 // Calculate destination pointer
800 unsigned char* dest = ((unsigned char*)route->target_inbox) +
801 sizeof(K2KInboxHeader) + idx * route->msg_size;
802
803 // Build message header
804 MessageHeader* header = (MessageHeader*)dest;
805 header->magic = MESSAGE_MAGIC;
806 header->version = MESSAGE_VERSION;
807 header->flags = 0;
808 header->message_id = (source_kernel_id << 32) | (slot & 0xFFFFFFFF);
809 header->correlation_id = 0;
810 header->source_kernel = source_kernel_id;
811 header->dest_kernel = target_id;
812 header->message_type = message_type;
813 header->priority = PRIORITY_NORMAL;
814 header->payload_size = payload_size;
815 header->checksum = 0;
816 header->timestamp.physical = hlc_physical;
817 header->timestamp.logical = hlc_logical;
818 header->timestamp.node_id = hlc_node_id;
819 header->deadline.physical = 0;
820 header->deadline.logical = 0;
821 header->deadline.node_id = 0;
822
823 // Copy payload after header
824 if (payload_size > 0 && payload_ptr != NULL) {
825 memcpy(dest + MESSAGE_HEADER_SIZE, payload_ptr, payload_size);
826 }
827
828 __threadfence(); // Ensure write is visible
829 return 1; // Success
830 }
831 }
832 return 0; // Route not found
833}
834
835// Legacy k2k_send for raw message data (no envelope)
836__device__ inline int k2k_send(
837 K2KRoutingTable* routes,
838 unsigned long long target_id,
839 const void* msg_ptr,
840 unsigned int msg_size
841) {
842 // Find route for target
843 for (unsigned int i = 0; i < routes->num_routes; i++) {
844 if (routes->routes[i].target_kernel_id == target_id) {
845 K2KRoute* route = &routes->routes[i];
846
847 // Atomically claim a slot in target's inbox
848 unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
849 unsigned long long slot = atomicAdd(target_head_ptr, 1);
850 unsigned int idx = (unsigned int)(slot & route->mask);
851
852 // Copy message to target inbox
853 unsigned char* dest = ((unsigned char*)route->target_inbox) +
854 sizeof(K2KInboxHeader) + idx * route->msg_size;
855 memcpy(dest, msg_ptr, msg_size < route->msg_size ? msg_size : route->msg_size);
856
857 __threadfence();
858 return 1; // Success
859 }
860 }
861 return 0; // Route not found
862}
863
864// Check if there are K2K messages in inbox
865__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
866 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
867 unsigned long long head = atomicAdd(&header->head, 0);
868 unsigned long long tail = atomicAdd(&header->tail, 0);
869 return head != tail;
870}
871
872// Try to receive a K2K message envelope
873// Returns pointer to MessageHeader if available, NULL otherwise
874__device__ inline MessageHeader* k2k_try_recv_envelope(unsigned char* k2k_inbox) {
875 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
876
877 unsigned long long head = atomicAdd(&header->head, 0);
878 unsigned long long tail = atomicAdd(&header->tail, 0);
879
880 if (head == tail) {
881 return NULL; // No messages
882 }
883
884 // Get message pointer
885 unsigned int idx = (unsigned int)(tail & header->mask);
886 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
887 MessageHeader* msg_header = (MessageHeader*)(data_start + idx * header->msg_size);
888
889 // Validate header
890 if (!message_header_validate(msg_header)) {
891 // Invalid message, skip it
892 atomicAdd(&header->tail, 1);
893 return NULL;
894 }
895
896 // Advance tail (consume message)
897 atomicAdd(&header->tail, 1);
898
899 return msg_header;
900}
901
902// Legacy k2k_try_recv for raw data
903__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
904 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
905
906 unsigned long long head = atomicAdd(&header->head, 0);
907 unsigned long long tail = atomicAdd(&header->tail, 0);
908
909 if (head == tail) {
910 return NULL; // No messages
911 }
912
913 // Get message pointer
914 unsigned int idx = (unsigned int)(tail & header->mask);
915 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
916
917 // Advance tail (consume message)
918 atomicAdd(&header->tail, 1);
919
920 return data_start + idx * header->msg_size;
921}
922
923// Peek at next K2K message without consuming
924__device__ inline MessageHeader* k2k_peek_envelope(unsigned char* k2k_inbox) {
925 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
926
927 unsigned long long head = atomicAdd(&header->head, 0);
928 unsigned long long tail = atomicAdd(&header->tail, 0);
929
930 if (head == tail) {
931 return NULL; // No messages
932 }
933
934 unsigned int idx = (unsigned int)(tail & header->mask);
935 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
936
937 return (MessageHeader*)(data_start + idx * header->msg_size);
938}
939
940// Legacy k2k_peek
941__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
942 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
943
944 unsigned long long head = atomicAdd(&header->head, 0);
945 unsigned long long tail = atomicAdd(&header->tail, 0);
946
947 if (head == tail) {
948 return NULL; // No messages
949 }
950
951 unsigned int idx = (unsigned int)(tail & header->mask);
952 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
953
954 return data_start + idx * header->msg_size;
955}
956
957// Get number of pending K2K messages
958__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
959 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
960 unsigned long long head = atomicAdd(&header->head, 0);
961 unsigned long long tail = atomicAdd(&header->tail, 0);
962 return (unsigned int)(head - tail);
963}
964"#
965 .to_string()
966}
967
968#[derive(Debug, Clone, Copy, PartialEq, Eq)]
970pub enum RingKernelIntrinsic {
971 IsActive,
973 ShouldTerminate,
974 MarkTerminated,
975 GetMessagesProcessed,
976
977 InputQueueSize,
979 OutputQueueSize,
980 InputQueueEmpty,
981 OutputQueueEmpty,
982 EnqueueResponse,
983
984 HlcTick,
986 HlcUpdate,
987 HlcNow,
988
989 K2kSend,
991 K2kTryRecv,
992 K2kHasMessage,
993 K2kPeek,
994 K2kPendingCount,
995}
996
997impl RingKernelIntrinsic {
998 pub fn to_cuda(&self, args: &[String]) -> String {
1000 match self {
1001 Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
1002 Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
1003 Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
1004 Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
1005
1006 Self::InputQueueSize => {
1007 "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
1008 .to_string()
1009 }
1010 Self::OutputQueueSize => {
1011 "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
1012 .to_string()
1013 }
1014 Self::InputQueueEmpty => {
1015 "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
1016 .to_string()
1017 }
1018 Self::OutputQueueEmpty => {
1019 "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
1020 .to_string()
1021 }
1022 Self::EnqueueResponse => {
1023 if !args.is_empty() {
1024 format!(
1025 "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
1026 memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
1027 args[0]
1028 )
1029 } else {
1030 "/* enqueue_response requires response pointer */".to_string()
1031 }
1032 }
1033
1034 Self::HlcTick => "hlc_logical++".to_string(),
1035 Self::HlcUpdate => {
1036 if !args.is_empty() {
1037 format!(
1038 "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
1039 args[0], args[0]
1040 )
1041 } else {
1042 "hlc_logical++".to_string()
1043 }
1044 }
1045 Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
1046
1047 Self::K2kSend => {
1048 if args.len() >= 2 {
1049 format!(
1051 "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
1052 args[0], args[1], args[1]
1053 )
1054 } else {
1055 "/* k2k_send requires target_id and msg_ptr */".to_string()
1056 }
1057 }
1058 Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
1059 Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
1060 Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
1061 Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
1062 }
1063 }
1064
1065 pub fn from_name(name: &str) -> Option<Self> {
1067 match name {
1068 "is_active" | "is_kernel_active" => Some(Self::IsActive),
1069 "should_terminate" => Some(Self::ShouldTerminate),
1070 "mark_terminated" => Some(Self::MarkTerminated),
1071 "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
1072
1073 "input_queue_size" => Some(Self::InputQueueSize),
1074 "output_queue_size" => Some(Self::OutputQueueSize),
1075 "input_queue_empty" => Some(Self::InputQueueEmpty),
1076 "output_queue_empty" => Some(Self::OutputQueueEmpty),
1077 "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
1078
1079 "hlc_tick" => Some(Self::HlcTick),
1080 "hlc_update" => Some(Self::HlcUpdate),
1081 "hlc_now" => Some(Self::HlcNow),
1082
1083 "k2k_send" => Some(Self::K2kSend),
1084 "k2k_try_recv" => Some(Self::K2kTryRecv),
1085 "k2k_has_message" => Some(Self::K2kHasMessage),
1086 "k2k_peek" => Some(Self::K2kPeek),
1087 "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
1088
1089 _ => None,
1090 }
1091 }
1092
1093 pub fn requires_k2k(&self) -> bool {
1095 matches!(
1096 self,
1097 Self::K2kSend
1098 | Self::K2kTryRecv
1099 | Self::K2kHasMessage
1100 | Self::K2kPeek
1101 | Self::K2kPendingCount
1102 )
1103 }
1104
1105 pub fn requires_hlc(&self) -> bool {
1107 matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
1108 }
1109
1110 pub fn requires_control_block(&self) -> bool {
1112 matches!(
1113 self,
1114 Self::IsActive
1115 | Self::ShouldTerminate
1116 | Self::MarkTerminated
1117 | Self::GetMessagesProcessed
1118 | Self::InputQueueSize
1119 | Self::OutputQueueSize
1120 | Self::InputQueueEmpty
1121 | Self::OutputQueueEmpty
1122 | Self::EnqueueResponse
1123 )
1124 }
1125}
1126
1127#[cfg(test)]
1128mod tests {
1129 use super::*;
1130
1131 #[test]
1132 fn test_default_config() {
1133 let config = RingKernelConfig::default();
1134 assert_eq!(config.block_size, 128);
1135 assert_eq!(config.queue_capacity, 1024);
1136 assert!(config.enable_hlc);
1137 assert!(!config.enable_k2k);
1138 }
1139
1140 #[test]
1141 fn test_config_builder() {
1142 let config = RingKernelConfig::new("processor")
1143 .with_block_size(256)
1144 .with_queue_capacity(2048)
1145 .with_k2k(true)
1146 .with_hlc(true);
1147
1148 assert_eq!(config.id, "processor");
1149 assert_eq!(config.kernel_name(), "ring_kernel_processor");
1150 assert_eq!(config.block_size, 256);
1151 assert_eq!(config.queue_capacity, 2048);
1152 assert!(config.enable_k2k);
1153 assert!(config.enable_hlc);
1154 }
1155
1156 #[test]
1157 fn test_kernel_signature() {
1158 let config = RingKernelConfig::new("test");
1159 let sig = config.generate_signature();
1160
1161 assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
1162 assert!(sig.contains("ControlBlock* __restrict__ control"));
1163 assert!(sig.contains("input_buffer"));
1164 assert!(sig.contains("output_buffer"));
1165 assert!(sig.contains("shared_state"));
1166 }
1167
1168 #[test]
1169 fn test_kernel_signature_with_k2k() {
1170 let config = RingKernelConfig::new("k2k_test").with_k2k(true);
1171 let sig = config.generate_signature();
1172
1173 assert!(sig.contains("K2KRoutingTable"));
1174 assert!(sig.contains("k2k_inbox"));
1175 assert!(sig.contains("k2k_outbox"));
1176 }
1177
1178 #[test]
1179 fn test_preamble_generation() {
1180 let config = RingKernelConfig::new("test").with_hlc(true);
1181 let preamble = config.generate_preamble(" ");
1182
1183 assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
1184 assert!(preamble.contains("int lane_id"));
1185 assert!(preamble.contains("int warp_id"));
1186 assert!(preamble.contains("MSG_SIZE"));
1187 assert!(preamble.contains("hlc_physical"));
1188 assert!(preamble.contains("hlc_logical"));
1189 }
1190
1191 #[test]
1192 fn test_loop_header() {
1193 let config = RingKernelConfig::new("test");
1194 let header = config.generate_loop_header(" ");
1195
1196 assert!(header.contains("while (true)"));
1197 assert!(header.contains("should_terminate"));
1198 assert!(header.contains("is_active"));
1199 assert!(header.contains("input_head"));
1200 assert!(header.contains("input_tail"));
1201 assert!(header.contains("msg_ptr"));
1202 }
1203
1204 #[test]
1205 fn test_epilogue() {
1206 let config = RingKernelConfig::new("test").with_hlc(true);
1207 let epilogue = config.generate_epilogue(" ");
1208
1209 assert!(epilogue.contains("has_terminated"));
1210 assert!(epilogue.contains("hlc_state.physical"));
1211 assert!(epilogue.contains("hlc_state.logical"));
1212 }
1213
1214 #[test]
1215 fn test_control_block_struct() {
1216 let code = generate_control_block_struct();
1217
1218 assert!(code.contains("struct __align__(128) ControlBlock"));
1219 assert!(code.contains("is_active"));
1220 assert!(code.contains("should_terminate"));
1221 assert!(code.contains("has_terminated"));
1222 assert!(code.contains("messages_processed"));
1223 assert!(code.contains("input_head"));
1224 assert!(code.contains("input_tail"));
1225 assert!(code.contains("hlc_state"));
1226 }
1227
1228 #[test]
1229 fn test_full_kernel_wrapper() {
1230 let config = RingKernelConfig::new("example")
1231 .with_block_size(128)
1232 .with_hlc(true);
1233
1234 let kernel = config.generate_kernel_wrapper("// Process message here");
1235
1236 assert!(kernel.contains("struct __align__(128) ControlBlock"));
1237 assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
1238 assert!(kernel.contains("while (true)"));
1239 assert!(kernel.contains("// Process message here"));
1240 assert!(kernel.contains("has_terminated"));
1241
1242 println!("Generated kernel:\n{}", kernel);
1243 }
1244
1245 #[test]
1246 fn test_intrinsic_lookup() {
1247 assert_eq!(
1248 RingKernelIntrinsic::from_name("is_active"),
1249 Some(RingKernelIntrinsic::IsActive)
1250 );
1251 assert_eq!(
1252 RingKernelIntrinsic::from_name("should_terminate"),
1253 Some(RingKernelIntrinsic::ShouldTerminate)
1254 );
1255 assert_eq!(
1256 RingKernelIntrinsic::from_name("hlc_tick"),
1257 Some(RingKernelIntrinsic::HlcTick)
1258 );
1259 assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
1260 }
1261
1262 #[test]
1263 fn test_intrinsic_cuda_output() {
1264 assert!(RingKernelIntrinsic::IsActive
1265 .to_cuda(&[])
1266 .contains("is_active"));
1267 assert!(RingKernelIntrinsic::ShouldTerminate
1268 .to_cuda(&[])
1269 .contains("should_terminate"));
1270 assert!(RingKernelIntrinsic::HlcTick
1271 .to_cuda(&[])
1272 .contains("hlc_logical++"));
1273 }
1274
1275 #[test]
1276 fn test_k2k_structs_generation() {
1277 let k2k_code = generate_k2k_structs();
1278
1279 assert!(
1281 k2k_code.contains("struct K2KRoute"),
1282 "Should have K2KRoute struct"
1283 );
1284 assert!(
1285 k2k_code.contains("struct K2KRoutingTable"),
1286 "Should have K2KRoutingTable struct"
1287 );
1288 assert!(
1289 k2k_code.contains("K2KInboxHeader"),
1290 "Should have K2KInboxHeader struct"
1291 );
1292
1293 assert!(
1295 k2k_code.contains("__device__ inline int k2k_send"),
1296 "Should have k2k_send function"
1297 );
1298 assert!(
1299 k2k_code.contains("__device__ inline int k2k_has_message"),
1300 "Should have k2k_has_message function"
1301 );
1302 assert!(
1303 k2k_code.contains("__device__ inline void* k2k_try_recv"),
1304 "Should have k2k_try_recv function"
1305 );
1306 assert!(
1307 k2k_code.contains("__device__ inline void* k2k_peek"),
1308 "Should have k2k_peek function"
1309 );
1310 assert!(
1311 k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
1312 "Should have k2k_pending_count function"
1313 );
1314
1315 println!("K2K code:\n{}", k2k_code);
1316 }
1317
1318 #[test]
1319 fn test_full_k2k_kernel() {
1320 let config = RingKernelConfig::new("k2k_processor")
1321 .with_block_size(128)
1322 .with_k2k(true)
1323 .with_hlc(true);
1324
1325 let kernel = config.generate_kernel_wrapper("// K2K handler code");
1326
1327 assert!(
1329 kernel.contains("K2KRoutingTable"),
1330 "Should have K2KRoutingTable"
1331 );
1332 assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
1333 assert!(
1334 kernel.contains("K2KInboxHeader"),
1335 "Should have K2KInboxHeader"
1336 );
1337 assert!(
1338 kernel.contains("k2k_routes"),
1339 "Should have k2k_routes param"
1340 );
1341 assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
1342 assert!(
1343 kernel.contains("k2k_outbox"),
1344 "Should have k2k_outbox param"
1345 );
1346 assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
1347 assert!(
1348 kernel.contains("k2k_try_recv"),
1349 "Should have k2k_try_recv function"
1350 );
1351
1352 println!("Full K2K kernel:\n{}", kernel);
1353 }
1354
1355 #[test]
1356 fn test_k2k_intrinsic_requirements() {
1357 assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
1358 assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
1359 assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
1360 assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
1361 assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
1362
1363 assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
1364 assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
1365 }
1366}