1use crate::types::{is_ring_context_type, CudaType, TypeMapper};
42use crate::Result;
43use std::fmt::Write;
44use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
45
46#[derive(Debug, Clone)]
48pub struct HandlerParam {
49 pub name: String,
51 pub rust_type: String,
53 pub cuda_type: String,
55 pub kind: HandlerParamKind,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum HandlerParamKind {
62 Context,
64 Message,
66 MessageMut,
68 Value,
70 Slice,
72 SliceMut,
74}
75
76#[derive(Debug, Clone)]
78pub struct HandlerSignature {
79 pub name: String,
81 pub params: Vec<HandlerParam>,
83 pub return_type: Option<HandlerReturnType>,
85 pub has_context: bool,
87 pub message_param: Option<HandlerParam>,
89}
90
91#[derive(Debug, Clone)]
93pub struct HandlerReturnType {
94 pub rust_type: String,
96 pub cuda_type: String,
98 pub is_struct: bool,
100}
101
102impl HandlerSignature {
103 pub fn parse(func: &ItemFn, type_mapper: &TypeMapper) -> Result<Self> {
105 let name = func.sig.ident.to_string();
106 let mut params = Vec::new();
107 let mut has_context = false;
108 let mut message_param = None;
109
110 for param in &func.sig.inputs {
111 if let FnArg::Typed(pat_type) = param {
112 let param_name = match pat_type.pat.as_ref() {
113 Pat::Ident(ident) => ident.ident.to_string(),
114 _ => continue,
115 };
116
117 let rust_type = quote::quote!(#pat_type.ty).to_string();
118 let kind = Self::classify_param(¶m_name, &pat_type.ty);
119
120 if kind == HandlerParamKind::Context {
121 has_context = true;
122 continue; }
124
125 let cuda_type = match type_mapper.map_type(&pat_type.ty) {
126 Ok(ct) => ct.to_cuda_string(),
127 Err(_) => "void*".to_string(), };
129
130 let param = HandlerParam {
131 name: param_name,
132 rust_type,
133 cuda_type,
134 kind,
135 };
136
137 if kind == HandlerParamKind::Message || kind == HandlerParamKind::MessageMut {
138 message_param = Some(param.clone());
139 }
140
141 params.push(param);
142 }
143 }
144
145 let return_type = Self::parse_return_type(&func.sig.output, type_mapper)?;
146
147 Ok(Self {
148 name,
149 params,
150 return_type,
151 has_context,
152 message_param,
153 })
154 }
155
156 fn classify_param(name: &str, ty: &Type) -> HandlerParamKind {
158 if is_ring_context_type(ty) {
160 return HandlerParamKind::Context;
161 }
162
163 let name_lower = name.to_lowercase();
165 if name_lower == "ctx" || name_lower == "context" {
166 return HandlerParamKind::Context;
167 }
168
169 if let Type::Reference(reference) = ty {
171 let is_mut = reference.mutability.is_some();
172
173 if let Type::Slice(_) = reference.elem.as_ref() {
175 return if is_mut {
176 HandlerParamKind::SliceMut
177 } else {
178 HandlerParamKind::Slice
179 };
180 }
181
182 if name_lower == "msg" || name_lower == "message" || name_lower.starts_with("msg_") {
184 return if is_mut {
185 HandlerParamKind::MessageMut
186 } else {
187 HandlerParamKind::Message
188 };
189 }
190
191 if let Type::Path(path) = reference.elem.as_ref() {
193 let type_name = path
194 .path
195 .segments
196 .last()
197 .map(|s| s.ident.to_string())
198 .unwrap_or_default();
199
200 if type_name.ends_with("Message") || type_name.ends_with("Request") {
202 return if is_mut {
203 HandlerParamKind::MessageMut
204 } else {
205 HandlerParamKind::Message
206 };
207 }
208 }
209 }
210
211 HandlerParamKind::Value
212 }
213
214 fn parse_return_type(
216 output: &ReturnType,
217 type_mapper: &TypeMapper,
218 ) -> Result<Option<HandlerReturnType>> {
219 match output {
220 ReturnType::Default => Ok(None),
221 ReturnType::Type(_, ty) => {
222 if let Type::Tuple(tuple) = ty.as_ref() {
224 if tuple.elems.is_empty() {
225 return Ok(None);
226 }
227 }
228
229 let rust_type = quote::quote!(#ty).to_string();
230 let cuda_type = type_mapper
231 .map_type(ty)
232 .map(|ct| ct.to_cuda_string())
233 .unwrap_or_else(|_| rust_type.clone());
234
235 let is_struct = matches!(type_mapper.map_type(ty), Ok(CudaType::Struct(_)));
236
237 Ok(Some(HandlerReturnType {
238 rust_type,
239 cuda_type,
240 is_struct,
241 }))
242 }
243 }
244 }
245
246 pub fn has_response(&self) -> bool {
248 self.return_type.is_some()
249 }
250
251 pub fn extra_params(&self) -> Vec<&HandlerParam> {
253 self.params
254 .iter()
255 .filter(|p| {
256 p.kind != HandlerParamKind::Context
257 && p.kind != HandlerParamKind::Message
258 && p.kind != HandlerParamKind::MessageMut
259 })
260 .collect()
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct HandlerCodegenConfig {
267 pub message_var: String,
269 pub response_var: String,
271 pub indent: String,
273 pub generate_deser: bool,
275 pub generate_ser: bool,
277}
278
279impl Default for HandlerCodegenConfig {
280 fn default() -> Self {
281 Self {
282 message_var: "msg".to_string(),
283 response_var: "response".to_string(),
284 indent: " ".to_string(), generate_deser: true,
286 generate_ser: true,
287 }
288 }
289}
290
291pub fn generate_message_deser(message_type: &str, config: &HandlerCodegenConfig) -> String {
295 let mut code = String::new();
296 let indent = &config.indent;
297
298 writeln!(code, "{}// Deserialize message from buffer", indent).unwrap();
299 writeln!(
300 code,
301 "{}{}* {} = ({}*)msg_ptr;",
302 indent, message_type, config.message_var, message_type
303 )
304 .unwrap();
305
306 code
307}
308
309pub fn generate_response_ser(response_type: &str, config: &HandlerCodegenConfig) -> String {
313 let mut code = String::new();
314 let indent = &config.indent;
315
316 writeln!(code).unwrap();
317 writeln!(code, "{}// Serialize response to output buffer", indent).unwrap();
318 writeln!(
319 code,
320 "{}unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;",
321 indent
322 ).unwrap();
323 writeln!(
324 code,
325 "{}memcpy(&output_buffer[_out_idx * RESP_SIZE], &{}, sizeof({}));",
326 indent, config.response_var, response_type
327 )
328 .unwrap();
329
330 code
331}
332
333pub fn generate_cuda_struct(
335 name: &str,
336 fields: &[(String, String)], ) -> String {
338 let mut code = String::new();
339
340 writeln!(code, "struct {} {{", name).unwrap();
341 for (field_name, cuda_type) in fields {
342 writeln!(code, " {} {};", cuda_type, field_name).unwrap();
343 }
344 writeln!(code, "}};").unwrap();
345
346 code
347}
348
349#[derive(Debug, Clone)]
351pub struct MessageTypeInfo {
352 pub name: String,
354 pub size: usize,
356 pub fields: Vec<(String, String)>,
358}
359
360impl MessageTypeInfo {
361 pub fn simple(name: &str, value_type: &str) -> Self {
363 Self {
364 name: name.to_string(),
365 size: Self::type_size(value_type),
366 fields: vec![("value".to_string(), value_type.to_string())],
367 }
368 }
369
370 fn type_size(cuda_type: &str) -> usize {
372 match cuda_type {
373 "float" => 4,
374 "double" => 8,
375 "int" => 4,
376 "unsigned int" => 4,
377 "long long" | "unsigned long long" => 8,
378 "short" | "unsigned short" => 2,
379 "char" | "unsigned char" => 1,
380 _ => 8, }
382 }
383
384 pub fn to_cuda_struct(&self) -> String {
386 generate_cuda_struct(&self.name, &self.fields)
387 }
388}
389
390#[derive(Debug, Clone, Default)]
392pub struct MessageTypeRegistry {
393 pub messages: Vec<MessageTypeInfo>,
395 pub responses: Vec<MessageTypeInfo>,
397}
398
399impl MessageTypeRegistry {
400 pub fn new() -> Self {
402 Self::default()
403 }
404
405 pub fn register_message(&mut self, info: MessageTypeInfo) {
407 self.messages.push(info);
408 }
409
410 pub fn register_response(&mut self, info: MessageTypeInfo) {
412 self.responses.push(info);
413 }
414
415 pub fn generate_structs(&self) -> String {
417 let mut code = String::new();
418
419 for msg in &self.messages {
420 code.push_str(&msg.to_cuda_struct());
421 code.push('\n');
422 }
423
424 for resp in &self.responses {
425 code.push_str(&resp.to_cuda_struct());
426 code.push('\n');
427 }
428
429 code
430 }
431}
432
433#[derive(Debug, Clone, Copy, PartialEq, Eq)]
435pub enum ContextMethod {
436 ThreadId,
438 BlockId,
440 GlobalThreadId,
442 SyncThreads,
444 Now,
446 AtomicAdd,
448 WarpShuffle,
450 LaneId,
452 WarpId,
454}
455
456impl ContextMethod {
457 pub fn from_name(name: &str) -> Option<Self> {
459 match name {
460 "thread_id" | "thread_idx" => Some(Self::ThreadId),
461 "block_id" | "block_idx" => Some(Self::BlockId),
462 "global_thread_id" | "global_id" => Some(Self::GlobalThreadId),
463 "sync_threads" | "synchronize" => Some(Self::SyncThreads),
464 "now" | "timestamp" => Some(Self::Now),
465 "atomic_add" => Some(Self::AtomicAdd),
466 "warp_shuffle" | "shuffle" => Some(Self::WarpShuffle),
467 "lane_id" => Some(Self::LaneId),
468 "warp_id" => Some(Self::WarpId),
469 _ => None,
470 }
471 }
472
473 pub fn to_cuda(&self, args: &[String]) -> String {
475 match self {
476 Self::ThreadId => "threadIdx.x".to_string(),
477 Self::BlockId => "blockIdx.x".to_string(),
478 Self::GlobalThreadId => "(blockIdx.x * blockDim.x + threadIdx.x)".to_string(),
479 Self::SyncThreads => "__syncthreads()".to_string(),
480 Self::Now => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
481 Self::AtomicAdd => {
482 if args.len() >= 2 {
483 format!("atomicAdd({}, {})", args[0], args[1])
484 } else {
485 "/* atomic_add requires ptr and val */".to_string()
486 }
487 }
488 Self::WarpShuffle => {
489 if args.len() >= 2 {
490 format!("__shfl_sync(0xFFFFFFFF, {}, {})", args[0], args[1])
491 } else {
492 "/* warp_shuffle requires val and lane */".to_string()
493 }
494 }
495 Self::LaneId => "(threadIdx.x % 32)".to_string(),
496 Self::WarpId => "(threadIdx.x / 32)".to_string(),
497 }
498 }
499
500 pub fn is_statement(&self) -> bool {
502 matches!(self, Self::SyncThreads)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use syn::parse_quote;
510
511 #[test]
512 fn test_handler_signature_simple() {
513 let func: ItemFn = parse_quote! {
514 fn process(value: f32) -> f32 {
515 value * 2.0
516 }
517 };
518
519 let mapper = TypeMapper::new();
520 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
521
522 assert_eq!(sig.name, "process");
523 assert_eq!(sig.params.len(), 1);
524 assert_eq!(sig.params[0].name, "value");
525 assert_eq!(sig.params[0].kind, HandlerParamKind::Value);
526 assert!(sig.has_response());
527 assert!(!sig.has_context);
528 }
529
530 #[test]
531 fn test_handler_signature_with_context() {
532 let func: ItemFn = parse_quote! {
533 fn handle(ctx: &RingContext, msg: &Message) -> Response {
534 Response { value: msg.value * 2.0 }
535 }
536 };
537
538 let mapper = TypeMapper::new();
539 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
540
541 assert_eq!(sig.name, "handle");
542 assert!(sig.has_context);
543 assert!(sig.message_param.is_some());
544 assert_eq!(sig.message_param.as_ref().unwrap().name, "msg");
545 assert!(sig.has_response());
546 }
547
548 #[test]
549 fn test_handler_signature_no_response() {
550 let func: ItemFn = parse_quote! {
551 fn fire_and_forget(msg: &Event) {
552 process(msg);
553 }
554 };
555
556 let mapper = TypeMapper::new();
557 let sig = HandlerSignature::parse(&func, &mapper).unwrap();
558
559 assert!(!sig.has_response());
560 }
561
562 #[test]
563 fn test_message_deser_generation() {
564 let config = HandlerCodegenConfig::default();
565 let code = generate_message_deser("MyMessage", &config);
566
567 assert!(code.contains("MyMessage* msg"));
568 assert!(code.contains("(MyMessage*)msg_ptr"));
569 }
570
571 #[test]
572 fn test_response_ser_generation() {
573 let config = HandlerCodegenConfig::default();
574 let code = generate_response_ser("MyResponse", &config);
575
576 assert!(code.contains("memcpy"));
577 assert!(code.contains("output_buffer"));
578 assert!(code.contains("RESP_SIZE"));
579 }
580
581 #[test]
582 fn test_cuda_struct_generation() {
583 let fields = vec![
584 ("id".to_string(), "unsigned long long".to_string()),
585 ("value".to_string(), "float".to_string()),
586 ];
587
588 let code = generate_cuda_struct("MyMessage", &fields);
589
590 assert!(code.contains("struct MyMessage"));
591 assert!(code.contains("unsigned long long id"));
592 assert!(code.contains("float value"));
593 }
594
595 #[test]
596 fn test_message_type_info() {
597 let info = MessageTypeInfo::simple("FloatMsg", "float");
598 assert_eq!(info.name, "FloatMsg");
599 assert_eq!(info.size, 4);
600 assert_eq!(info.fields.len(), 1);
601
602 let cuda = info.to_cuda_struct();
603 assert!(cuda.contains("struct FloatMsg"));
604 assert!(cuda.contains("float value"));
605 }
606
607 #[test]
608 fn test_context_method_lookup() {
609 assert_eq!(
610 ContextMethod::from_name("thread_id"),
611 Some(ContextMethod::ThreadId)
612 );
613 assert_eq!(
614 ContextMethod::from_name("sync_threads"),
615 Some(ContextMethod::SyncThreads)
616 );
617 assert_eq!(
618 ContextMethod::from_name("global_thread_id"),
619 Some(ContextMethod::GlobalThreadId)
620 );
621 assert_eq!(ContextMethod::from_name("unknown"), None);
622 }
623
624 #[test]
625 fn test_context_method_cuda_output() {
626 assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
627 assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
628 assert_eq!(
629 ContextMethod::GlobalThreadId.to_cuda(&[]),
630 "(blockIdx.x * blockDim.x + threadIdx.x)"
631 );
632
633 let args = vec!["ptr".to_string(), "1".to_string()];
634 assert_eq!(ContextMethod::AtomicAdd.to_cuda(&args), "atomicAdd(ptr, 1)");
635 }
636
637 #[test]
638 fn test_message_type_registry() {
639 let mut registry = MessageTypeRegistry::new();
640
641 registry.register_message(MessageTypeInfo {
642 name: "Request".to_string(),
643 size: 8,
644 fields: vec![("value".to_string(), "float".to_string())],
645 });
646
647 registry.register_response(MessageTypeInfo {
648 name: "Response".to_string(),
649 size: 8,
650 fields: vec![("result".to_string(), "float".to_string())],
651 });
652
653 let structs = registry.generate_structs();
654 assert!(structs.contains("struct Request"));
655 assert!(structs.contains("struct Response"));
656 }
657
658 #[test]
659 fn test_handler_param_kind_classification() {
660 let ctx_ty: Type = parse_quote!(&RingContext);
661 assert_eq!(
662 HandlerSignature::classify_param("ctx", &ctx_ty),
663 HandlerParamKind::Context
664 );
665
666 let msg_ty: Type = parse_quote!(&MyMessage);
667 assert_eq!(
668 HandlerSignature::classify_param("msg", &msg_ty),
669 HandlerParamKind::Message
670 );
671
672 let slice_ty: Type = parse_quote!(&[f32]);
673 assert_eq!(
674 HandlerSignature::classify_param("data", &slice_ty),
675 HandlerParamKind::Slice
676 );
677
678 let value_ty: Type = parse_quote!(f32);
679 assert_eq!(
680 HandlerSignature::classify_param("value", &value_ty),
681 HandlerParamKind::Value
682 );
683 }
684}