1use std::fmt::Write;
46
47use crate::reduction_intrinsics::ReductionOp;
48
49#[derive(Debug, Clone, Default)]
55pub struct KernelReductionConfig {
56 pub enabled: bool,
58 pub op: ReductionOp,
60 pub accumulator_type: String,
62 pub use_cooperative: bool,
65 pub shared_array_name: String,
67 pub accumulator_name: String,
69}
70
71impl KernelReductionConfig {
72 pub fn new() -> Self {
74 Self {
75 enabled: false,
76 op: ReductionOp::Sum,
77 accumulator_type: "double".to_string(),
78 use_cooperative: true,
79 shared_array_name: "__reduction_shared".to_string(),
80 accumulator_name: "__reduction_accumulator".to_string(),
81 }
82 }
83
84 pub fn with_op(mut self, op: ReductionOp) -> Self {
86 self.enabled = true;
87 self.op = op;
88 self
89 }
90
91 pub fn with_type(mut self, ty: &str) -> Self {
93 self.accumulator_type = ty.to_string();
94 self
95 }
96
97 pub fn with_cooperative(mut self, use_cooperative: bool) -> Self {
99 self.use_cooperative = use_cooperative;
100 self
101 }
102
103 pub fn with_shared_name(mut self, name: &str) -> Self {
105 self.shared_array_name = name.to_string();
106 self
107 }
108
109 pub fn with_accumulator_name(mut self, name: &str) -> Self {
111 self.accumulator_name = name.to_string();
112 self
113 }
114
115 pub fn generate_shared_declaration(&self, block_size: u32) -> String {
117 if !self.enabled {
118 return String::new();
119 }
120 format!(
121 " __shared__ {} {}[{}];",
122 self.accumulator_type, self.shared_array_name, block_size
123 )
124 }
125
126 pub fn generate_accumulator_param(&self) -> String {
128 if !self.enabled {
129 return String::new();
130 }
131 format!(
132 " {}* __restrict__ {},",
133 self.accumulator_type, self.accumulator_name
134 )
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct RingKernelConfig {
141 pub id: String,
143 pub block_size: u32,
145 pub queue_capacity: u32,
147 pub enable_k2k: bool,
149 pub enable_hlc: bool,
151 pub message_size: usize,
153 pub response_size: usize,
155 pub cooperative_groups: bool,
157 pub idle_sleep_ns: u32,
159 pub use_envelope_format: bool,
162 pub kernel_id_num: u64,
164 pub hlc_node_id: u64,
166 pub reduction: KernelReductionConfig,
168}
169
170impl Default for RingKernelConfig {
171 fn default() -> Self {
172 Self {
173 id: "ring_kernel".to_string(),
174 block_size: 128,
175 queue_capacity: 1024,
176 enable_k2k: false,
177 enable_hlc: true,
178 message_size: 64,
179 response_size: 64,
180 cooperative_groups: false,
181 idle_sleep_ns: 1000,
182 use_envelope_format: true, kernel_id_num: 0,
184 hlc_node_id: 0,
185 reduction: KernelReductionConfig::new(),
186 }
187 }
188}
189
190impl RingKernelConfig {
191 pub fn new(id: impl Into<String>) -> Self {
193 Self {
194 id: id.into(),
195 ..Default::default()
196 }
197 }
198
199 pub fn with_block_size(mut self, size: u32) -> Self {
201 self.block_size = size;
202 self
203 }
204
205 pub fn with_queue_capacity(mut self, capacity: u32) -> Self {
207 debug_assert!(
208 capacity.is_power_of_two(),
209 "Queue capacity must be power of 2"
210 );
211 self.queue_capacity = capacity;
212 self
213 }
214
215 pub fn with_k2k(mut self, enabled: bool) -> Self {
217 self.enable_k2k = enabled;
218 self
219 }
220
221 pub fn with_hlc(mut self, enabled: bool) -> Self {
223 self.enable_hlc = enabled;
224 self
225 }
226
227 pub fn with_message_sizes(mut self, message: usize, response: usize) -> Self {
229 self.message_size = message;
230 self.response_size = response;
231 self
232 }
233
234 pub fn with_idle_sleep(mut self, ns: u32) -> Self {
236 self.idle_sleep_ns = ns;
237 self
238 }
239
240 pub fn with_envelope_format(mut self, enabled: bool) -> Self {
243 self.use_envelope_format = enabled;
244 self
245 }
246
247 pub fn with_kernel_id(mut self, id: u64) -> Self {
249 self.kernel_id_num = id;
250 self
251 }
252
253 pub fn with_hlc_node_id(mut self, node_id: u64) -> Self {
255 self.hlc_node_id = node_id;
256 self
257 }
258
259 pub fn with_reduction(mut self, reduction: KernelReductionConfig) -> Self {
275 self.reduction = reduction;
276 if self.reduction.enabled && self.reduction.use_cooperative {
278 self.cooperative_groups = true;
279 }
280 self
281 }
282
283 pub fn with_sum_reduction(mut self) -> Self {
287 self.reduction = KernelReductionConfig::new()
288 .with_op(ReductionOp::Sum)
289 .with_type("double");
290 self.cooperative_groups = true;
291 self
292 }
293
294 pub fn kernel_name(&self) -> String {
296 format!("ring_kernel_{}", self.id)
297 }
298
299 pub fn generate_signature(&self) -> String {
301 let mut sig = String::new();
302
303 writeln!(sig, "extern \"C\" __global__ void {}(", self.kernel_name()).unwrap();
304 writeln!(sig, " ControlBlock* __restrict__ control,").unwrap();
305 writeln!(sig, " unsigned char* __restrict__ input_buffer,").unwrap();
306 writeln!(sig, " unsigned char* __restrict__ output_buffer,").unwrap();
307
308 if self.enable_k2k {
309 writeln!(sig, " K2KRoutingTable* __restrict__ k2k_routes,").unwrap();
310 writeln!(sig, " unsigned char* __restrict__ k2k_inbox,").unwrap();
311 writeln!(sig, " unsigned char* __restrict__ k2k_outbox,").unwrap();
312 }
313
314 write!(sig, " void* __restrict__ shared_state").unwrap();
315 write!(sig, "\n)").unwrap();
316
317 sig
318 }
319
320 pub fn generate_preamble(&self, indent: &str) -> String {
322 let mut code = String::new();
323
324 writeln!(code, "{}// Thread identification", indent).unwrap();
326 writeln!(
327 code,
328 "{}int tid = threadIdx.x + blockIdx.x * blockDim.x;",
329 indent
330 )
331 .unwrap();
332 writeln!(code, "{}int lane_id = threadIdx.x % 32;", indent).unwrap();
333 writeln!(code, "{}int warp_id = threadIdx.x / 32;", indent).unwrap();
334 writeln!(code).unwrap();
335
336 writeln!(code, "{}// Kernel identity", indent).unwrap();
338 writeln!(
339 code,
340 "{}const unsigned long long KERNEL_ID = {}ULL;",
341 indent, self.kernel_id_num
342 )
343 .unwrap();
344 writeln!(
345 code,
346 "{}const unsigned long long HLC_NODE_ID = {}ULL;",
347 indent, self.hlc_node_id
348 )
349 .unwrap();
350 writeln!(code).unwrap();
351
352 writeln!(code, "{}// Message buffer constants", indent).unwrap();
354 if self.use_envelope_format {
355 writeln!(
357 code,
358 "{}const unsigned int PAYLOAD_SIZE = {}; // User payload size",
359 indent, self.message_size
360 )
361 .unwrap();
362 writeln!(
363 code,
364 "{}const unsigned int MSG_SIZE = MESSAGE_HEADER_SIZE + PAYLOAD_SIZE; // Full envelope",
365 indent
366 )
367 .unwrap();
368 writeln!(
369 code,
370 "{}const unsigned int RESP_PAYLOAD_SIZE = {};",
371 indent, self.response_size
372 )
373 .unwrap();
374 writeln!(
375 code,
376 "{}const unsigned int RESP_SIZE = MESSAGE_HEADER_SIZE + RESP_PAYLOAD_SIZE;",
377 indent
378 )
379 .unwrap();
380 } else {
381 writeln!(
383 code,
384 "{}const unsigned int MSG_SIZE = {};",
385 indent, self.message_size
386 )
387 .unwrap();
388 writeln!(
389 code,
390 "{}const unsigned int RESP_SIZE = {};",
391 indent, self.response_size
392 )
393 .unwrap();
394 }
395 writeln!(
396 code,
397 "{}const unsigned int QUEUE_MASK = {};",
398 indent,
399 self.queue_capacity - 1
400 )
401 .unwrap();
402 writeln!(code).unwrap();
403
404 if self.enable_hlc {
406 writeln!(code, "{}// HLC clock state", indent).unwrap();
407 writeln!(code, "{}unsigned long long hlc_physical = 0;", indent).unwrap();
408 writeln!(code, "{}unsigned long long hlc_logical = 0;", indent).unwrap();
409 writeln!(code).unwrap();
410 }
411
412 code
413 }
414
415 pub fn generate_loop_header(&self, indent: &str) -> String {
417 let mut code = String::new();
418
419 writeln!(code, "{}// Persistent message processing loop", indent).unwrap();
420 writeln!(code, "{}while (true) {{", indent).unwrap();
421 writeln!(code, "{} // Check for termination signal", indent).unwrap();
422 writeln!(
423 code,
424 "{} if (atomicAdd(&control->should_terminate, 0) != 0) {{",
425 indent
426 )
427 .unwrap();
428 writeln!(code, "{} break;", indent).unwrap();
429 writeln!(code, "{} }}", indent).unwrap();
430 writeln!(code).unwrap();
431
432 writeln!(code, "{} // Check if kernel is active", indent).unwrap();
434 writeln!(
435 code,
436 "{} if (atomicAdd(&control->is_active, 0) == 0) {{",
437 indent
438 )
439 .unwrap();
440 if self.idle_sleep_ns > 0 {
441 writeln!(
442 code,
443 "{} __nanosleep({});",
444 indent, self.idle_sleep_ns
445 )
446 .unwrap();
447 }
448 writeln!(code, "{} continue;", indent).unwrap();
449 writeln!(code, "{} }}", indent).unwrap();
450 writeln!(code).unwrap();
451
452 writeln!(code, "{} // Check input queue for messages", indent).unwrap();
454 writeln!(
455 code,
456 "{} unsigned long long head = atomicAdd(&control->input_head, 0);",
457 indent
458 )
459 .unwrap();
460 writeln!(
461 code,
462 "{} unsigned long long tail = atomicAdd(&control->input_tail, 0);",
463 indent
464 )
465 .unwrap();
466 writeln!(code).unwrap();
467 writeln!(code, "{} if (head == tail) {{", indent).unwrap();
468 writeln!(code, "{} // No messages, yield", indent).unwrap();
469 if self.idle_sleep_ns > 0 {
470 writeln!(
471 code,
472 "{} __nanosleep({});",
473 indent, self.idle_sleep_ns
474 )
475 .unwrap();
476 }
477 writeln!(code, "{} continue;", indent).unwrap();
478 writeln!(code, "{} }}", indent).unwrap();
479 writeln!(code).unwrap();
480
481 writeln!(code, "{} // Get message from queue", indent).unwrap();
483 writeln!(
484 code,
485 "{} unsigned int msg_idx = (unsigned int)(tail & QUEUE_MASK);",
486 indent
487 )
488 .unwrap();
489 writeln!(
490 code,
491 "{} unsigned char* envelope_ptr = &input_buffer[msg_idx * MSG_SIZE];",
492 indent
493 )
494 .unwrap();
495
496 if self.use_envelope_format {
497 writeln!(code).unwrap();
499 writeln!(code, "{} // Parse message envelope", indent).unwrap();
500 writeln!(
501 code,
502 "{} MessageHeader* msg_header = message_get_header(envelope_ptr);",
503 indent
504 )
505 .unwrap();
506 writeln!(
507 code,
508 "{} unsigned char* msg_ptr = message_get_payload(envelope_ptr);",
509 indent
510 )
511 .unwrap();
512 writeln!(code).unwrap();
513 writeln!(code, "{} // Validate message (skip invalid)", indent).unwrap();
514 writeln!(
515 code,
516 "{} if (!message_header_validate(msg_header)) {{",
517 indent
518 )
519 .unwrap();
520 writeln!(
521 code,
522 "{} atomicAdd(&control->input_tail, 1);",
523 indent
524 )
525 .unwrap();
526 writeln!(
527 code,
528 "{} atomicAdd(&control->last_error, 1); // Track errors",
529 indent
530 )
531 .unwrap();
532 writeln!(code, "{} continue;", indent).unwrap();
533 writeln!(code, "{} }}", indent).unwrap();
534
535 if self.enable_hlc {
537 writeln!(code).unwrap();
538 writeln!(code, "{} // Update HLC from message timestamp", indent).unwrap();
539 writeln!(
540 code,
541 "{} if (msg_header->timestamp.physical > hlc_physical) {{",
542 indent
543 )
544 .unwrap();
545 writeln!(
546 code,
547 "{} hlc_physical = msg_header->timestamp.physical;",
548 indent
549 )
550 .unwrap();
551 writeln!(code, "{} hlc_logical = 0;", indent).unwrap();
552 writeln!(code, "{} }}", indent).unwrap();
553 writeln!(code, "{} hlc_logical++;", indent).unwrap();
554 }
555 } else {
556 writeln!(
558 code,
559 "{} unsigned char* msg_ptr = envelope_ptr; // Raw message data",
560 indent
561 )
562 .unwrap();
563 }
564 writeln!(code).unwrap();
565
566 code
567 }
568
569 pub fn generate_message_complete(&self, indent: &str) -> String {
571 let mut code = String::new();
572
573 writeln!(code).unwrap();
574 writeln!(code, "{} // Mark message as processed", indent).unwrap();
575 writeln!(code, "{} atomicAdd(&control->input_tail, 1);", indent).unwrap();
576 writeln!(
577 code,
578 "{} atomicAdd(&control->messages_processed, 1);",
579 indent
580 )
581 .unwrap();
582
583 if self.enable_hlc {
584 writeln!(code).unwrap();
585 writeln!(code, "{} // Update HLC", indent).unwrap();
586 writeln!(code, "{} hlc_logical++;", indent).unwrap();
587 }
588
589 code
590 }
591
592 pub fn generate_loop_footer(&self, indent: &str) -> String {
594 let mut code = String::new();
595
596 writeln!(code, "{} __syncthreads();", indent).unwrap();
597 writeln!(code, "{}}}", indent).unwrap();
598
599 code
600 }
601
602 pub fn generate_epilogue(&self, indent: &str) -> String {
604 let mut code = String::new();
605
606 writeln!(code).unwrap();
607 writeln!(code, "{}// Mark kernel as terminated", indent).unwrap();
608 writeln!(code, "{}if (tid == 0) {{", indent).unwrap();
609
610 if self.enable_hlc {
611 writeln!(code, "{} // Store final HLC state", indent).unwrap();
612 writeln!(
613 code,
614 "{} control->hlc_state.physical = hlc_physical;",
615 indent
616 )
617 .unwrap();
618 writeln!(
619 code,
620 "{} control->hlc_state.logical = hlc_logical;",
621 indent
622 )
623 .unwrap();
624 }
625
626 writeln!(
627 code,
628 "{} atomicExch(&control->has_terminated, 1);",
629 indent
630 )
631 .unwrap();
632 writeln!(code, "{}}}", indent).unwrap();
633
634 code
635 }
636
637 pub fn generate_kernel_wrapper(&self, handler_placeholder: &str) -> String {
639 let mut code = String::new();
640
641 code.push_str(&generate_control_block_struct());
643 code.push('\n');
644
645 if self.enable_hlc {
646 code.push_str(&generate_hlc_struct());
647 code.push('\n');
648 }
649
650 if self.use_envelope_format {
652 code.push_str(&generate_message_envelope_structs());
653 code.push('\n');
654 }
655
656 if self.enable_k2k {
657 code.push_str(&generate_k2k_structs());
658 code.push('\n');
659 }
660
661 code.push_str(&self.generate_signature());
663 code.push_str(" {\n");
664
665 code.push_str(&self.generate_preamble(" "));
667
668 code.push_str(&self.generate_loop_header(" "));
670
671 writeln!(code, " // === USER HANDLER CODE ===").unwrap();
673 for line in handler_placeholder.lines() {
674 writeln!(code, " {}", line).unwrap();
675 }
676 writeln!(code, " // === END HANDLER CODE ===").unwrap();
677
678 code.push_str(&self.generate_message_complete(" "));
680
681 code.push_str(&self.generate_loop_footer(" "));
683
684 code.push_str(&self.generate_epilogue(" "));
686
687 code.push_str("}\n");
688
689 code
690 }
691}
692
693pub fn generate_control_block_struct() -> String {
695 r#"// Control block for kernel state management (128 bytes, cache-line aligned)
696struct __align__(128) ControlBlock {
697 // Lifecycle state
698 unsigned int is_active;
699 unsigned int should_terminate;
700 unsigned int has_terminated;
701 unsigned int _pad1;
702
703 // Counters
704 unsigned long long messages_processed;
705 unsigned long long messages_in_flight;
706
707 // Queue pointers
708 unsigned long long input_head;
709 unsigned long long input_tail;
710 unsigned long long output_head;
711 unsigned long long output_tail;
712
713 // Queue metadata
714 unsigned int input_capacity;
715 unsigned int output_capacity;
716 unsigned int input_mask;
717 unsigned int output_mask;
718
719 // HLC state
720 struct {
721 unsigned long long physical;
722 unsigned long long logical;
723 } hlc_state;
724
725 // Error state
726 unsigned int last_error;
727 unsigned int error_count;
728
729 // Reserved padding
730 unsigned char _reserved[24];
731};
732"#
733 .to_string()
734}
735
736pub fn generate_hlc_struct() -> String {
738 r#"// Hybrid Logical Clock state
739struct HlcState {
740 unsigned long long physical;
741 unsigned long long logical;
742};
743
744// HLC timestamp (24 bytes)
745struct __align__(8) HlcTimestamp {
746 unsigned long long physical;
747 unsigned long long logical;
748 unsigned long long node_id;
749};
750"#
751 .to_string()
752}
753
754pub fn generate_message_envelope_structs() -> String {
759 r#"// Magic number for message validation
760#define MESSAGE_MAGIC 0x52494E474B45524E ULL // "RINGKERN"
761#define MESSAGE_VERSION 1
762#define MESSAGE_HEADER_SIZE 256
763#define MAX_PAYLOAD_SIZE (64 * 1024)
764
765// Message priority levels
766#define PRIORITY_LOW 0
767#define PRIORITY_NORMAL 1
768#define PRIORITY_HIGH 2
769#define PRIORITY_CRITICAL 3
770
771// Message header structure (256 bytes, cache-line aligned)
772// Matches ringkernel_core::message::MessageHeader exactly
773struct __align__(64) MessageHeader {
774 // Magic number for validation (0xRINGKERN)
775 unsigned long long magic;
776 // Header version
777 unsigned int version;
778 // Message flags
779 unsigned int flags;
780 // Unique message identifier
781 unsigned long long message_id;
782 // Correlation ID for request-response
783 unsigned long long correlation_id;
784 // Source kernel ID (0 for host)
785 unsigned long long source_kernel;
786 // Destination kernel ID (0 for host)
787 unsigned long long dest_kernel;
788 // Message type discriminator
789 unsigned long long message_type;
790 // Priority level
791 unsigned char priority;
792 // Reserved for alignment
793 unsigned char _reserved1[7];
794 // Payload size in bytes
795 unsigned long long payload_size;
796 // Checksum of payload (CRC32)
797 unsigned int checksum;
798 // Reserved for alignment
799 unsigned int _reserved2;
800 // HLC timestamp when message was created
801 HlcTimestamp timestamp;
802 // Deadline timestamp (0 = no deadline)
803 HlcTimestamp deadline;
804 // Reserved for future use (104 bytes total: 32+32+32+8)
805 unsigned char _reserved3[104];
806};
807
808// Validate message header
809__device__ inline int message_header_validate(const MessageHeader* header) {
810 return header->magic == MESSAGE_MAGIC &&
811 header->version <= MESSAGE_VERSION &&
812 header->payload_size <= MAX_PAYLOAD_SIZE;
813}
814
815// Get payload pointer from header
816__device__ inline unsigned char* message_get_payload(unsigned char* envelope_ptr) {
817 return envelope_ptr + MESSAGE_HEADER_SIZE;
818}
819
820// Get header from envelope pointer
821__device__ inline MessageHeader* message_get_header(unsigned char* envelope_ptr) {
822 return (MessageHeader*)envelope_ptr;
823}
824
825// CRC32 computation for payload checksums
826// Uses CRC32-C (Castagnoli) polynomial 0x1EDC6F41
827__device__ inline unsigned int message_compute_checksum(const unsigned char* data, unsigned long long size) {
828 unsigned int crc = 0xFFFFFFFF;
829 for (unsigned long long i = 0; i < size; i++) {
830 crc ^= data[i];
831 for (int j = 0; j < 8; j++) {
832 crc = (crc >> 1) ^ (0x82F63B78 * (crc & 1)); // CRC32-C polynomial
833 }
834 }
835 return ~crc;
836}
837
838// Verify payload checksum
839__device__ inline int message_verify_checksum(const MessageHeader* header, const unsigned char* payload) {
840 if (header->payload_size == 0) {
841 return header->checksum == 0; // Empty payload should have zero checksum
842 }
843 unsigned int computed = message_compute_checksum(payload, header->payload_size);
844 return computed == header->checksum;
845}
846
847// Create a response header based on request
848// Note: payload_ptr is optional, pass NULL if checksum not needed yet
849__device__ inline void message_create_response_header(
850 MessageHeader* response,
851 const MessageHeader* request,
852 unsigned long long this_kernel_id,
853 unsigned long long payload_size,
854 unsigned long long hlc_physical,
855 unsigned long long hlc_logical,
856 unsigned long long hlc_node_id,
857 const unsigned char* payload_ptr // Optional: for checksum computation
858) {
859 response->magic = MESSAGE_MAGIC;
860 response->version = MESSAGE_VERSION;
861 response->flags = 0;
862 // Generate new message ID (simple increment from request)
863 response->message_id = request->message_id + 0x100000000ULL;
864 // Preserve correlation ID for request-response matching
865 response->correlation_id = request->correlation_id != 0
866 ? request->correlation_id
867 : request->message_id;
868 response->source_kernel = this_kernel_id;
869 response->dest_kernel = request->source_kernel; // Response goes back to sender
870 response->message_type = request->message_type + 1; // Convention: response type = request + 1
871 response->priority = request->priority;
872 response->payload_size = payload_size;
873 // Compute checksum if payload provided
874 if (payload_ptr != NULL && payload_size > 0) {
875 response->checksum = message_compute_checksum(payload_ptr, payload_size);
876 } else {
877 response->checksum = 0;
878 }
879 response->timestamp.physical = hlc_physical;
880 response->timestamp.logical = hlc_logical;
881 response->timestamp.node_id = hlc_node_id;
882 response->deadline.physical = 0;
883 response->deadline.logical = 0;
884 response->deadline.node_id = 0;
885}
886
887// Calculate total envelope size
888__device__ inline unsigned int message_envelope_size(const MessageHeader* header) {
889 return MESSAGE_HEADER_SIZE + (unsigned int)header->payload_size;
890}
891"#
892 .to_string()
893}
894
895pub fn generate_k2k_structs() -> String {
897 r#"// Kernel-to-kernel routing table entry (48 bytes)
898// Matches ringkernel_cuda::k2k_gpu::K2KRouteEntry
899struct K2KRoute {
900 unsigned long long target_kernel_id;
901 unsigned long long target_inbox; // device pointer
902 unsigned long long target_head; // device pointer to head
903 unsigned long long target_tail; // device pointer to tail
904 unsigned int capacity;
905 unsigned int mask;
906 unsigned int msg_size;
907 unsigned int _pad;
908};
909
910// K2K routing table (8 + 48*16 = 776 bytes)
911struct K2KRoutingTable {
912 unsigned int num_routes;
913 unsigned int _pad;
914 K2KRoute routes[16]; // Max 16 K2K connections
915};
916
917// K2K inbox header (64 bytes, cache-line aligned)
918// Matches ringkernel_cuda::k2k_gpu::K2KInboxHeader
919struct __align__(64) K2KInboxHeader {
920 unsigned long long head;
921 unsigned long long tail;
922 unsigned int capacity;
923 unsigned int mask;
924 unsigned int msg_size;
925 unsigned int _pad;
926};
927
928// Send a message envelope to another kernel via K2K
929// The entire envelope (header + payload) is copied to the target's inbox
930__device__ inline int k2k_send_envelope(
931 K2KRoutingTable* routes,
932 unsigned long long target_id,
933 unsigned long long source_kernel_id,
934 const void* payload_ptr,
935 unsigned int payload_size,
936 unsigned long long message_type,
937 unsigned long long hlc_physical,
938 unsigned long long hlc_logical,
939 unsigned long long hlc_node_id
940) {
941 // Find route for target
942 for (unsigned int i = 0; i < routes->num_routes; i++) {
943 if (routes->routes[i].target_kernel_id == target_id) {
944 K2KRoute* route = &routes->routes[i];
945
946 // Calculate total envelope size
947 unsigned int envelope_size = MESSAGE_HEADER_SIZE + payload_size;
948 if (envelope_size > route->msg_size) {
949 return -1; // Message too large
950 }
951
952 // Atomically claim a slot in target's inbox
953 unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
954 unsigned long long slot = atomicAdd(target_head_ptr, 1);
955 unsigned int idx = (unsigned int)(slot & route->mask);
956
957 // Calculate destination pointer
958 unsigned char* dest = ((unsigned char*)route->target_inbox) +
959 sizeof(K2KInboxHeader) + idx * route->msg_size;
960
961 // Build message header
962 MessageHeader* header = (MessageHeader*)dest;
963 header->magic = MESSAGE_MAGIC;
964 header->version = MESSAGE_VERSION;
965 header->flags = 0;
966 header->message_id = (source_kernel_id << 32) | (slot & 0xFFFFFFFF);
967 header->correlation_id = 0;
968 header->source_kernel = source_kernel_id;
969 header->dest_kernel = target_id;
970 header->message_type = message_type;
971 header->priority = PRIORITY_NORMAL;
972 header->payload_size = payload_size;
973 header->timestamp.physical = hlc_physical;
974 header->timestamp.logical = hlc_logical;
975 header->timestamp.node_id = hlc_node_id;
976 header->deadline.physical = 0;
977 header->deadline.logical = 0;
978 header->deadline.node_id = 0;
979
980 // Copy payload after header and compute checksum
981 if (payload_size > 0 && payload_ptr != NULL) {
982 memcpy(dest + MESSAGE_HEADER_SIZE, payload_ptr, payload_size);
983 header->checksum = message_compute_checksum(
984 (const unsigned char*)payload_ptr, payload_size);
985 } else {
986 header->checksum = 0;
987 }
988
989 __threadfence(); // Ensure write is visible
990 return 1; // Success
991 }
992 }
993 return 0; // Route not found
994}
995
996// Legacy k2k_send for raw message data (no envelope)
997__device__ inline int k2k_send(
998 K2KRoutingTable* routes,
999 unsigned long long target_id,
1000 const void* msg_ptr,
1001 unsigned int msg_size
1002) {
1003 // Find route for target
1004 for (unsigned int i = 0; i < routes->num_routes; i++) {
1005 if (routes->routes[i].target_kernel_id == target_id) {
1006 K2KRoute* route = &routes->routes[i];
1007
1008 // Atomically claim a slot in target's inbox
1009 unsigned long long* target_head_ptr = (unsigned long long*)route->target_head;
1010 unsigned long long slot = atomicAdd(target_head_ptr, 1);
1011 unsigned int idx = (unsigned int)(slot & route->mask);
1012
1013 // Copy message to target inbox
1014 unsigned char* dest = ((unsigned char*)route->target_inbox) +
1015 sizeof(K2KInboxHeader) + idx * route->msg_size;
1016 memcpy(dest, msg_ptr, msg_size < route->msg_size ? msg_size : route->msg_size);
1017
1018 __threadfence();
1019 return 1; // Success
1020 }
1021 }
1022 return 0; // Route not found
1023}
1024
1025// Check if there are K2K messages in inbox
1026__device__ inline int k2k_has_message(unsigned char* k2k_inbox) {
1027 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1028 unsigned long long head = atomicAdd(&header->head, 0);
1029 unsigned long long tail = atomicAdd(&header->tail, 0);
1030 return head != tail;
1031}
1032
1033// Try to receive a K2K message envelope
1034// Returns pointer to MessageHeader if available, NULL otherwise
1035__device__ inline MessageHeader* k2k_try_recv_envelope(unsigned char* k2k_inbox) {
1036 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1037
1038 unsigned long long head = atomicAdd(&header->head, 0);
1039 unsigned long long tail = atomicAdd(&header->tail, 0);
1040
1041 if (head == tail) {
1042 return NULL; // No messages
1043 }
1044
1045 // Get message pointer
1046 unsigned int idx = (unsigned int)(tail & header->mask);
1047 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1048 MessageHeader* msg_header = (MessageHeader*)(data_start + idx * header->msg_size);
1049
1050 // Validate header
1051 if (!message_header_validate(msg_header)) {
1052 // Invalid message, skip it
1053 atomicAdd(&header->tail, 1);
1054 return NULL;
1055 }
1056
1057 // Advance tail (consume message)
1058 atomicAdd(&header->tail, 1);
1059
1060 return msg_header;
1061}
1062
1063// Legacy k2k_try_recv for raw data
1064__device__ inline void* k2k_try_recv(unsigned char* k2k_inbox) {
1065 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1066
1067 unsigned long long head = atomicAdd(&header->head, 0);
1068 unsigned long long tail = atomicAdd(&header->tail, 0);
1069
1070 if (head == tail) {
1071 return NULL; // No messages
1072 }
1073
1074 // Get message pointer
1075 unsigned int idx = (unsigned int)(tail & header->mask);
1076 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1077
1078 // Advance tail (consume message)
1079 atomicAdd(&header->tail, 1);
1080
1081 return data_start + idx * header->msg_size;
1082}
1083
1084// Peek at next K2K message without consuming
1085__device__ inline MessageHeader* k2k_peek_envelope(unsigned char* k2k_inbox) {
1086 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1087
1088 unsigned long long head = atomicAdd(&header->head, 0);
1089 unsigned long long tail = atomicAdd(&header->tail, 0);
1090
1091 if (head == tail) {
1092 return NULL; // No messages
1093 }
1094
1095 unsigned int idx = (unsigned int)(tail & header->mask);
1096 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1097
1098 return (MessageHeader*)(data_start + idx * header->msg_size);
1099}
1100
1101// Legacy k2k_peek
1102__device__ inline void* k2k_peek(unsigned char* k2k_inbox) {
1103 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1104
1105 unsigned long long head = atomicAdd(&header->head, 0);
1106 unsigned long long tail = atomicAdd(&header->tail, 0);
1107
1108 if (head == tail) {
1109 return NULL; // No messages
1110 }
1111
1112 unsigned int idx = (unsigned int)(tail & header->mask);
1113 unsigned char* data_start = k2k_inbox + sizeof(K2KInboxHeader);
1114
1115 return data_start + idx * header->msg_size;
1116}
1117
1118// Get number of pending K2K messages
1119__device__ inline unsigned int k2k_pending_count(unsigned char* k2k_inbox) {
1120 K2KInboxHeader* header = (K2KInboxHeader*)k2k_inbox;
1121 unsigned long long head = atomicAdd(&header->head, 0);
1122 unsigned long long tail = atomicAdd(&header->tail, 0);
1123 return (unsigned int)(head - tail);
1124}
1125"#
1126 .to_string()
1127}
1128
1129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1131pub enum RingKernelIntrinsic {
1132 IsActive,
1134 ShouldTerminate,
1135 MarkTerminated,
1136 GetMessagesProcessed,
1137
1138 InputQueueSize,
1140 OutputQueueSize,
1141 InputQueueEmpty,
1142 OutputQueueEmpty,
1143 EnqueueResponse,
1144
1145 HlcTick,
1147 HlcUpdate,
1148 HlcNow,
1149
1150 K2kSend,
1152 K2kTryRecv,
1153 K2kHasMessage,
1154 K2kPeek,
1155 K2kPendingCount,
1156}
1157
1158impl RingKernelIntrinsic {
1159 pub fn to_cuda(&self, args: &[String]) -> String {
1161 match self {
1162 Self::IsActive => "atomicAdd(&control->is_active, 0) != 0".to_string(),
1163 Self::ShouldTerminate => "atomicAdd(&control->should_terminate, 0) != 0".to_string(),
1164 Self::MarkTerminated => "atomicExch(&control->has_terminated, 1)".to_string(),
1165 Self::GetMessagesProcessed => "atomicAdd(&control->messages_processed, 0)".to_string(),
1166
1167 Self::InputQueueSize => {
1168 "(atomicAdd(&control->input_head, 0) - atomicAdd(&control->input_tail, 0))"
1169 .to_string()
1170 }
1171 Self::OutputQueueSize => {
1172 "(atomicAdd(&control->output_head, 0) - atomicAdd(&control->output_tail, 0))"
1173 .to_string()
1174 }
1175 Self::InputQueueEmpty => {
1176 "(atomicAdd(&control->input_head, 0) == atomicAdd(&control->input_tail, 0))"
1177 .to_string()
1178 }
1179 Self::OutputQueueEmpty => {
1180 "(atomicAdd(&control->output_head, 0) == atomicAdd(&control->output_tail, 0))"
1181 .to_string()
1182 }
1183 Self::EnqueueResponse => {
1184 if !args.is_empty() {
1185 format!(
1186 "{{ unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask; \
1187 memcpy(&output_buffer[_out_idx * RESP_SIZE], {}, RESP_SIZE); }}",
1188 args[0]
1189 )
1190 } else {
1191 "/* enqueue_response requires response pointer */".to_string()
1192 }
1193 }
1194
1195 Self::HlcTick => "hlc_logical++".to_string(),
1196 Self::HlcUpdate => {
1197 if !args.is_empty() {
1198 format!(
1199 "{{ if ({} > hlc_physical) {{ hlc_physical = {}; hlc_logical = 0; }} else {{ hlc_logical++; }} }}",
1200 args[0], args[0]
1201 )
1202 } else {
1203 "hlc_logical++".to_string()
1204 }
1205 }
1206 Self::HlcNow => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
1207
1208 Self::K2kSend => {
1209 if args.len() >= 2 {
1210 format!(
1212 "k2k_send(k2k_routes, {}, {}, sizeof(*{}))",
1213 args[0], args[1], args[1]
1214 )
1215 } else {
1216 "/* k2k_send requires target_id and msg_ptr */".to_string()
1217 }
1218 }
1219 Self::K2kTryRecv => "k2k_try_recv(k2k_inbox)".to_string(),
1220 Self::K2kHasMessage => "k2k_has_message(k2k_inbox)".to_string(),
1221 Self::K2kPeek => "k2k_peek(k2k_inbox)".to_string(),
1222 Self::K2kPendingCount => "k2k_pending_count(k2k_inbox)".to_string(),
1223 }
1224 }
1225
1226 pub fn from_name(name: &str) -> Option<Self> {
1228 match name {
1229 "is_active" | "is_kernel_active" => Some(Self::IsActive),
1230 "should_terminate" => Some(Self::ShouldTerminate),
1231 "mark_terminated" => Some(Self::MarkTerminated),
1232 "messages_processed" | "get_messages_processed" => Some(Self::GetMessagesProcessed),
1233
1234 "input_queue_size" => Some(Self::InputQueueSize),
1235 "output_queue_size" => Some(Self::OutputQueueSize),
1236 "input_queue_empty" => Some(Self::InputQueueEmpty),
1237 "output_queue_empty" => Some(Self::OutputQueueEmpty),
1238 "enqueue_response" | "enqueue" => Some(Self::EnqueueResponse),
1239
1240 "hlc_tick" => Some(Self::HlcTick),
1241 "hlc_update" => Some(Self::HlcUpdate),
1242 "hlc_now" => Some(Self::HlcNow),
1243
1244 "k2k_send" => Some(Self::K2kSend),
1245 "k2k_try_recv" => Some(Self::K2kTryRecv),
1246 "k2k_has_message" => Some(Self::K2kHasMessage),
1247 "k2k_peek" => Some(Self::K2kPeek),
1248 "k2k_pending_count" | "k2k_pending" => Some(Self::K2kPendingCount),
1249
1250 _ => None,
1251 }
1252 }
1253
1254 pub fn requires_k2k(&self) -> bool {
1256 matches!(
1257 self,
1258 Self::K2kSend
1259 | Self::K2kTryRecv
1260 | Self::K2kHasMessage
1261 | Self::K2kPeek
1262 | Self::K2kPendingCount
1263 )
1264 }
1265
1266 pub fn requires_hlc(&self) -> bool {
1268 matches!(self, Self::HlcTick | Self::HlcUpdate | Self::HlcNow)
1269 }
1270
1271 pub fn requires_control_block(&self) -> bool {
1273 matches!(
1274 self,
1275 Self::IsActive
1276 | Self::ShouldTerminate
1277 | Self::MarkTerminated
1278 | Self::GetMessagesProcessed
1279 | Self::InputQueueSize
1280 | Self::OutputQueueSize
1281 | Self::InputQueueEmpty
1282 | Self::OutputQueueEmpty
1283 | Self::EnqueueResponse
1284 )
1285 }
1286}
1287
1288#[derive(Debug, Clone)]
1297pub struct CudaHandlerInfo {
1298 pub handler_id: u32,
1300 pub func_name: String,
1302 pub message_type_name: String,
1304 pub message_type_id: u64,
1306 pub cuda_body: Option<String>,
1309 pub produces_response: bool,
1311}
1312
1313impl CudaHandlerInfo {
1314 pub fn new(handler_id: u32, func_name: impl Into<String>) -> Self {
1316 Self {
1317 handler_id,
1318 func_name: func_name.into(),
1319 message_type_name: String::new(),
1320 message_type_id: 0,
1321 cuda_body: None,
1322 produces_response: false,
1323 }
1324 }
1325
1326 pub fn with_message_type(mut self, name: &str, type_id: u64) -> Self {
1328 self.message_type_name = name.to_string();
1329 self.message_type_id = type_id;
1330 self
1331 }
1332
1333 pub fn with_cuda_body(mut self, body: impl Into<String>) -> Self {
1335 self.cuda_body = Some(body.into());
1336 self
1337 }
1338
1339 pub fn with_response(mut self) -> Self {
1341 self.produces_response = true;
1342 self
1343 }
1344}
1345
1346#[derive(Debug, Clone, Default)]
1348pub struct CudaDispatchTable {
1349 handlers: Vec<CudaHandlerInfo>,
1350 pub message_struct_name: String,
1352 pub handler_id_field: String,
1354 pub unknown_counter_field: String,
1356}
1357
1358impl CudaDispatchTable {
1359 pub fn new() -> Self {
1361 Self {
1362 handlers: Vec::new(),
1363 message_struct_name: "ExtendedH2KMessage".to_string(),
1364 handler_id_field: "handler_id".to_string(),
1365 unknown_counter_field: "unknown_handler_count".to_string(),
1366 }
1367 }
1368
1369 pub fn add_handler(&mut self, handler: CudaHandlerInfo) {
1371 self.handlers.push(handler);
1372 }
1373
1374 pub fn with_handler(mut self, handler: CudaHandlerInfo) -> Self {
1376 self.add_handler(handler);
1377 self
1378 }
1379
1380 pub fn with_message_struct(mut self, name: &str) -> Self {
1382 self.message_struct_name = name.to_string();
1383 self
1384 }
1385
1386 pub fn handlers(&self) -> &[CudaHandlerInfo] {
1388 &self.handlers
1389 }
1390
1391 pub fn len(&self) -> usize {
1393 self.handlers.len()
1394 }
1395
1396 pub fn is_empty(&self) -> bool {
1398 self.handlers.is_empty()
1399 }
1400}
1401
1402pub fn generate_handler_dispatch_code(table: &CudaDispatchTable, indent: &str) -> String {
1435 let mut code = String::new();
1436
1437 if table.is_empty() {
1438 writeln!(
1439 code,
1440 "{}// No handlers registered - dispatch table empty",
1441 indent
1442 )
1443 .unwrap();
1444 return code;
1445 }
1446
1447 writeln!(code, "{}// Handler dispatch based on handler_id", indent).unwrap();
1448 writeln!(
1449 code,
1450 "{}uint32_t handler_id = msg->{};",
1451 indent, table.handler_id_field
1452 )
1453 .unwrap();
1454 writeln!(code, "{}switch (handler_id) {{", indent).unwrap();
1455
1456 for handler in &table.handlers {
1457 writeln!(code, "{} case {}: {{", indent, handler.handler_id).unwrap();
1458
1459 if !handler.message_type_name.is_empty() {
1461 writeln!(
1462 code,
1463 "{} // Handler: {} (type_id: {})",
1464 indent, handler.message_type_name, handler.message_type_id
1465 )
1466 .unwrap();
1467 } else {
1468 writeln!(code, "{} // Handler: {}", indent, handler.func_name).unwrap();
1469 }
1470
1471 if let Some(body) = &handler.cuda_body {
1473 for line in body.lines() {
1475 writeln!(code, "{} {}", indent, line).unwrap();
1476 }
1477 } else {
1478 if handler.produces_response {
1480 writeln!(
1481 code,
1482 "{} {}(msg, state, response);",
1483 indent, handler.func_name
1484 )
1485 .unwrap();
1486 } else {
1487 writeln!(code, "{} {}(msg, state);", indent, handler.func_name).unwrap();
1488 }
1489 }
1490
1491 writeln!(code, "{} break;", indent).unwrap();
1492 writeln!(code, "{} }}", indent).unwrap();
1493 }
1494
1495 writeln!(code, "{} default:", indent).unwrap();
1497 writeln!(
1498 code,
1499 "{} atomicAdd(&ctrl->{}, 1);",
1500 indent, table.unknown_counter_field
1501 )
1502 .unwrap();
1503 writeln!(code, "{} break;", indent).unwrap();
1504 writeln!(code, "{}}}", indent).unwrap();
1505
1506 code
1507}
1508
1509pub fn generate_extended_h2k_message_struct() -> String {
1513 r#"// Extended H2K message for handler dispatch (64 bytes, cache-line aligned)
1514// This format supports type-based routing within persistent kernels
1515struct __align__(64) ExtendedH2KMessage {
1516 uint32_t handler_id; // Handler ID for dispatch (0-255)
1517 uint32_t flags; // Message flags (see message_flags)
1518 uint64_t cmd_id; // Command/correlation ID for request tracking
1519 uint64_t timestamp; // HLC timestamp
1520 uint8_t payload[40]; // Inline payload (up to 32 bytes used typically)
1521};
1522
1523// Message flags for ExtendedH2KMessage
1524#define EXT_MSG_FLAG_EXTENDED 0x01
1525#define EXT_MSG_FLAG_HIGH_PRIORITY 0x02
1526#define EXT_MSG_FLAG_EXTERNAL_BUF 0x04
1527#define EXT_MSG_FLAG_REQUIRES_RESP 0x08
1528"#
1529 .to_string()
1530}
1531
1532pub fn generate_multi_handler_kernel(
1543 config: &RingKernelConfig,
1544 table: &CudaDispatchTable,
1545) -> String {
1546 let mut code = String::new();
1547
1548 code.push_str(&generate_control_block_struct());
1550 code.push('\n');
1551
1552 if config.enable_hlc {
1553 code.push_str(&generate_hlc_struct());
1554 code.push('\n');
1555 }
1556
1557 code.push_str(&generate_extended_h2k_message_struct());
1559 code.push('\n');
1560
1561 if config.use_envelope_format {
1562 code.push_str(&generate_message_envelope_structs());
1563 code.push('\n');
1564 }
1565
1566 if config.enable_k2k {
1567 code.push_str(&generate_k2k_structs());
1568 code.push('\n');
1569 }
1570
1571 code.push_str(&config.generate_signature());
1573 code.push_str(" {\n");
1574
1575 code.push_str(&config.generate_preamble(" "));
1577
1578 code.push_str(&config.generate_loop_header(" "));
1580
1581 writeln!(code, " // === MULTI-HANDLER DISPATCH ===").unwrap();
1583 writeln!(
1584 code,
1585 " ExtendedH2KMessage* msg = (ExtendedH2KMessage*)(input_buffer + (msg_idx * MSG_SIZE));"
1586 )
1587 .unwrap();
1588 code.push_str(&generate_handler_dispatch_code(table, " "));
1589 writeln!(code, " // === END DISPATCH ===").unwrap();
1590
1591 code.push_str(&config.generate_message_complete(" "));
1593
1594 code.push_str(&config.generate_loop_footer(" "));
1596
1597 code.push_str(&config.generate_epilogue(" "));
1599
1600 code.push_str("}\n");
1601
1602 code
1603}
1604
1605#[cfg(test)]
1606mod tests {
1607 use super::*;
1608
1609 #[test]
1610 fn test_default_config() {
1611 let config = RingKernelConfig::default();
1612 assert_eq!(config.block_size, 128);
1613 assert_eq!(config.queue_capacity, 1024);
1614 assert!(config.enable_hlc);
1615 assert!(!config.enable_k2k);
1616 }
1617
1618 #[test]
1619 fn test_config_builder() {
1620 let config = RingKernelConfig::new("processor")
1621 .with_block_size(256)
1622 .with_queue_capacity(2048)
1623 .with_k2k(true)
1624 .with_hlc(true);
1625
1626 assert_eq!(config.id, "processor");
1627 assert_eq!(config.kernel_name(), "ring_kernel_processor");
1628 assert_eq!(config.block_size, 256);
1629 assert_eq!(config.queue_capacity, 2048);
1630 assert!(config.enable_k2k);
1631 assert!(config.enable_hlc);
1632 }
1633
1634 #[test]
1635 fn test_kernel_signature() {
1636 let config = RingKernelConfig::new("test");
1637 let sig = config.generate_signature();
1638
1639 assert!(sig.contains("extern \"C\" __global__ void ring_kernel_test"));
1640 assert!(sig.contains("ControlBlock* __restrict__ control"));
1641 assert!(sig.contains("input_buffer"));
1642 assert!(sig.contains("output_buffer"));
1643 assert!(sig.contains("shared_state"));
1644 }
1645
1646 #[test]
1647 fn test_kernel_signature_with_k2k() {
1648 let config = RingKernelConfig::new("k2k_test").with_k2k(true);
1649 let sig = config.generate_signature();
1650
1651 assert!(sig.contains("K2KRoutingTable"));
1652 assert!(sig.contains("k2k_inbox"));
1653 assert!(sig.contains("k2k_outbox"));
1654 }
1655
1656 #[test]
1657 fn test_preamble_generation() {
1658 let config = RingKernelConfig::new("test").with_hlc(true);
1659 let preamble = config.generate_preamble(" ");
1660
1661 assert!(preamble.contains("int tid = threadIdx.x + blockIdx.x * blockDim.x"));
1662 assert!(preamble.contains("int lane_id"));
1663 assert!(preamble.contains("int warp_id"));
1664 assert!(preamble.contains("MSG_SIZE"));
1665 assert!(preamble.contains("hlc_physical"));
1666 assert!(preamble.contains("hlc_logical"));
1667 }
1668
1669 #[test]
1670 fn test_loop_header() {
1671 let config = RingKernelConfig::new("test");
1672 let header = config.generate_loop_header(" ");
1673
1674 assert!(header.contains("while (true)"));
1675 assert!(header.contains("should_terminate"));
1676 assert!(header.contains("is_active"));
1677 assert!(header.contains("input_head"));
1678 assert!(header.contains("input_tail"));
1679 assert!(header.contains("msg_ptr"));
1680 }
1681
1682 #[test]
1683 fn test_epilogue() {
1684 let config = RingKernelConfig::new("test").with_hlc(true);
1685 let epilogue = config.generate_epilogue(" ");
1686
1687 assert!(epilogue.contains("has_terminated"));
1688 assert!(epilogue.contains("hlc_state.physical"));
1689 assert!(epilogue.contains("hlc_state.logical"));
1690 }
1691
1692 #[test]
1693 fn test_control_block_struct() {
1694 let code = generate_control_block_struct();
1695
1696 assert!(code.contains("struct __align__(128) ControlBlock"));
1697 assert!(code.contains("is_active"));
1698 assert!(code.contains("should_terminate"));
1699 assert!(code.contains("has_terminated"));
1700 assert!(code.contains("messages_processed"));
1701 assert!(code.contains("input_head"));
1702 assert!(code.contains("input_tail"));
1703 assert!(code.contains("hlc_state"));
1704 }
1705
1706 #[test]
1707 fn test_full_kernel_wrapper() {
1708 let config = RingKernelConfig::new("example")
1709 .with_block_size(128)
1710 .with_hlc(true);
1711
1712 let kernel = config.generate_kernel_wrapper("// Process message here");
1713
1714 assert!(kernel.contains("struct __align__(128) ControlBlock"));
1715 assert!(kernel.contains("extern \"C\" __global__ void ring_kernel_example"));
1716 assert!(kernel.contains("while (true)"));
1717 assert!(kernel.contains("// Process message here"));
1718 assert!(kernel.contains("has_terminated"));
1719
1720 println!("Generated kernel:\n{}", kernel);
1721 }
1722
1723 #[test]
1724 fn test_intrinsic_lookup() {
1725 assert_eq!(
1726 RingKernelIntrinsic::from_name("is_active"),
1727 Some(RingKernelIntrinsic::IsActive)
1728 );
1729 assert_eq!(
1730 RingKernelIntrinsic::from_name("should_terminate"),
1731 Some(RingKernelIntrinsic::ShouldTerminate)
1732 );
1733 assert_eq!(
1734 RingKernelIntrinsic::from_name("hlc_tick"),
1735 Some(RingKernelIntrinsic::HlcTick)
1736 );
1737 assert_eq!(RingKernelIntrinsic::from_name("unknown"), None);
1738 }
1739
1740 #[test]
1741 fn test_intrinsic_cuda_output() {
1742 assert!(RingKernelIntrinsic::IsActive
1743 .to_cuda(&[])
1744 .contains("is_active"));
1745 assert!(RingKernelIntrinsic::ShouldTerminate
1746 .to_cuda(&[])
1747 .contains("should_terminate"));
1748 assert!(RingKernelIntrinsic::HlcTick
1749 .to_cuda(&[])
1750 .contains("hlc_logical++"));
1751 }
1752
1753 #[test]
1754 fn test_k2k_structs_generation() {
1755 let k2k_code = generate_k2k_structs();
1756
1757 assert!(
1759 k2k_code.contains("struct K2KRoute"),
1760 "Should have K2KRoute struct"
1761 );
1762 assert!(
1763 k2k_code.contains("struct K2KRoutingTable"),
1764 "Should have K2KRoutingTable struct"
1765 );
1766 assert!(
1767 k2k_code.contains("K2KInboxHeader"),
1768 "Should have K2KInboxHeader struct"
1769 );
1770
1771 assert!(
1773 k2k_code.contains("__device__ inline int k2k_send"),
1774 "Should have k2k_send function"
1775 );
1776 assert!(
1777 k2k_code.contains("__device__ inline int k2k_has_message"),
1778 "Should have k2k_has_message function"
1779 );
1780 assert!(
1781 k2k_code.contains("__device__ inline void* k2k_try_recv"),
1782 "Should have k2k_try_recv function"
1783 );
1784 assert!(
1785 k2k_code.contains("__device__ inline void* k2k_peek"),
1786 "Should have k2k_peek function"
1787 );
1788 assert!(
1789 k2k_code.contains("__device__ inline unsigned int k2k_pending_count"),
1790 "Should have k2k_pending_count function"
1791 );
1792
1793 println!("K2K code:\n{}", k2k_code);
1794 }
1795
1796 #[test]
1797 fn test_full_k2k_kernel() {
1798 let config = RingKernelConfig::new("k2k_processor")
1799 .with_block_size(128)
1800 .with_k2k(true)
1801 .with_hlc(true);
1802
1803 let kernel = config.generate_kernel_wrapper("// K2K handler code");
1804
1805 assert!(
1807 kernel.contains("K2KRoutingTable"),
1808 "Should have K2KRoutingTable"
1809 );
1810 assert!(kernel.contains("K2KRoute"), "Should have K2KRoute struct");
1811 assert!(
1812 kernel.contains("K2KInboxHeader"),
1813 "Should have K2KInboxHeader"
1814 );
1815 assert!(
1816 kernel.contains("k2k_routes"),
1817 "Should have k2k_routes param"
1818 );
1819 assert!(kernel.contains("k2k_inbox"), "Should have k2k_inbox param");
1820 assert!(
1821 kernel.contains("k2k_outbox"),
1822 "Should have k2k_outbox param"
1823 );
1824 assert!(kernel.contains("k2k_send"), "Should have k2k_send function");
1825 assert!(
1826 kernel.contains("k2k_try_recv"),
1827 "Should have k2k_try_recv function"
1828 );
1829
1830 println!("Full K2K kernel:\n{}", kernel);
1831 }
1832
1833 #[test]
1834 fn test_k2k_intrinsic_requirements() {
1835 assert!(RingKernelIntrinsic::K2kSend.requires_k2k());
1836 assert!(RingKernelIntrinsic::K2kTryRecv.requires_k2k());
1837 assert!(RingKernelIntrinsic::K2kHasMessage.requires_k2k());
1838 assert!(RingKernelIntrinsic::K2kPeek.requires_k2k());
1839 assert!(RingKernelIntrinsic::K2kPendingCount.requires_k2k());
1840
1841 assert!(!RingKernelIntrinsic::HlcTick.requires_k2k());
1842 assert!(!RingKernelIntrinsic::IsActive.requires_k2k());
1843 }
1844
1845 #[test]
1848 fn test_kernel_reduction_config_default() {
1849 let config = KernelReductionConfig::new();
1850 assert!(!config.enabled);
1851 assert_eq!(config.op, ReductionOp::Sum);
1852 assert_eq!(config.accumulator_type, "double");
1853 assert!(config.use_cooperative);
1854 }
1855
1856 #[test]
1857 fn test_kernel_reduction_config_builder() {
1858 let config = KernelReductionConfig::new()
1859 .with_op(ReductionOp::Max)
1860 .with_type("float")
1861 .with_cooperative(false)
1862 .with_shared_name("my_shared")
1863 .with_accumulator_name("my_accumulator");
1864
1865 assert!(config.enabled);
1866 assert_eq!(config.op, ReductionOp::Max);
1867 assert_eq!(config.accumulator_type, "float");
1868 assert!(!config.use_cooperative);
1869 assert_eq!(config.shared_array_name, "my_shared");
1870 assert_eq!(config.accumulator_name, "my_accumulator");
1871 }
1872
1873 #[test]
1874 fn test_kernel_reduction_shared_declaration() {
1875 let config = KernelReductionConfig::new()
1876 .with_op(ReductionOp::Sum)
1877 .with_type("double")
1878 .with_shared_name("reduction_shared");
1879
1880 let decl = config.generate_shared_declaration(256);
1881 assert!(decl.contains("__shared__"));
1882 assert!(decl.contains("double"));
1883 assert!(decl.contains("reduction_shared"));
1884 assert!(decl.contains("[256]"));
1885 }
1886
1887 #[test]
1888 fn test_kernel_reduction_accumulator_param() {
1889 let config = KernelReductionConfig::new()
1890 .with_op(ReductionOp::Sum)
1891 .with_type("double")
1892 .with_accumulator_name("dangling_sum");
1893
1894 let param = config.generate_accumulator_param();
1895 assert!(param.contains("double*"));
1896 assert!(param.contains("__restrict__"));
1897 assert!(param.contains("dangling_sum"));
1898 }
1899
1900 #[test]
1901 fn test_kernel_reduction_disabled_generates_empty() {
1902 let config = KernelReductionConfig::new(); assert!(config.generate_shared_declaration(256).is_empty());
1904 assert!(config.generate_accumulator_param().is_empty());
1905 }
1906
1907 #[test]
1908 fn test_ring_kernel_config_with_reduction() {
1909 let reduction = KernelReductionConfig::new()
1910 .with_op(ReductionOp::Sum)
1911 .with_type("double");
1912
1913 let config = RingKernelConfig::new("pagerank")
1914 .with_block_size(256)
1915 .with_reduction(reduction);
1916
1917 assert!(config.reduction.enabled);
1918 assert_eq!(config.reduction.op, ReductionOp::Sum);
1919 assert!(config.cooperative_groups);
1921 }
1922
1923 #[test]
1924 fn test_ring_kernel_config_with_sum_reduction() {
1925 let config = RingKernelConfig::new("pagerank")
1926 .with_block_size(256)
1927 .with_sum_reduction();
1928
1929 assert!(config.reduction.enabled);
1930 assert_eq!(config.reduction.op, ReductionOp::Sum);
1931 assert_eq!(config.reduction.accumulator_type, "double");
1932 assert!(config.cooperative_groups);
1933 }
1934
1935 #[test]
1936 fn test_ring_kernel_config_reduction_without_cooperative() {
1937 let reduction = KernelReductionConfig::new()
1938 .with_op(ReductionOp::Sum)
1939 .with_cooperative(false);
1940
1941 let config = RingKernelConfig::new("pagerank").with_reduction(reduction);
1942
1943 assert!(config.reduction.enabled);
1944 assert!(!config.cooperative_groups);
1946 }
1947
1948 #[test]
1949 fn test_reduction_op_default() {
1950 let op = ReductionOp::default();
1951 assert_eq!(op, ReductionOp::Sum);
1952 }
1953
1954 #[test]
1959 fn test_cuda_handler_info_creation() {
1960 let handler = CudaHandlerInfo::new(1, "handle_fraud_check")
1961 .with_message_type("FraudCheckRequest", 1001)
1962 .with_response();
1963
1964 assert_eq!(handler.handler_id, 1);
1965 assert_eq!(handler.func_name, "handle_fraud_check");
1966 assert_eq!(handler.message_type_name, "FraudCheckRequest");
1967 assert_eq!(handler.message_type_id, 1001);
1968 assert!(handler.produces_response);
1969 assert!(handler.cuda_body.is_none());
1970 }
1971
1972 #[test]
1973 fn test_cuda_handler_info_with_body() {
1974 let handler = CudaHandlerInfo::new(2, "handle_aggregate")
1975 .with_cuda_body("float sum = 0.0f;\nfor (int i = 0; i < 10; i++) sum += data[i];");
1976
1977 assert_eq!(handler.handler_id, 2);
1978 assert!(handler.cuda_body.is_some());
1979 assert!(handler.cuda_body.as_ref().unwrap().contains("float sum"));
1980 }
1981
1982 #[test]
1983 fn test_cuda_dispatch_table_creation() {
1984 let table = CudaDispatchTable::new()
1985 .with_handler(CudaHandlerInfo::new(1, "handle_a"))
1986 .with_handler(CudaHandlerInfo::new(2, "handle_b"));
1987
1988 assert_eq!(table.len(), 2);
1989 assert!(!table.is_empty());
1990 assert_eq!(table.handlers()[0].handler_id, 1);
1991 assert_eq!(table.handlers()[1].handler_id, 2);
1992 }
1993
1994 #[test]
1995 fn test_generate_handler_dispatch_code_empty() {
1996 let table = CudaDispatchTable::new();
1997 let code = generate_handler_dispatch_code(&table, " ");
1998
1999 assert!(code.contains("No handlers registered"));
2000 assert!(!code.contains("switch"));
2001 }
2002
2003 #[test]
2004 fn test_generate_handler_dispatch_code_single_handler() {
2005 let table = CudaDispatchTable::new().with_handler(
2006 CudaHandlerInfo::new(1, "handle_fraud").with_message_type("FraudCheck", 1001),
2007 );
2008
2009 let code = generate_handler_dispatch_code(&table, " ");
2010
2011 assert!(code.contains("switch (handler_id)"));
2012 assert!(code.contains("case 1:"));
2013 assert!(code.contains("handle_fraud(msg, state)"));
2014 assert!(code.contains("default:"));
2015 assert!(code.contains("unknown_handler_count"));
2016 }
2017
2018 #[test]
2019 fn test_generate_handler_dispatch_code_multiple_handlers() {
2020 let table = CudaDispatchTable::new()
2021 .with_handler(
2022 CudaHandlerInfo::new(1, "handle_fraud").with_message_type("FraudCheck", 1001),
2023 )
2024 .with_handler(
2025 CudaHandlerInfo::new(2, "handle_aggregate")
2026 .with_message_type("Aggregate", 1002)
2027 .with_response(),
2028 )
2029 .with_handler(
2030 CudaHandlerInfo::new(5, "handle_pattern").with_message_type("Pattern", 1005),
2031 );
2032
2033 let code = generate_handler_dispatch_code(&table, " ");
2034
2035 assert!(code.contains("case 1:"));
2036 assert!(code.contains("case 2:"));
2037 assert!(code.contains("case 5:"));
2038 assert!(code.contains("handle_fraud(msg, state)"));
2039 assert!(code.contains("handle_aggregate(msg, state, response)")); assert!(code.contains("handle_pattern(msg, state)"));
2041 }
2042
2043 #[test]
2044 fn test_generate_handler_dispatch_code_with_inline_body() {
2045 let table = CudaDispatchTable::new().with_handler(
2046 CudaHandlerInfo::new(1, "inline_handler")
2047 .with_cuda_body("int result = msg->payload[0] * 2;\nresponse->result = result;"),
2048 );
2049
2050 let code = generate_handler_dispatch_code(&table, " ");
2051
2052 assert!(code.contains("case 1:"));
2053 assert!(code.contains("int result = msg->payload[0] * 2;"));
2054 assert!(code.contains("response->result = result;"));
2055 assert!(!code.contains("inline_handler(msg,"));
2057 }
2058
2059 #[test]
2060 fn test_generate_extended_h2k_message_struct() {
2061 let code = generate_extended_h2k_message_struct();
2062
2063 assert!(code.contains("struct __align__(64) ExtendedH2KMessage"));
2064 assert!(code.contains("uint32_t handler_id"));
2065 assert!(code.contains("uint32_t flags"));
2066 assert!(code.contains("uint64_t cmd_id"));
2067 assert!(code.contains("uint64_t timestamp"));
2068 assert!(code.contains("uint8_t payload[40]"));
2069 assert!(code.contains("EXT_MSG_FLAG_EXTENDED"));
2070 assert!(code.contains("EXT_MSG_FLAG_REQUIRES_RESP"));
2071 }
2072
2073 #[test]
2074 fn test_generate_multi_handler_kernel() {
2075 let config = RingKernelConfig::new("multi_handler")
2076 .with_block_size(128)
2077 .with_hlc(true);
2078
2079 let table = CudaDispatchTable::new()
2080 .with_handler(CudaHandlerInfo::new(1, "handle_fraud"))
2081 .with_handler(CudaHandlerInfo::new(2, "handle_aggregate"));
2082
2083 let code = generate_multi_handler_kernel(&config, &table);
2084
2085 assert!(code.contains("struct __align__(128) ControlBlock"));
2087 assert!(code.contains("struct __align__(64) ExtendedH2KMessage"));
2088
2089 assert!(code.contains("ring_kernel_multi_handler"));
2091
2092 assert!(code.contains("MULTI-HANDLER DISPATCH"));
2094 assert!(code.contains("switch (handler_id)"));
2095 assert!(code.contains("case 1:"));
2096 assert!(code.contains("case 2:"));
2097 }
2098}