1use std::fmt::Write;
46
47#[derive(Debug, Clone)]
49pub struct RingKernelConfig {
50 pub id: String,
52 pub block_size: u32,
54 pub queue_capacity: u32,
56 pub enable_k2k: bool,
58 pub enable_hlc: bool,
60 pub message_size: usize,
62 pub response_size: usize,
64 pub cooperative_groups: bool,
66 pub idle_sleep_ns: u32,
68 pub use_envelope_format: bool,
71 pub kernel_id_num: u64,
73 pub hlc_node_id: u64,
75}
76
77impl Default for RingKernelConfig {
78 fn default() -> Self {
79 Self {
80 id: "ring_kernel".to_string(),
81 block_size: 128,
82 queue_capacity: 1024,
83 enable_k2k: false,
84 enable_hlc: true,
85 message_size: 64,
86 response_size: 64,
87 cooperative_groups: false,
88 idle_sleep_ns: 1000,
89 use_envelope_format: true, kernel_id_num: 0,
91 hlc_node_id: 0,
92 }
93 }
94}
95
96impl RingKernelConfig {
97 pub fn new(id: impl Into<String>) -> Self {
99 Self {
100 id: id.into(),
101 ..Default::default()
102 }
103 }
104
105 pub fn with_block_size(mut self, size: u32) -> Self {
107 self.block_size = size;
108 self
109 }
110
111 pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
113 debug_assert!(
114 capacity.is_power_of_two(),
115 "Queue capacity must be power of 2"
116 );
117 self.queue_capacity = capacity;
118 self
119 }
120
121 pub fn with_k2k(mut self, enabled: bool) -> Self {
123 self.enable_k2k = enabled;
124 self
125 }
126
127 pub fn with_hlc(mut self, enabled: bool) -> Self {
129 self.enable_hlc = enabled;
130 self
131 }
132
133 pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
135 self.message_size = message;
136 self.response_size = response;
137 self
138 }
139
140 pub fn with_idle_sleep(mut self, ns: u32) -> Self {
142 self.idle_sleep_ns = ns;
143 self
144 }
145
146 pub fn with_envelope_format(mut self, enabled: bool) -> Self {
149 self.use_envelope_format = enabled;
150 self
151 }
152
153 pub fn with_kernel_id(mut self, id: u64) -> Self {
155 self.kernel_id_num = id;
156 self
157 }
158
159 pub fn with_hlc_node_id(mut self, node_id: u64) -> Self {
161 self.hlc_node_id = node_id;
162 self
163 }
164
165 pub fn kernel_name(&self) -> String {
167 format!("ring_kernel_{}", self.id)
168 }
169
170 pub fn generate_signature(&self) -> String {
172 let mut sig = String::new();
173
174 writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
175 writeln!(sig, " ControlBlock* __restrict__ control,").unwrap();
176 writeln!(sig, " unsigned char* __restrict__ input_buffer,").unwrap();
177 writeln!(sig, " unsigned char* __restrict__ output_buffer,").unwrap();
178
179 if self.enable_k2k {
180 writeln!(sig, " K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
181 writeln!(sig, " unsigned char* __restrict__ k2k_inbox,").unwrap();
182 writeln!(sig, " unsigned char* __restrict__ k2k_outbox,").unwrap();
183 }
184
185 write!(sig, " void* __restrict__ shared_state").unwrap();
186 write!(sig, "\n)").unwrap();
187
188 sig
189 }
190
191 pub fn generate_preamble(&self, indent: &str) -> String {
193 let mut code = String::new();
194
195 writeln!(code, "{}// Thread identification", indent).unwrap();
197 writeln!(
198 code,
199 "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
200 indent
201 )
202 .unwrap();
203 writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
204 writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
205 writeln!(code).unwrap();
206
207 writeln!(code, "{}// Kernel identity", indent).unwrap();
209 writeln!(
210 code,
211 "{}const unsigned long long KERNEL_ID = {}ULL;",
212 indent, self.kernel_id_num
213 )
214 .unwrap();
215 writeln!(
216 code,
217 "{}const unsigned long long HLC_NODE_ID = {}ULL;",
218 indent, self.hlc_node_id
219 )
220 .unwrap();
221 writeln!(code).unwrap();
222
223 writeln!(code, "{}// Message buffer constants", indent).unwrap();
225 if self.use_envelope_format {
226 writeln!(
228 code,
229 "{}const unsigned int PAYLOAD_SIZE = {}; // User payload size",
230 indent, self.message_size
231 )
232 .unwrap();
233 writeln!(
234 code,
235 "{}const unsigned int MSG_SIZE = MESSAGE_HEADER_SIZE + PAYLOAD_SIZE; // Full envelope",
236 indent
237 )
238 .unwrap();
239 writeln!(
240 code,
241 "{}const unsigned int RESP_PAYLOAD_SIZE = {};",
242 indent, self.response_size
243 )
244 .unwrap();
245 writeln!(
246 code,
247 "{}const unsigned int RESP_SIZE = MESSAGE_HEADER_SIZE + RESP_PAYLOAD_SIZE;",
248 indent
249 )
250 .unwrap();
251 } else {
252 writeln!(
254 code,
255 "{}const unsigned int MSG_SIZE = {};",
256 indent, self.message_size
257 )
258 .unwrap();
259 writeln!(
260 code,
261 "{}const unsigned int RESP_SIZE = {};",
262 indent, self.response_size
263 )
264 .unwrap();
265 }
266 writeln!(
267 code,
268 "{}const unsigned int QUEUE_MASK = {};",
269 indent,
270 self.queue_capacity - 1
271 )
272 .unwrap();
273 writeln!(code).unwrap();
274
275 if self.enable_hlc {
277 writeln!(code, "{}// HLC clock state", indent).unwrap();
278 writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
279 writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
280 writeln!(code).unwrap();
281 }
282
283 code
284 }
285
286 pub fn generate_loop_header(&self, indent: &str) -> String {
288 let mut code = String::new();
289
290 writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
291 writeln!(code, "{}while (true) {{", indent).unwrap();
292 writeln!(code, "{} // Check for termination signal", indent).unwrap();
293 writeln!(
294 code,
295 "{} if (atomicAdd(&control->should_terminate, 0) != 0) {{",
296 indent
297 )
298 .unwrap();
299 writeln!(code, "{} break;", indent).unwrap();
300 writeln!(code, "{} }}", indent).unwrap();
301 writeln!(code).unwrap();
302
303 writeln!(code, "{} // Check if kernel is active", indent).unwrap();
305 writeln!(
306 code,
307 "{} if (atomicAdd(&control->is_active, 0) == 0) {{",
308 indent
309 )
310 .unwrap();
311 if self.idle_sleep_ns > 0 {
312 writeln!(
313 code,
314 "{} __nanosleep({});",
315 indent, self.idle_sleep_ns
316 )
317 .unwrap();
318 }
319 writeln!(code, "{} continue;", indent).unwrap();
320 writeln!(code, "{} }}", indent).unwrap();
321 writeln!(code).unwrap();
322
323 writeln!(code, "{} // Check input queue for messages", indent).unwrap();
325 writeln!(
326 code,
327 "{} unsigned long long head = atomicAdd(&control->input_head, 0);",
328 indent
329 )
330 .unwrap();
331 writeln!(
332 code,
333 "{} unsigned long long tail = atomicAdd(&control->input_tail, 0);",
334 indent
335 )
336 .unwrap();
337 writeln!(code).unwrap();
338 writeln!(code, "{} if (head == tail) {{", indent).unwrap();
339 writeln!(code, "{} // No messages, yield", indent).unwrap();
340 if self.idle_sleep_ns > 0 {
341 writeln!(
342 code,
343 "{} __nanosleep({});",
344 indent, self.idle_sleep_ns
345 )
346 .unwrap();
347 }
348 writeln!(code, "{} continue;", indent).unwrap();
349 writeln!(code, "{} }}", indent).unwrap();
350 writeln!(code).unwrap();
351
352 writeln!(code, "{} // Get message from queue", indent).unwrap();
354 writeln!(
355 code,
356 "{} unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
357 indent
358 )
359 .unwrap();
360 writeln!(
361 code,
362 "{} unsigned char* envelope_ptr = &input_buffer[msg_idx * MSG_SIZE];",
363 indent
364 )
365 .unwrap();
366
367 if self.use_envelope_format {
368 writeln!(code).unwrap();
370 writeln!(code, "{} // Parse message envelope", indent).unwrap();
371 writeln!(
372 code,
373 "{} MessageHeader* msg_header = message_get_header(envelope_ptr);",
374 indent
375 )
376 .unwrap();
377 writeln!(
378 code,
379 "{} unsigned char* msg_ptr = message_get_payload(envelope_ptr);",
380 indent
381 )
382 .unwrap();
383 writeln!(code).unwrap();
384 writeln!(code, "{} // Validate message (skip invalid)", indent).unwrap();
385 writeln!(
386 code,
387 "{} if (!message_header_validate(msg_header)) {{",
388 indent
389 )
390 .unwrap();
391 writeln!(
392 code,
393 "{} atomicAdd(&control->input_tail, 1);",
394 indent
395 )
396 .unwrap();
397 writeln!(
398 code,
399 "{} atomicAdd(&control->last_error, 1); // Track errors",
400 indent
401 )
402 .unwrap();
403 writeln!(code, "{} continue;", indent).unwrap();
404 writeln!(code, "{} }}", indent).unwrap();
405
406 if self.enable_hlc {
408 writeln!(code).unwrap();
409 writeln!(code, "{} // Update HLC from message timestamp", indent).unwrap();
410 writeln!(
411 code,
412 "{} if (msg_header->timestamp.physical > hlc_physical) {{",
413 indent
414 )
415 .unwrap();
416 writeln!(
417 code,
418 "{} hlc_physical = msg_header->timestamp.physical;",
419 indent
420 )
421 .unwrap();
422 writeln!(code, "{} hlc_logical = 0;", indent).unwrap();
423 writeln!(code, "{} }}", indent).unwrap();
424 writeln!(code, "{} hlc_logical++;", indent).unwrap();
425 }
426 } else {
427 writeln!(
429 code,
430 "{} unsigned char* msg_ptr = envelope_ptr; // Raw message data",
431 indent
432 )
433 .unwrap();
434 }
435 writeln!(code).unwrap();
436
437 code
438 }
439
440 pub fn generate_message_complete(&self, indent: &str) -> String {
442 let mut code = String::new();
443
444 writeln!(code).unwrap();
445 writeln!(code, "{} // Mark message as processed", indent).unwrap();
446 writeln!(code, "{} atomicAdd(&control->input_tail, 1);", indent).unwrap();
447 writeln!(
448 code,
449 "{} atomicAdd(&control->messages_processed, 1);",
450 indent
451 )
452 .unwrap();
453
454 if self.enable_hlc {
455 writeln!(code).unwrap();
456 writeln!(code, "{} // Update HLC", indent).unwrap();
457 writeln!(code, "{} hlc_logical++;", indent).unwrap();
458 }
459
460 code
461 }
462
463 pub fn generate_loop_footer(&self, indent: &str) -> String {
465 let mut code = String::new();
466
467 writeln!(code, "{} __syncthreads();", indent).unwrap();
468 writeln!(code, "{}}}", indent).unwrap();
469
470 code
471 }
472
473 pub fn generate_epilogue(&self, indent: &str) -> String {
475 let mut code = String::new();
476
477 writeln!(code).unwrap();
478 writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
479 writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
480
481 if self.enable_hlc {
482 writeln!(code, "{} // Store final HLC state", indent).unwrap();
483 writeln!(
484 code,
485 "{} control->hlc_state.physical = hlc_physical;",
486 indent
487 )
488 .unwrap();
489 writeln!(
490 code,
491 "{} control->hlc_state.logical = hlc_logical;",
492 indent
493 )
494 .unwrap();
495 }
496
497 writeln!(
498 code,
499 "{} atomicExch(&control->has_terminated, 1);",
500 indent
501 )
502 .unwrap();
503 writeln!(code, "{}}}", indent).unwrap();
504
505 code
506 }
507
508 pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
510 let mut code = String::new();
511
512 code.push_str(&generate_control_block_struct());
514 code.push('\n');
515
516 if self.enable_hlc {
517 code.push_str(&generate_hlc_struct());
518 code.push('\n');
519 }
520
521 if self.use_envelope_format {
523 code.push_str(&generate_message_envelope_structs());
524 code.push('\n');
525 }
526
527 if self.enable_k2k {
528 code.push_str(&generate_k2k_structs());
529 code.push('\n');
530 }
531
532 code.push_str(&self.generate_signature());
534 code.push_str(" {\n");
535
536 code.push_str(&self.generate_preamble(" "));
538
539 code.push_str(&self.generate_loop_header(" "));
541
542 writeln!(code, " // === USER HANDLER CODE ===").unwrap();
544 for line in handler_placeholder.lines() {
545 writeln!(code, " {}", line).unwrap();
546 }
547 writeln!(code, " // === END HANDLER CODE ===").unwrap();
548
549 code.push_str(&self.generate_message_complete(" "));
551
552 code.push_str(&self.generate_loop_footer(" "));
554
555 code.push_str(&self.generate_epilogue(" "));
557
558 code.push_str("}\n");
559
560 code
561 }
562}
563
564pub fn generate_control_block_struct() -> String {
566 r#"// Control block for kernel state management (128 bytes, cache-line aligned)
567struct __align__(128) ControlBlock {
568 // Lifecycle state
569 unsigned int is_active;
570 unsigned int should_terminate;
571 unsigned int has_terminated;
572 unsigned int _pad1;
573
574 // Counters
575 unsigned long long messages_processed;
576 unsigned long long messages_in_flight;
577
578 // Queue pointers
579 unsigned long long input_head;
580 unsigned long long input_tail;
581 unsigned long long output_head;
582 unsigned long long output_tail;
583
584 // Queue metadata
585 unsigned int input_capacity;
586 unsigned int output_capacity;
587 unsigned int input_mask;
588 unsigned int output_mask;
589
590 // HLC state
591 struct {
592 unsigned long long physical;
593 unsigned long long logical;
594 } hlc_state;
595
596 // Error state
597 unsigned int last_error;
598 unsigned int error_count;
599
600 // Reserved padding
601 unsigned char _reserved[24];
602};
603"#
604 .to_string()
605}
606
607pub fn generate_hlc_struct() -> String {
609 r#"// Hybrid Logical Clock state
610struct HlcState {
611 unsigned long long physical;
612 unsigned long long logical;
613};
614
615// HLC timestamp (24 bytes)
616struct __align__(8) HlcTimestamp {
617 unsigned long long physical;
618 unsigned long long logical;
619 unsigned long long node_id;
620};
621"#
622 .to_string()
623}
624
625pub fn generate_message_envelope_structs() -> String {
630 r#"// Magic number for message validation
631#define MESSAGE_MAGIC 0x52494E474B45524E ULL // "RINGKERN"
632#define MESSAGE_VERSION 1
633#define MESSAGE_HEADER_SIZE 256
634#define MAX_PAYLOAD_SIZE (64 * 1024)
635
636// Message priority levels
637#define PRIORITY_LOW 0
638#define PRIORITY_NORMAL 1
639#define PRIORITY_HIGH 2
640#define PRIORITY_CRITICAL 3
641
642// Message header structure (256 bytes, cache-line aligned)
643// Matches ringkernel_core::message::MessageHeader exactly
644struct __align__(64) MessageHeader {
645 // Magic number for validation (0xRINGKERN)
646 unsigned long long magic;
647 // Header version
648 unsigned int version;
649 // Message flags
650 unsigned int flags;
651 // Unique message identifier
652 unsigned long long message_id;
653 // Correlation ID for request-response
654 unsigned long long correlation_id;
655 // Source kernel ID (0 for host)
656 unsigned long long source_kernel;
657 // Destination kernel ID (0 for host)
658 unsigned long long dest_kernel;
659 // Message type discriminator
660 unsigned long long message_type;
661 // Priority level
662 unsigned char priority;
663 // Reserved for alignment
664 unsigned char _reserved1[7];
665 // Payload size in bytes
666 unsigned long long payload_size;
667 // Checksum of payload (CRC32)
668 unsigned int checksum;
669 // Reserved for alignment
670 unsigned int _reserved2;
671 // HLC timestamp when message was created
672 HlcTimestamp timestamp;
673 // Deadline timestamp (0 = no deadline)
674 HlcTimestamp deadline;
675 // Reserved for future use (104 bytes total: 32+32+32+8)
676 unsigned char _reserved3[104];
677};
678
679// Validate message header
680__device__ inline int message_header_validate(const MessageHeader* header) {
681 return header->magic == MESSAGE_MAGIC &&
682 header->version <= MESSAGE_VERSION &&
683 header->payload_size <= MAX_PAYLOAD_SIZE;
684}
685
686// Get payload pointer from header
687__device__ inline unsigned char* message_get_payload(unsigned char* envelope_ptr) {
688 return envelope_ptr + MESSAGE_HEADER_SIZE;
689}
690
691// Get header from envelope pointer
692__device__ inline MessageHeader* message_get_header(unsigned char* envelope_ptr) {
693 return (MessageHeader*)envelope_ptr;
694}
695
696// CRC32 computation for payload checksums
697// Uses CRC32-C (Castagnoli) polynomial 0x1EDC6F41
698__device__ inline unsigned int message_compute_checksum(const unsigned char* data, unsigned long long size) {
699 unsigned int crc = 0xFFFFFFFF;
700 for (unsigned long long i = 0; i < size; i++) {
701 crc ^= data[i];
702 for (int j = 0; j < 8; j++) {
703 crc = (crc >> 1) ^ (0x82F63B78 * (crc & 1)); // CRC32-C polynomial
704 }
705 }
706 return ~crc;
707}
708
709// Verify payload checksum
710__device__ inline int message_verify_checksum(const MessageHeader* header, const unsigned char* payload) {
711 if (header->payload_size == 0) {
712 return header->checksum == 0; // Empty payload should have zero checksum
713 }
714 unsigned int computed = message_compute_checksum(payload, header->payload_size);
715 return computed == header->checksum;
716}
717
718// Create a response header based on request
719// Note: payload_ptr is optional, pass NULL if checksum not needed yet
720__device__ inline void message_create_response_header(
721 MessageHeader* response,
722 const MessageHeader* request,
723 unsigned long long this_kernel_id,
724 unsigned long long payload_size,
725 unsigned long long hlc_physical,
726 unsigned long long hlc_logical,
727 unsigned long long hlc_node_id,
728 const unsigned char* payload_ptr // Optional: for checksum computation
729) {
730 response->magic = MESSAGE_MAGIC;
731 response->version = MESSAGE_VERSION;
732 response->flags = 0;
733 // Generate new message ID (simple increment from request)
734 response->message_id = request->message_id + 0x100000000ULL;
735 // Preserve correlation ID for request-response matching
736 response->correlation_id = request->correlation_id != 0
737 ? request->correlation_id
738 : request->message_id;
739 response->source_kernel = this_kernel_id;
740 response->dest_kernel = request->source_kernel; // Response goes back to sender
741 response->message_type = request->message_type + 1; // Convention: response type = request + 1
742 response->priority = request->priority;
743 response->payload_size = payload_size;
744 // Compute checksum if payload provided
745 if (payload_ptr != NULL && payload_size > 0) {
746 response->checksum = message_compute_checksum(payload_ptr, payload_size);
747 } else {
748 response->checksum = 0;
749 }
750 response->timestamp.physical = hlc_physical;
751 response->timestamp.logical = hlc_logical;
752 response->timestamp.node_id = hlc_node_id;
753 response->deadline.physical = 0;
754 response->deadline.logical = 0;
755 response->deadline.node_id = 0;
756}
757
758// Calculate total envelope size
759__device__ inline unsigned int message_envelope_size(const MessageHeader* header) {
760 return MESSAGE_HEADER_SIZE + (unsigned int)header->payload_size;
761}
762"#
763 .to_string()
764}
765
766pub fn generate_k2k_structs() -> String {
768 r#"// Kernel-to-kernel routing table entry (48 bytes)
769// Matches ringkernel_cuda::k2k_gpu::K2KRouteEntry
770struct K2KRoute {
771 unsigned long long target_kernel_id;
772 unsigned long long target_inbox; // device pointer
773 unsigned long long target_head; // device pointer to head
774 unsigned long long target_tail; // device pointer to tail
775 unsigned int capacity;
776 unsigned int mask;
777 unsigned int msg_size;
778 unsigned int _pad;
779};
780
781// K2K routing table (8 + 48*16 = 776 bytes)
782struct K2KRoutingTable {
783 unsigned int num_routes;
784 unsigned int _pad;
785 K2KRoute routes[16]; // Max 16 K2K connections
786};
787
788// K2K inbox header (64 bytes, cache-line aligned)
789// Matches ringkernel_cuda::k2k_gpu::K2KInboxHeader
790struct __align__(64) K2KInboxHeader {
791 unsigned long long head;
792 unsigned long long tail;
793 unsigned int capacity;
794 unsigned int mask;
795 unsigned int msg_size;
796 unsigned int _pad;
797};
798
799// Send a message envelope to another kernel via K2K
800// The entire envelope (header + payload) is copied to the target's inbox
801__device__ inline int k2k_send_envelope(
802 K2KRoutingTable* routes,
803 unsigned long long target_id,
804 unsigned long long source_kernel_id,
805 const void* payload_ptr,
806 unsigned int payload_size,
807 unsigned long long message_type,
808 unsigned long long hlc_physical,
809 unsigned long long hlc_logical,
810 unsigned long long hlc_node_id
811) {
812 // Find route for target
813 for (unsigned int i = 0; i < routes->num_routes; i++) {
814 if (routes->routes[i].target_kernel_id == target_id) {
815 K2KRoute* route = &routes->routes[i];
816
817 // Calculate total envelope size
818 unsigned int envelope_size = MESSAGE_HEADER_SIZE + payload_size;
819 if (envelope_size > route->msg_size) {
820 return -1; // Message too large
821 }
822
823 // Atomically claim a slot in target's inbox
824 unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
825 unsigned long long slot = atomicAdd(target_head_ptr, 1);
826 unsigned int idx = (unsigned int)(slot & route->mask);
827
828 // Calculate destination pointer
829 unsigned char* dest = ((unsigned char*)route->target_inbox) +
830 sizeof(K2KInboxHeader) + idx * route->msg_size;
831
832 // Build message header
833 MessageHeader* header = (MessageHeader*)dest;
834 header->magic = MESSAGE_MAGIC;
835 header->version = MESSAGE_VERSION;
836 header->flags = 0;
837 header->message_id = (source_kernel_id << 32) | (slot & 0xFFFFFFFF);
838 header->correlation_id = 0;
839 header->source_kernel = source_kernel_id;
840 header->dest_kernel = target_id;
841 header->message_type = message_type;
842 header->priority = PRIORITY_NORMAL;
843 header->payload_size = payload_size;
844 header->timestamp.physical = hlc_physical;
845 header->timestamp.logical = hlc_logical;
846 header->timestamp.node_id = hlc_node_id;
847 header->deadline.physical = 0;
848 header->deadline.logical = 0;
849 header->deadline.node_id = 0;
850
851 // Copy payload after header and compute checksum
852 if (payload_size > 0 && payload_ptr != NULL) {
853 memcpy(dest + MESSAGE_HEADER_SIZE, payload_ptr, payload_size);
854 header->checksum = message_compute_checksum(
855 (const unsigned char*)payload_ptr, payload_size);
856 } else {
857 header->checksum = 0;
858 }
859
860 __threadfence(); // Ensure write is visible
861 return 1; // Success
862 }
863 }
864 return 0; // Route not found
865}
866
867// Legacy k2k_send for raw message data (no envelope)
868__device__ inline int k2k_send(
869 K2KRoutingTable* routes,
870 unsigned long long target_id,
871 const void* msg_ptr,
872 unsigned int msg_size
873) {
874 // Find route for target
875 for (unsigned int i = 0; i < routes->num_routes; i++) {
876 if (routes->routes[i].target_kernel_id == target_id) {
877 K2KRoute* route = &routes->routes[i];
878
879 // Atomically claim a slot in target's inbox
880 unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
881 unsigned long long slot = atomicAdd(target_head_ptr, 1);
882 unsigned int idx = (unsigned int)(slot & route->mask);
883
884 // Copy message to target inbox
885 unsigned char* dest = ((unsigned char*)route->target_inbox) +
886 sizeof(K2KInboxHeader) + idx * route->msg_size;
887 memcpy(dest, msg_ptr, msg_size < route->msg_size ? msg_size : route->msg_size);
888
889 __threadfence();
890 return 1; // Success
891 }
892 }
893 return 0; // Route not found
894}
895
896// Check if there are K2K messages in inbox
897__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
898 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
899 unsigned long long head = atomicAdd(&header->head, 0);
900 unsigned long long tail = atomicAdd(&header->tail, 0);
901 return head != tail;
902}
903
904// Try to receive a K2K message envelope
905// Returns pointer to MessageHeader if available, NULL otherwise
906__device__ inline MessageHeader* k2k_try_recv_envelope(unsigned char* k2k_inbox) {
907 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
908
909 unsigned long long head = atomicAdd(&header->head, 0);
910 unsigned long long tail = atomicAdd(&header->tail, 0);
911
912 if (head == tail) {
913 return NULL; // No messages
914 }
915
916 // Get message pointer
917 unsigned int idx = (unsigned int)(tail & header->mask);
918 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
919 MessageHeader* msg_header = (MessageHeader*)(data_start + idx * header->msg_size);
920
921 // Validate header
922 if (!message_header_validate(msg_header)) {
923 // Invalid message, skip it
924 atomicAdd(&header->tail, 1);
925 return NULL;
926 }
927
928 // Advance tail (consume message)
929 atomicAdd(&header->tail, 1);
930
931 return msg_header;
932}
933
934// Legacy k2k_try_recv for raw data
935__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
936 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
937
938 unsigned long long head = atomicAdd(&header->head, 0);
939 unsigned long long tail = atomicAdd(&header->tail, 0);
940
941 if (head == tail) {
942 return NULL; // No messages
943 }
944
945 // Get message pointer
946 unsigned int idx = (unsigned int)(tail & header->mask);
947 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
948
949 // Advance tail (consume message)
950 atomicAdd(&header->tail, 1);
951
952 return data_start + idx * header->msg_size;
953}
954
955// Peek at next K2K message without consuming
956__device__ inline MessageHeader* k2k_peek_envelope(unsigned char* k2k_inbox) {
957 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
958
959 unsigned long long head = atomicAdd(&header->head, 0);
960 unsigned long long tail = atomicAdd(&header->tail, 0);
961
962 if (head == tail) {
963 return NULL; // No messages
964 }
965
966 unsigned int idx = (unsigned int)(tail & header->mask);
967 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
968
969 return (MessageHeader*)(data_start + idx * header->msg_size);
970}
971
972// Legacy k2k_peek
973__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
974 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
975
976 unsigned long long head = atomicAdd(&header->head, 0);
977 unsigned long long tail = atomicAdd(&header->tail, 0);
978
979 if (head == tail) {
980 return NULL; // No messages
981 }
982
983 unsigned int idx = (unsigned int)(tail & header->mask);
984 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
985
986 return data_start + idx * header->msg_size;
987}
988
989// Get number of pending K2K messages
990__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
991 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
992 unsigned long long head = atomicAdd(&header->head, 0);
993 unsigned long long tail = atomicAdd(&header->tail, 0);
994 return (unsigned int)(head - tail);
995}
996"#
997 .to_string()
998}
999
1000#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1002pub enum RingKernelIntrinsic {
1003 IsActive,
1005 ShouldTerminate,
1006 MarkTerminated,
1007 GetMessagesProcessed,
1008
1009 InputQueueSize,
1011 OutputQueueSize,
1012 InputQueueEmpty,
1013 OutputQueueEmpty,
1014 EnqueueResponse,
1015
1016 HlcTick,
1018 HlcUpdate,
1019 HlcNow,
1020
1021 K2kSend,
1023 K2kTryRecv,
1024 K2kHasMessage,
1025 K2kPeek,
1026 K2kPendingCount,
1027}
1028
1029impl RingKernelIntrinsic {
1030 pub fn to_cuda(&self, args: &[String]) -> String {
1032 match self {
1033 Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
1034 Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
1035 Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
1036 Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
1037
1038 Self::InputQueueSize => {
1039 "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
1040 .to_string()
1041 }
1042 Self::OutputQueueSize => {
1043 "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
1044 .to_string()
1045 }
1046 Self::InputQueueEmpty => {
1047 "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
1048 .to_string()
1049 }
1050 Self::OutputQueueEmpty => {
1051 "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
1052 .to_string()
1053 }
1054 Self::EnqueueResponse => {
1055 if !args.is_empty() {
1056 format!(
1057 "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
1058 memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
1059 args[0]
1060 )
1061 } else {
1062 "/* enqueue_response requires response pointer */".to_string()
1063 }
1064 }
1065
1066 Self::HlcTick => "hlc_logical++".to_string(),
1067 Self::HlcUpdate => {
1068 if !args.is_empty() {
1069 format!(
1070 "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
1071 args[0], args[0]
1072 )
1073 } else {
1074 "hlc_logical++".to_string()
1075 }
1076 }
1077 Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
1078
1079 Self::K2kSend => {
1080 if args.len() >= 2 {
1081 format!(
1083 "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
1084 args[0], args[1], args[1]
1085 )
1086 } else {
1087 "/* k2k_send requires target_id and msg_ptr */".to_string()
1088 }
1089 }
1090 Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
1091 Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
1092 Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
1093 Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
1094 }
1095 }
1096
1097 pub fn from_name(name: &str) -> Option<Self> {
1099 match name {
1100 "is_active" | "is_kernel_active" => Some(Self::IsActive),
1101 "should_terminate" => Some(Self::ShouldTerminate),
1102 "mark_terminated" => Some(Self::MarkTerminated),
1103 "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
1104
1105 "input_queue_size" => Some(Self::InputQueueSize),
1106 "output_queue_size" => Some(Self::OutputQueueSize),
1107 "input_queue_empty" => Some(Self::InputQueueEmpty),
1108 "output_queue_empty" => Some(Self::OutputQueueEmpty),
1109 "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
1110
1111 "hlc_tick" => Some(Self::HlcTick),
1112 "hlc_update" => Some(Self::HlcUpdate),
1113 "hlc_now" => Some(Self::HlcNow),
1114
1115 "k2k_send" => Some(Self::K2kSend),
1116 "k2k_try_recv" => Some(Self::K2kTryRecv),
1117 "k2k_has_message" => Some(Self::K2kHasMessage),
1118 "k2k_peek" => Some(Self::K2kPeek),
1119 "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
1120
1121 _ => None,
1122 }
1123 }
1124
1125 pub fn requires_k2k(&self) -> bool {
1127 matches!(
1128 self,
1129 Self::K2kSend
1130 | Self::K2kTryRecv
1131 | Self::K2kHasMessage
1132 | Self::K2kPeek
1133 | Self::K2kPendingCount
1134 )
1135 }
1136
1137 pub fn requires_hlc(&self) -> bool {
1139 matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
1140 }
1141
1142 pub fn requires_control_block(&self) -> bool {
1144 matches!(
1145 self,
1146 Self::IsActive
1147 | Self::ShouldTerminate
1148 | Self::MarkTerminated
1149 | Self::GetMessagesProcessed
1150 | Self::InputQueueSize
1151 | Self::OutputQueueSize
1152 | Self::InputQueueEmpty
1153 | Self::OutputQueueEmpty
1154 | Self::EnqueueResponse
1155 )
1156 }
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161 use super::*;
1162
1163 #[test]
1164 fn test_default_config() {
1165 let config = RingKernelConfig::default();
1166 assert_eq!(config.block_size, 128);
1167 assert_eq!(config.queue_capacity, 1024);
1168 assert!(config.enable_hlc);
1169 assert!(!config.enable_k2k);
1170 }
1171
1172 #[test]
1173 fn test_config_builder() {
1174 let config = RingKernelConfig::new("processor")
1175 .with_block_size(256)
1176 .with_queue_capacity(2048)
1177 .with_k2k(true)
1178 .with_hlc(true);
1179
1180 assert_eq!(config.id, "processor");
1181 assert_eq!(config.kernel_name(), "ring_kernel_processor");
1182 assert_eq!(config.block_size, 256);
1183 assert_eq!(config.queue_capacity, 2048);
1184 assert!(config.enable_k2k);
1185 assert!(config.enable_hlc);
1186 }
1187
1188 #[test]
1189 fn test_kernel_signature() {
1190 let config = RingKernelConfig::new("test");
1191 let sig = config.generate_signature();
1192
1193 assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
1194 assert!(sig.contains("ControlBlock* __restrict__ control"));
1195 assert!(sig.contains("input_buffer"));
1196 assert!(sig.contains("output_buffer"));
1197 assert!(sig.contains("shared_state"));
1198 }
1199
1200 #[test]
1201 fn test_kernel_signature_with_k2k() {
1202 let config = RingKernelConfig::new("k2k_test").with_k2k(true);
1203 let sig = config.generate_signature();
1204
1205 assert!(sig.contains("K2KRoutingTable"));
1206 assert!(sig.contains("k2k_inbox"));
1207 assert!(sig.contains("k2k_outbox"));
1208 }
1209
1210 #[test]
1211 fn test_preamble_generation() {
1212 let config = RingKernelConfig::new("test").with_hlc(true);
1213 let preamble = config.generate_preamble(" ");
1214
1215 assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
1216 assert!(preamble.contains("int lane_id"));
1217 assert!(preamble.contains("int warp_id"));
1218 assert!(preamble.contains("MSG_SIZE"));
1219 assert!(preamble.contains("hlc_physical"));
1220 assert!(preamble.contains("hlc_logical"));
1221 }
1222
1223 #[test]
1224 fn test_loop_header() {
1225 let config = RingKernelConfig::new("test");
1226 let header = config.generate_loop_header(" ");
1227
1228 assert!(header.contains("while (true)"));
1229 assert!(header.contains("should_terminate"));
1230 assert!(header.contains("is_active"));
1231 assert!(header.contains("input_head"));
1232 assert!(header.contains("input_tail"));
1233 assert!(header.contains("msg_ptr"));
1234 }
1235
1236 #[test]
1237 fn test_epilogue() {
1238 let config = RingKernelConfig::new("test").with_hlc(true);
1239 let epilogue = config.generate_epilogue(" ");
1240
1241 assert!(epilogue.contains("has_terminated"));
1242 assert!(epilogue.contains("hlc_state.physical"));
1243 assert!(epilogue.contains("hlc_state.logical"));
1244 }
1245
1246 #[test]
1247 fn test_control_block_struct() {
1248 let code = generate_control_block_struct();
1249
1250 assert!(code.contains("struct __align__(128) ControlBlock"));
1251 assert!(code.contains("is_active"));
1252 assert!(code.contains("should_terminate"));
1253 assert!(code.contains("has_terminated"));
1254 assert!(code.contains("messages_processed"));
1255 assert!(code.contains("input_head"));
1256 assert!(code.contains("input_tail"));
1257 assert!(code.contains("hlc_state"));
1258 }
1259
1260 #[test]
1261 fn test_full_kernel_wrapper() {
1262 let config = RingKernelConfig::new("example")
1263 .with_block_size(128)
1264 .with_hlc(true);
1265
1266 let kernel = config.generate_kernel_wrapper("// Process message here");
1267
1268 assert!(kernel.contains("struct __align__(128) ControlBlock"));
1269 assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
1270 assert!(kernel.contains("while (true)"));
1271 assert!(kernel.contains("// Process message here"));
1272 assert!(kernel.contains("has_terminated"));
1273
1274 println!("Generated kernel:\n{}", kernel);
1275 }
1276
1277 #[test]
1278 fn test_intrinsic_lookup() {
1279 assert_eq!(
1280 RingKernelIntrinsic::from_name("is_active"),
1281 Some(RingKernelIntrinsic::IsActive)
1282 );
1283 assert_eq!(
1284 RingKernelIntrinsic::from_name("should_terminate"),
1285 Some(RingKernelIntrinsic::ShouldTerminate)
1286 );
1287 assert_eq!(
1288 RingKernelIntrinsic::from_name("hlc_tick"),
1289 Some(RingKernelIntrinsic::HlcTick)
1290 );
1291 assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
1292 }
1293
1294 #[test]
1295 fn test_intrinsic_cuda_output() {
1296 assert!(RingKernelIntrinsic::IsActive
1297 .to_cuda(&[])
1298 .contains("is_active"));
1299 assert!(RingKernelIntrinsic::ShouldTerminate
1300 .to_cuda(&[])
1301 .contains("should_terminate"));
1302 assert!(RingKernelIntrinsic::HlcTick
1303 .to_cuda(&[])
1304 .contains("hlc_logical++"));
1305 }
1306
1307 #[test]
1308 fn test_k2k_structs_generation() {
1309 let k2k_code = generate_k2k_structs();
1310
1311 assert!(
1313 k2k_code.contains("struct K2KRoute"),
1314 "Should have K2KRoute struct"
1315 );
1316 assert!(
1317 k2k_code.contains("struct K2KRoutingTable"),
1318 "Should have K2KRoutingTable struct"
1319 );
1320 assert!(
1321 k2k_code.contains("K2KInboxHeader"),
1322 "Should have K2KInboxHeader struct"
1323 );
1324
1325 assert!(
1327 k2k_code.contains("__device__ inline int k2k_send"),
1328 "Should have k2k_send function"
1329 );
1330 assert!(
1331 k2k_code.contains("__device__ inline int k2k_has_message"),
1332 "Should have k2k_has_message function"
1333 );
1334 assert!(
1335 k2k_code.contains("__device__ inline void* k2k_try_recv"),
1336 "Should have k2k_try_recv function"
1337 );
1338 assert!(
1339 k2k_code.contains("__device__ inline void* k2k_peek"),
1340 "Should have k2k_peek function"
1341 );
1342 assert!(
1343 k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
1344 "Should have k2k_pending_count function"
1345 );
1346
1347 println!("K2K code:\n{}", k2k_code);
1348 }
1349
1350 #[test]
1351 fn test_full_k2k_kernel() {
1352 let config = RingKernelConfig::new("k2k_processor")
1353 .with_block_size(128)
1354 .with_k2k(true)
1355 .with_hlc(true);
1356
1357 let kernel = config.generate_kernel_wrapper("// K2K handler code");
1358
1359 assert!(
1361 kernel.contains("K2KRoutingTable"),
1362 "Should have K2KRoutingTable"
1363 );
1364 assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
1365 assert!(
1366 kernel.contains("K2KInboxHeader"),
1367 "Should have K2KInboxHeader"
1368 );
1369 assert!(
1370 kernel.contains("k2k_routes"),
1371 "Should have k2k_routes param"
1372 );
1373 assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
1374 assert!(
1375 kernel.contains("k2k_outbox"),
1376 "Should have k2k_outbox param"
1377 );
1378 assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
1379 assert!(
1380 kernel.contains("k2k_try_recv"),
1381 "Should have k2k_try_recv function"
1382 );
1383
1384 println!("Full K2K kernel:\n{}", kernel);
1385 }
1386
1387 #[test]
1388 fn test_k2k_intrinsic_requirements() {
1389 assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
1390 assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
1391 assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
1392 assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
1393 assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
1394
1395 assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
1396 assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
1397 }
1398}