1use crate::types::{is_ring_context_type, CudaType, TypeMapper};
42use crate::Result;
43use std::fmt::Write;
44use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
45
46#[derive(Debug, Clone)]
48pub struct HandlerParam {
49 pub name: String,
51 pub rust_type: String,
53 pub cuda_type: String,
55 pub kind: HandlerParamKind,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum HandlerParamKind {
62 Context,
64 Message,
66 MessageMut,
68 Value,
70 Slice,
72 SliceMut,
74}
75
76#[derive(Debug, Clone)]
78pub struct HandlerSignature {
79 pub name: String,
81 pub params: Vec<HandlerParam>,
83 pub return_type: Option<HandlerReturnType>,
85 pub has_context: bool,
87 pub message_param: Option<HandlerParam>,
89}
90
91#[derive(Debug, Clone)]
93pub struct HandlerReturnType {
94 pub rust_type: String,
96 pub cuda_type: String,
98 pub is_struct: bool,
100}
101
102impl HandlerSignature {
103 pub fn parse(func: &ItemFn, type_mapper: &TypeMapper) -> Result<Self> {
105 let name = func.sig.ident.to_string();
106 let mut params = Vec::new();
107 let mut has_context = false;
108 let mut message_param = None;
109
110 for param in &func.sig.inputs {
111 if let FnArg::Typed(pat_type) = param {
112 let param_name = match pat_type.pat.as_ref() {
113 Pat::Ident(ident) => ident.ident.to_string(),
114 _ => continue,
115 };
116
117 let ty = &pat_type.ty;
118 let rust_type = quote::quote!(#ty).to_string();
119 let kind = Self::classify_param(¶m_name, &pat_type.ty);
120
121 if kind == HandlerParamKind::Context {
122 has_context = true;
123 continue; }
125
126 let cuda_type = match type_mapper.map_type(&pat_type.ty) {
127 Ok(ct) => ct.to_cuda_string(),
128 Err(_) => "void*".to_string(), };
130
131 let param = HandlerParam {
132 name: param_name,
133 rust_type,
134 cuda_type,
135 kind,
136 };
137
138 if kind == HandlerParamKind::Message || kind == HandlerParamKind::MessageMut {
139 message_param = Some(param.clone());
140 }
141
142 params.push(param);
143 }
144 }
145
146 let return_type = Self::parse_return_type(&func.sig.output, type_mapper)?;
147
148 Ok(Self {
149 name,
150 params,
151 return_type,
152 has_context,
153 message_param,
154 })
155 }
156
157 fn classify_param(name: &str, ty: &Type) -> HandlerParamKind {
159 if is_ring_context_type(ty) {
161 return HandlerParamKind::Context;
162 }
163
164 let name_lower = name.to_lowercase();
166 if name_lower == "ctx" || name_lower == "context" {
167 return HandlerParamKind::Context;
168 }
169
170 if let Type::Reference(reference) = ty {
172 let is_mut = reference.mutability.is_some();
173
174 if let Type::Slice(_) = reference.elem.as_ref() {
176 return if is_mut {
177 HandlerParamKind::SliceMut
178 } else {
179 HandlerParamKind::Slice
180 };
181 }
182
183 if name_lower == "msg" || name_lower == "message" || name_lower.starts_with("msg_") {
185 return if is_mut {
186 HandlerParamKind::MessageMut
187 } else {
188 HandlerParamKind::Message
189 };
190 }
191
192 if let Type::Path(path) = reference.elem.as_ref() {
194 let type_name = path
195 .path
196 .segments
197 .last()
198 .map(|s| s.ident.to_string())
199 .unwrap_or_default();
200
201 if type_name.ends_with("Message") || type_name.ends_with("Request") {
203 return if is_mut {
204 HandlerParamKind::MessageMut
205 } else {
206 HandlerParamKind::Message
207 };
208 }
209 }
210 }
211
212 HandlerParamKind::Value
213 }
214
215 fn parse_return_type(
217 output: &ReturnType,
218 type_mapper: &TypeMapper,
219 ) -> Result<Option<HandlerReturnType>> {
220 match output {
221 ReturnType::Default => Ok(None),
222 ReturnType::Type(_, ty) => {
223 if let Type::Tuple(tuple) = ty.as_ref() {
225 if tuple.elems.is_empty() {
226 return Ok(None);
227 }
228 }
229
230 let rust_type = quote::quote!(#ty).to_string();
231 let cuda_type = type_mapper
232 .map_type(ty)
233 .map(|ct| ct.to_cuda_string())
234 .unwrap_or_else(|_| rust_type.clone());
235
236 let is_struct = matches!(type_mapper.map_type(ty), Ok(CudaType::Struct(_)));
237
238 Ok(Some(HandlerReturnType {
239 rust_type,
240 cuda_type,
241 is_struct,
242 }))
243 }
244 }
245 }
246
247 pub fn has_response(&self) -> bool {
249 self.return_type.is_some()
250 }
251
252 pub fn extra_params(&self) -> Vec<&HandlerParam> {
254 self.params
255 .iter()
256 .filter(|p| {
257 p.kind != HandlerParamKind::Context
258 && p.kind != HandlerParamKind::Message
259 && p.kind != HandlerParamKind::MessageMut
260 })
261 .collect()
262 }
263}
264
265#[derive(Debug, Clone)]
267pub struct HandlerCodegenConfig {
268 pub message_var: String,
270 pub response_var: String,
272 pub indent: String,
274 pub generate_deser: bool,
276 pub generate_ser: bool,
278}
279
280impl Default for HandlerCodegenConfig {
281 fn default() -> Self {
282 Self {
283 message_var: "msg".to_string(),
284 response_var: "response".to_string(),
285 indent: " ".to_string(), generate_deser: true,
287 generate_ser: true,
288 }
289 }
290}
291
292pub fn generate_message_deser(message_type: &str, config: &HandlerCodegenConfig) -> String {
297 let mut code = String::new();
298 let indent = &config.indent;
299
300 writeln!(code, "{}// Deserialize message from buffer", indent).unwrap();
301 writeln!(
302 code,
303 "{}// msg_ptr points to payload data (after MessageHeader when using envelopes)",
304 indent
305 )
306 .unwrap();
307 writeln!(
308 code,
309 "{}{}* {} = ({}*)msg_ptr;",
310 indent, message_type, config.message_var, message_type
311 )
312 .unwrap();
313
314 code
315}
316
317pub fn generate_envelope_message_deser(
321 message_type: &str,
322 config: &HandlerCodegenConfig,
323) -> String {
324 let mut code = String::new();
325 let indent = &config.indent;
326
327 writeln!(code, "{}// Message envelope deserialization", indent).unwrap();
328 writeln!(
329 code,
330 "{}// msg_header provides: message_id, correlation_id, source_kernel, timestamp, etc.",
331 indent
332 )
333 .unwrap();
334 writeln!(
335 code,
336 "{}{}* {} = ({}*)msg_ptr; // Typed payload",
337 indent, message_type, config.message_var, message_type
338 )
339 .unwrap();
340 writeln!(code).unwrap();
341 writeln!(code, "{}// Access header fields:", indent).unwrap();
342 writeln!(
343 code,
344 "{}// - msg_header->message_id (unique message ID)",
345 indent
346 )
347 .unwrap();
348 writeln!(
349 code,
350 "{}// - msg_header->correlation_id (for request-response matching)",
351 indent
352 )
353 .unwrap();
354 writeln!(
355 code,
356 "{}// - msg_header->source_kernel (sender kernel ID, 0 = host)",
357 indent
358 )
359 .unwrap();
360 writeln!(
361 code,
362 "{}// - msg_header->timestamp (HLC timestamp of send)",
363 indent
364 )
365 .unwrap();
366
367 code
368}
369
370pub fn generate_response_ser(response_type: &str, config: &HandlerCodegenConfig) -> String {
374 let mut code = String::new();
375 let indent = &config.indent;
376
377 writeln!(code).unwrap();
378 writeln!(
379 code,
380 "{}// Serialize response to output buffer (raw format)",
381 indent
382 )
383 .unwrap();
384 writeln!(
385 code,
386 "{}unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;",
387 indent
388 ).unwrap();
389 writeln!(
390 code,
391 "{}memcpy(&output_buffer[_out_idx * RESP_SIZE], &{}, sizeof({}));",
392 indent, config.response_var, response_type
393 )
394 .unwrap();
395
396 code
397}
398
399pub fn generate_envelope_response_ser(
404 response_type: &str,
405 config: &HandlerCodegenConfig,
406) -> String {
407 let mut code = String::new();
408 let indent = &config.indent;
409
410 writeln!(code).unwrap();
411 writeln!(
412 code,
413 "{}// Serialize response envelope to output buffer",
414 indent
415 )
416 .unwrap();
417 writeln!(
418 code,
419 "{}unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;",
420 indent
421 ).unwrap();
422 writeln!(
423 code,
424 "{}unsigned char* resp_envelope = &output_buffer[_out_idx * RESP_SIZE];",
425 indent
426 )
427 .unwrap();
428 writeln!(code).unwrap();
429 writeln!(
430 code,
431 "{}// Build response header from request header",
432 indent
433 )
434 .unwrap();
435 writeln!(
436 code,
437 "{}MessageHeader* resp_header = (MessageHeader*)resp_envelope;",
438 indent
439 )
440 .unwrap();
441 writeln!(code, "{}message_create_response_header(", indent).unwrap();
442 writeln!(code, "{} resp_header,", indent).unwrap();
443 writeln!(
444 code,
445 "{} msg_header, // Request header (for correlation)",
446 indent
447 )
448 .unwrap();
449 writeln!(
450 code,
451 "{} KERNEL_ID, // This kernel's ID",
452 indent
453 )
454 .unwrap();
455 writeln!(
456 code,
457 "{} sizeof({}), // Payload size",
458 indent, response_type
459 )
460 .unwrap();
461 writeln!(
462 code,
463 "{} hlc_physical, // Current HLC",
464 indent
465 )
466 .unwrap();
467 writeln!(code, "{} hlc_logical,", indent).unwrap();
468 writeln!(code, "{} HLC_NODE_ID", indent).unwrap();
469 writeln!(code, "{});", indent).unwrap();
470 writeln!(code).unwrap();
471 writeln!(code, "{}// Copy response payload after header", indent).unwrap();
472 writeln!(
473 code,
474 "{}memcpy(resp_envelope + MESSAGE_HEADER_SIZE, &{}, sizeof({}));",
475 indent, config.response_var, response_type
476 )
477 .unwrap();
478 writeln!(code).unwrap();
479 writeln!(
480 code,
481 "{}__threadfence(); // Ensure write is visible",
482 indent
483 )
484 .unwrap();
485
486 code
487}
488
489pub fn generate_cuda_struct(
491 name: &str,
492 fields: &[(String, String)], ) -> String {
494 let mut code = String::new();
495
496 writeln!(code, "struct {} {{", name).unwrap();
497 for (field_name, cuda_type) in fields {
498 writeln!(code, " {} {};", cuda_type, field_name).unwrap();
499 }
500 writeln!(code, "}};").unwrap();
501
502 code
503}
504
505#[derive(Debug, Clone)]
507pub struct MessageTypeInfo {
508 pub name: String,
510 pub size: usize,
512 pub fields: Vec<(String, String)>,
514}
515
516impl MessageTypeInfo {
517 pub fn simple(name: &str, value_type: &str) -> Self {
519 Self {
520 name: name.to_string(),
521 size: Self::type_size(value_type),
522 fields: vec![("value".to_string(), value_type.to_string())],
523 }
524 }
525
526 fn type_size(cuda_type: &str) -> usize {
528 match cuda_type {
529 "float" => 4,
530 "double" => 8,
531 "int" => 4,
532 "unsigned int" => 4,
533 "long long" | "unsigned long long" => 8,
534 "short" | "unsigned short" => 2,
535 "char" | "unsigned char" => 1,
536 _ => 8, }
538 }
539
540 pub fn to_cuda_struct(&self) -> String {
542 generate_cuda_struct(&self.name, &self.fields)
543 }
544}
545
546#[derive(Debug, Clone, Default)]
548pub struct MessageTypeRegistry {
549 pub messages: Vec<MessageTypeInfo>,
551 pub responses: Vec<MessageTypeInfo>,
553}
554
555impl MessageTypeRegistry {
556 pub fn new() -> Self {
558 Self::default()
559 }
560
561 pub fn register_message(&mut self, info: MessageTypeInfo) {
563 self.messages.push(info);
564 }
565
566 pub fn register_response(&mut self, info: MessageTypeInfo) {
568 self.responses.push(info);
569 }
570
571 pub fn generate_structs(&self) -> String {
573 let mut code = String::new();
574
575 for msg in &self.messages {
576 code.push_str(&msg.to_cuda_struct());
577 code.push('\n');
578 }
579
580 for resp in &self.responses {
581 code.push_str(&resp.to_cuda_struct());
582 code.push('\n');
583 }
584
585 code
586 }
587}
588
589#[derive(Debug, Clone, Copy, PartialEq, Eq)]
591pub enum ContextMethod {
592 ThreadId,
594 BlockId,
596 GlobalThreadId,
598 SyncThreads,
600 Now,
602 AtomicAdd,
604 WarpShuffle,
606 LaneId,
608 WarpId,
610}
611
612impl ContextMethod {
613 pub fn from_name(name: &str) -> Option<Self> {
615 match name {
616 "thread_id" | "thread_idx" => Some(Self::ThreadId),
617 "block_id" | "block_idx" => Some(Self::BlockId),
618 "global_thread_id" | "global_id" => Some(Self::GlobalThreadId),
619 "sync_threads" | "synchronize" => Some(Self::SyncThreads),
620 "now" | "timestamp" => Some(Self::Now),
621 "atomic_add" => Some(Self::AtomicAdd),
622 "warp_shuffle" | "shuffle" => Some(Self::WarpShuffle),
623 "lane_id" => Some(Self::LaneId),
624 "warp_id" => Some(Self::WarpId),
625 _ => None,
626 }
627 }
628
629 pub fn to_cuda(&self, args: &[String]) -> String {
631 match self {
632 Self::ThreadId => "threadIdx.x".to_string(),
633 Self::BlockId => "blockIdx.x".to_string(),
634 Self::GlobalThreadId => "(blockIdx.x * blockDim.x + threadIdx.x)".to_string(),
635 Self::SyncThreads => "__syncthreads()".to_string(),
636 Self::Now => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
637 Self::AtomicAdd => {
638 if args.len() >= 2 {
639 format!("atomicAdd({}, {})", args[0], args[1])
640 } else {
641 "/* atomic_add requires ptr and val */".to_string()
642 }
643 }
644 Self::WarpShuffle => {
645 if args.len() >= 2 {
646 format!("__shfl_sync(0xFFFFFFFF, {}, {})", args[0], args[1])
647 } else {
648 "/* warp_shuffle requires val and lane */".to_string()
649 }
650 }
651 Self::LaneId => "(threadIdx.x % 32)".to_string(),
652 Self::WarpId => "(threadIdx.x / 32)".to_string(),
653 }
654 }
655
656 pub fn is_statement(&self) -> bool {
658 matches!(self, Self::SyncThreads)
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use syn::parse_quote;
666
667 #[test]
668 fn test_handler_signature_simple() {
669 let func: ItemFn = parse_quote! {
670 fn process(value: f32) -> f32 {
671 value * 2.0
672 }
673 };
674
675 let mapper = TypeMapper::new();
676 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
677
678 assert_eq!(sig.name, "process");
679 assert_eq!(sig.params.len(), 1);
680 assert_eq!(sig.params[0].name, "value");
681 assert_eq!(sig.params[0].kind, HandlerParamKind::Value);
682 assert!(sig.has_response());
683 assert!(!sig.has_context);
684 }
685
686 #[test]
687 fn test_handler_signature_with_context() {
688 let func: ItemFn = parse_quote! {
689 fn handle(ctx: &RingContext, msg: &Message) -> Response {
690 Response { value: msg.value * 2.0 }
691 }
692 };
693
694 let mapper = TypeMapper::new();
695 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
696
697 assert_eq!(sig.name, "handle");
698 assert!(sig.has_context);
699 assert!(sig.message_param.is_some());
700 assert_eq!(sig.message_param.as_ref().unwrap().name, "msg");
701 assert!(sig.has_response());
702 }
703
704 #[test]
705 fn test_handler_signature_no_response() {
706 let func: ItemFn = parse_quote! {
707 fn fire_and_forget(msg: &Event) {
708 process(msg);
709 }
710 };
711
712 let mapper = TypeMapper::new();
713 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
714
715 assert!(!sig.has_response());
716 }
717
718 #[test]
719 fn test_message_deser_generation() {
720 let config = HandlerCodegenConfig::default();
721 let code = generate_message_deser("MyMessage", &config);
722
723 assert!(code.contains("MyMessage* msg"));
724 assert!(code.contains("(MyMessage*)msg_ptr"));
725 }
726
727 #[test]
728 fn test_response_ser_generation() {
729 let config = HandlerCodegenConfig::default();
730 let code = generate_response_ser("MyResponse", &config);
731
732 assert!(code.contains("memcpy"));
733 assert!(code.contains("output_buffer"));
734 assert!(code.contains("RESP_SIZE"));
735 }
736
737 #[test]
738 fn test_cuda_struct_generation() {
739 let fields = vec![
740 ("id".to_string(), "unsigned long long".to_string()),
741 ("value".to_string(), "float".to_string()),
742 ];
743
744 let code = generate_cuda_struct("MyMessage", &fields);
745
746 assert!(code.contains("struct MyMessage"));
747 assert!(code.contains("unsigned long long id"));
748 assert!(code.contains("float value"));
749 }
750
751 #[test]
752 fn test_message_type_info() {
753 let info = MessageTypeInfo::simple("FloatMsg", "float");
754 assert_eq!(info.name, "FloatMsg");
755 assert_eq!(info.size, 4);
756 assert_eq!(info.fields.len(), 1);
757
758 let cuda = info.to_cuda_struct();
759 assert!(cuda.contains("struct FloatMsg"));
760 assert!(cuda.contains("float value"));
761 }
762
763 #[test]
764 fn test_context_method_lookup() {
765 assert_eq!(
766 ContextMethod::from_name("thread_id"),
767 Some(ContextMethod::ThreadId)
768 );
769 assert_eq!(
770 ContextMethod::from_name("sync_threads"),
771 Some(ContextMethod::SyncThreads)
772 );
773 assert_eq!(
774 ContextMethod::from_name("global_thread_id"),
775 Some(ContextMethod::GlobalThreadId)
776 );
777 assert_eq!(ContextMethod::from_name("unknown"), None);
778 }
779
780 #[test]
781 fn test_context_method_cuda_output() {
782 assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
783 assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
784 assert_eq!(
785 ContextMethod::GlobalThreadId.to_cuda(&[]),
786 "(blockIdx.x * blockDim.x + threadIdx.x)"
787 );
788
789 let args = vec!["ptr".to_string(), "1".to_string()];
790 assert_eq!(ContextMethod::AtomicAdd.to_cuda(&args), "atomicAdd(ptr, 1)");
791 }
792
793 #[test]
794 fn test_message_type_registry() {
795 let mut registry = MessageTypeRegistry::new();
796
797 registry.register_message(MessageTypeInfo {
798 name: "Request".to_string(),
799 size: 8,
800 fields: vec![("value".to_string(), "float".to_string())],
801 });
802
803 registry.register_response(MessageTypeInfo {
804 name: "Response".to_string(),
805 size: 8,
806 fields: vec![("result".to_string(), "float".to_string())],
807 });
808
809 let structs = registry.generate_structs();
810 assert!(structs.contains("struct Request"));
811 assert!(structs.contains("struct Response"));
812 }
813
814 #[test]
815 fn test_handler_param_kind_classification() {
816 let ctx_ty: Type = parse_quote!(&RingContext);
817 assert_eq!(
818 HandlerSignature::classify_param("ctx", &ctx_ty),
819 HandlerParamKind::Context
820 );
821
822 let msg_ty: Type = parse_quote!(&MyMessage);
823 assert_eq!(
824 HandlerSignature::classify_param("msg", &msg_ty),
825 HandlerParamKind::Message
826 );
827
828 let slice_ty: Type = parse_quote!(&[f32]);
829 assert_eq!(
830 HandlerSignature::classify_param("data", &slice_ty),
831 HandlerParamKind::Slice
832 );
833
834 let value_ty: Type = parse_quote!(f32);
835 assert_eq!(
836 HandlerSignature::classify_param("value", &value_ty),
837 HandlerParamKind::Value
838 );
839 }
840}