Skip to main content

ringkernel_cuda_codegen/
handler.rs

1//! Handler function integration for ring kernel transpilation.
2//!
3//! This module provides utilities for parsing handler function signatures,
4//! extracting message/response types, and generating the code that binds
5//! handlers to the ring kernel message loop.
6//!
7//! # Handler Signature Patterns
8//!
9//! Ring kernel handlers follow specific patterns:
10//!
11//! ```ignore
12//! // Pattern 1: Message in, response out
13//! fn handle(ctx: &RingContext, msg: &MyMessage) -> MyResponse { ... }
14//!
15//! // Pattern 2: Message in, no response (fire-and-forget)
16//! fn handle(ctx: &RingContext, msg: &MyMessage) { ... }
17//!
18//! // Pattern 3: Simple value processing
19//! fn process(value: f32) -> f32 { ... }
20//! ```
21//!
22//! # Generated Code
23//!
24//! The handler body is embedded within the message loop with proper
25//! message deserialization and response serialization:
26//!
27//! ```cuda
28//! // Message pointer from buffer
29//! MyMessage* msg = (MyMessage*)&input_buffer[msg_idx * MSG_SIZE];
30//!
31//! // === Handler body (transpiled) ===
32//! float result = msg->value * 2.0f;
33//! MyResponse response;
34//! response.value = result;
35//! // ================================
36//!
37//! // Enqueue response
38//! memcpy(&output_buffer[out_idx * RESP_SIZE], &response, sizeof(MyResponse));
39//! ```
40
41use crate::types::{is_ring_context_type, CudaType, TypeMapper};
42use crate::Result;
43use std::fmt::Write;
44use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
45
46/// Information about a handler function parameter.
47#[derive(Debug, Clone)]
48pub struct HandlerParam {
49    /// Parameter name.
50    pub name: String,
51    /// Parameter type (Rust).
52    pub rust_type: String,
53    /// Parameter type (CUDA).
54    pub cuda_type: String,
55    /// Kind of parameter.
56    pub kind: HandlerParamKind,
57}
58
59/// Kinds of handler parameters.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum HandlerParamKind {
62    /// RingContext reference (removed in transpilation).
63    Context,
64    /// Message reference (deserialized from input buffer).
65    Message,
66    /// Mutable message reference.
67    MessageMut,
68    /// Regular value parameter.
69    Value,
70    /// Slice parameter (pointer in CUDA).
71    Slice,
72    /// Mutable slice parameter.
73    SliceMut,
74}
75
76/// Parsed handler function signature.
77#[derive(Debug, Clone)]
78pub struct HandlerSignature {
79    /// Handler function name.
80    pub name: String,
81    /// All parameters.
82    pub params: Vec<HandlerParam>,
83    /// Return type (if any).
84    pub return_type: Option<HandlerReturnType>,
85    /// Whether the handler has a RingContext parameter.
86    pub has_context: bool,
87    /// The message parameter (if any).
88    pub message_param: Option<HandlerParam>,
89}
90
91/// Handler return type information.
92#[derive(Debug, Clone)]
93pub struct HandlerReturnType {
94    /// Rust type name.
95    pub rust_type: String,
96    /// CUDA type string.
97    pub cuda_type: String,
98    /// Whether this is a struct type (vs primitive).
99    pub is_struct: bool,
100}
101
102impl HandlerSignature {
103    /// Parse a handler function signature.
104    pub fn parse(func: &ItemFn, type_mapper: &TypeMapper) -> Result<Self> {
105        let name = func.sig.ident.to_string();
106        let mut params = Vec::new();
107        let mut has_context = false;
108        let mut message_param = None;
109
110        for param in &func.sig.inputs {
111            if let FnArg::Typed(pat_type) = param {
112                let param_name = match pat_type.pat.as_ref() {
113                    Pat::Ident(ident) => ident.ident.to_string(),
114                    _ => continue,
115                };
116
117                let ty = &pat_type.ty;
118                let rust_type = quote::quote!(#ty).to_string();
119                let kind = Self::classify_param(&param_name, &pat_type.ty);
120
121                if kind == HandlerParamKind::Context {
122                    has_context = true;
123                    continue; // Skip context params in output
124                }
125
126                let cuda_type = match type_mapper.map_type(&pat_type.ty) {
127                    Ok(ct) => ct.to_cuda_string(),
128                    Err(_) => "void*".to_string(), // Fallback for unknown types
129                };
130
131                let param = HandlerParam {
132                    name: param_name,
133                    rust_type,
134                    cuda_type,
135                    kind,
136                };
137
138                if kind == HandlerParamKind::Message || kind == HandlerParamKind::MessageMut {
139                    message_param = Some(param.clone());
140                }
141
142                params.push(param);
143            }
144        }
145
146        let return_type = Self::parse_return_type(&func.sig.output, type_mapper)?;
147
148        Ok(Self {
149            name,
150            params,
151            return_type,
152            has_context,
153            message_param,
154        })
155    }
156
157    /// Classify a parameter based on its name and type.
158    fn classify_param(name: &str, ty: &Type) -> HandlerParamKind {
159        // Check for RingContext
160        if is_ring_context_type(ty) {
161            return HandlerParamKind::Context;
162        }
163
164        // Check name patterns
165        let name_lower = name.to_lowercase();
166        if name_lower == "ctx" || name_lower == "context" {
167            return HandlerParamKind::Context;
168        }
169
170        // Check for references
171        if let Type::Reference(reference) = ty {
172            let is_mut = reference.mutability.is_some();
173
174            // Check for slice
175            if let Type::Slice(_) = reference.elem.as_ref() {
176                return if is_mut {
177                    HandlerParamKind::SliceMut
178                } else {
179                    HandlerParamKind::Slice
180                };
181            }
182
183            // Message-like parameters (references to structs)
184            if name_lower == "msg" || name_lower == "message" || name_lower.starts_with("msg_") {
185                return if is_mut {
186                    HandlerParamKind::MessageMut
187                } else {
188                    HandlerParamKind::Message
189                };
190            }
191
192            // Generic reference - treat as message if it's a struct
193            if let Type::Path(path) = reference.elem.as_ref() {
194                let type_name = path
195                    .path
196                    .segments
197                    .last()
198                    .map(|s| s.ident.to_string())
199                    .unwrap_or_default();
200
201                // Heuristic: if type name ends with "Message" or "Request", treat as message
202                if type_name.ends_with("Message") || type_name.ends_with("Request") {
203                    return if is_mut {
204                        HandlerParamKind::MessageMut
205                    } else {
206                        HandlerParamKind::Message
207                    };
208                }
209            }
210        }
211
212        HandlerParamKind::Value
213    }
214
215    /// Parse the return type.
216    fn parse_return_type(
217        output: &ReturnType,
218        type_mapper: &TypeMapper,
219    ) -> Result<Option<HandlerReturnType>> {
220        match output {
221            ReturnType::Default => Ok(None),
222            ReturnType::Type(_, ty) => {
223                // Check for unit type
224                if let Type::Tuple(tuple) = ty.as_ref() {
225                    if tuple.elems.is_empty() {
226                        return Ok(None);
227                    }
228                }
229
230                let rust_type = quote::quote!(#ty).to_string();
231                let cuda_type = type_mapper
232                    .map_type(ty)
233                    .map(|ct| ct.to_cuda_string())
234                    .unwrap_or_else(|_| rust_type.clone());
235
236                let is_struct = matches!(type_mapper.map_type(ty), Ok(CudaType::Struct(_)));
237
238                Ok(Some(HandlerReturnType {
239                    rust_type,
240                    cuda_type,
241                    is_struct,
242                }))
243            }
244        }
245    }
246
247    /// Check if the handler produces a response.
248    pub fn has_response(&self) -> bool {
249        self.return_type.is_some()
250    }
251
252    /// Get non-context, non-message parameters (additional kernel params).
253    pub fn extra_params(&self) -> Vec<&HandlerParam> {
254        self.params
255            .iter()
256            .filter(|p| {
257                p.kind != HandlerParamKind::Context
258                    && p.kind != HandlerParamKind::Message
259                    && p.kind != HandlerParamKind::MessageMut
260            })
261            .collect()
262    }
263}
264
265/// Configuration for handler code generation.
266#[derive(Debug, Clone)]
267pub struct HandlerCodegenConfig {
268    /// Name for the message pointer variable.
269    pub message_var: String,
270    /// Name for the response variable.
271    pub response_var: String,
272    /// Indent string for generated code.
273    pub indent: String,
274    /// Whether to generate message deserialization.
275    pub generate_deser: bool,
276    /// Whether to generate response serialization.
277    pub generate_ser: bool,
278}
279
280impl Default for HandlerCodegenConfig {
281    fn default() -> Self {
282        Self {
283            message_var: "msg".to_string(),
284            response_var: "response".to_string(),
285            indent: "        ".to_string(), // 2 levels (function + loop)
286            generate_deser: true,
287            generate_ser: true,
288        }
289    }
290}
291
292/// Generate message deserialization code.
293///
294/// This generates code to cast the input buffer pointer to the message type.
295/// When envelope format is used, msg_ptr points to the payload (after header).
296pub fn generate_message_deser(message_type: &str, config: &HandlerCodegenConfig) -> String {
297    let mut code = String::new();
298    let indent = &config.indent;
299
300    writeln!(code, "{}// Deserialize message from buffer", indent).unwrap();
301    writeln!(
302        code,
303        "{}// msg_ptr points to payload data (after MessageHeader when using envelopes)",
304        indent
305    )
306    .unwrap();
307    writeln!(
308        code,
309        "{}{}* {} = ({}*)msg_ptr;",
310        indent, message_type, config.message_var, message_type
311    )
312    .unwrap();
313
314    code
315}
316
317/// Generate message deserialization code for envelope format with header access.
318///
319/// This provides access to both the message header and typed payload.
320pub fn generate_envelope_message_deser(
321    message_type: &str,
322    config: &HandlerCodegenConfig,
323) -> String {
324    let mut code = String::new();
325    let indent = &config.indent;
326
327    writeln!(code, "{}// Message envelope deserialization", indent).unwrap();
328    writeln!(
329        code,
330        "{}// msg_header provides: message_id, correlation_id, source_kernel, timestamp, etc.",
331        indent
332    )
333    .unwrap();
334    writeln!(
335        code,
336        "{}{}* {} = ({}*)msg_ptr;  // Typed payload",
337        indent, message_type, config.message_var, message_type
338    )
339    .unwrap();
340    writeln!(code).unwrap();
341    writeln!(code, "{}// Access header fields:", indent).unwrap();
342    writeln!(
343        code,
344        "{}// - msg_header->message_id     (unique message ID)",
345        indent
346    )
347    .unwrap();
348    writeln!(
349        code,
350        "{}// - msg_header->correlation_id (for request-response matching)",
351        indent
352    )
353    .unwrap();
354    writeln!(
355        code,
356        "{}// - msg_header->source_kernel  (sender kernel ID, 0 = host)",
357        indent
358    )
359    .unwrap();
360    writeln!(
361        code,
362        "{}// - msg_header->timestamp      (HLC timestamp of send)",
363        indent
364    )
365    .unwrap();
366
367    code
368}
369
370/// Generate response serialization code (legacy raw format).
371///
372/// This generates code to copy the response to the output buffer.
373pub fn generate_response_ser(response_type: &str, config: &HandlerCodegenConfig) -> String {
374    let mut code = String::new();
375    let indent = &config.indent;
376
377    writeln!(code).unwrap();
378    writeln!(
379        code,
380        "{}// Serialize response to output buffer (raw format)",
381        indent
382    )
383    .unwrap();
384    writeln!(
385        code,
386        "{}unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;",
387        indent
388    ).unwrap();
389    writeln!(
390        code,
391        "{}memcpy(&output_buffer[_out_idx * RESP_SIZE], &{}, sizeof({}));",
392        indent, config.response_var, response_type
393    )
394    .unwrap();
395
396    code
397}
398
399/// Generate response serialization code for envelope format.
400///
401/// This creates a complete response envelope with header and payload,
402/// properly routing the response back to the sender.
403pub fn generate_envelope_response_ser(
404    response_type: &str,
405    config: &HandlerCodegenConfig,
406) -> String {
407    let mut code = String::new();
408    let indent = &config.indent;
409
410    writeln!(code).unwrap();
411    writeln!(
412        code,
413        "{}// Serialize response envelope to output buffer",
414        indent
415    )
416    .unwrap();
417    writeln!(
418        code,
419        "{}unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;",
420        indent
421    ).unwrap();
422    writeln!(
423        code,
424        "{}unsigned char* resp_envelope = &output_buffer[_out_idx * RESP_SIZE];",
425        indent
426    )
427    .unwrap();
428    writeln!(code).unwrap();
429    writeln!(
430        code,
431        "{}// Build response header from request header",
432        indent
433    )
434    .unwrap();
435    writeln!(
436        code,
437        "{}MessageHeader* resp_header = (MessageHeader*)resp_envelope;",
438        indent
439    )
440    .unwrap();
441    writeln!(code, "{}message_create_response_header(", indent).unwrap();
442    writeln!(code, "{}    resp_header,", indent).unwrap();
443    writeln!(
444        code,
445        "{}    msg_header,              // Request header (for correlation)",
446        indent
447    )
448    .unwrap();
449    writeln!(
450        code,
451        "{}    KERNEL_ID,               // This kernel's ID",
452        indent
453    )
454    .unwrap();
455    writeln!(
456        code,
457        "{}    sizeof({}),   // Payload size",
458        indent, response_type
459    )
460    .unwrap();
461    writeln!(
462        code,
463        "{}    hlc_physical,            // Current HLC",
464        indent
465    )
466    .unwrap();
467    writeln!(code, "{}    hlc_logical,", indent).unwrap();
468    writeln!(code, "{}    HLC_NODE_ID", indent).unwrap();
469    writeln!(code, "{});", indent).unwrap();
470    writeln!(code).unwrap();
471    writeln!(code, "{}// Copy response payload after header", indent).unwrap();
472    writeln!(
473        code,
474        "{}memcpy(resp_envelope + MESSAGE_HEADER_SIZE, &{}, sizeof({}));",
475        indent, config.response_var, response_type
476    )
477    .unwrap();
478    writeln!(code).unwrap();
479    writeln!(
480        code,
481        "{}__threadfence();  // Ensure write is visible",
482        indent
483    )
484    .unwrap();
485
486    code
487}
488
489/// Generate a CUDA struct definition from field information.
490pub fn generate_cuda_struct(
491    name: &str,
492    fields: &[(String, String)], // (field_name, cuda_type)
493) -> String {
494    let mut code = String::new();
495
496    writeln!(code, "struct {} {{", name).unwrap();
497    for (field_name, cuda_type) in fields {
498        writeln!(code, "    {} {};", cuda_type, field_name).unwrap();
499    }
500    writeln!(code, "}};").unwrap();
501
502    code
503}
504
505/// Message type information for code generation.
506#[derive(Debug, Clone)]
507pub struct MessageTypeInfo {
508    /// Type name.
509    pub name: String,
510    /// Size in bytes.
511    pub size: usize,
512    /// Fields (name, cuda_type).
513    pub fields: Vec<(String, String)>,
514}
515
516impl MessageTypeInfo {
517    /// Create a simple message type with a single value field.
518    pub fn simple(name: &str, value_type: &str) -> Self {
519        Self {
520            name: name.to_string(),
521            size: Self::type_size(value_type),
522            fields: vec![("value".to_string(), value_type.to_string())],
523        }
524    }
525
526    /// Get the size of a CUDA type in bytes.
527    fn type_size(cuda_type: &str) -> usize {
528        match cuda_type {
529            "float" => 4,
530            "double" => 8,
531            "int" => 4,
532            "unsigned int" => 4,
533            "long long" | "unsigned long long" => 8,
534            "short" | "unsigned short" => 2,
535            "char" | "unsigned char" => 1,
536            _ => 8, // Default assumption
537        }
538    }
539
540    /// Generate CUDA struct definition.
541    pub fn to_cuda_struct(&self) -> String {
542        generate_cuda_struct(&self.name, &self.fields)
543    }
544}
545
546/// Registry of message types for a kernel.
547#[derive(Debug, Clone, Default)]
548pub struct MessageTypeRegistry {
549    /// Registered message types.
550    pub messages: Vec<MessageTypeInfo>,
551    /// Registered response types.
552    pub responses: Vec<MessageTypeInfo>,
553}
554
555impl MessageTypeRegistry {
556    /// Create a new empty registry.
557    pub fn new() -> Self {
558        Self::default()
559    }
560
561    /// Register a message type.
562    pub fn register_message(&mut self, info: MessageTypeInfo) {
563        self.messages.push(info);
564    }
565
566    /// Register a response type.
567    pub fn register_response(&mut self, info: MessageTypeInfo) {
568        self.responses.push(info);
569    }
570
571    /// Generate all struct definitions.
572    pub fn generate_structs(&self) -> String {
573        let mut code = String::new();
574
575        for msg in &self.messages {
576            code.push_str(&msg.to_cuda_struct());
577            code.push('\n');
578        }
579
580        for resp in &self.responses {
581            code.push_str(&resp.to_cuda_struct());
582            code.push('\n');
583        }
584
585        code
586    }
587}
588
589/// RingContext method mappings to CUDA intrinsics.
590#[derive(Debug, Clone, Copy, PartialEq, Eq)]
591pub enum ContextMethod {
592    /// ctx.thread_id() -> threadIdx.x
593    ThreadId,
594    /// ctx.block_id() -> blockIdx.x
595    BlockId,
596    /// ctx.global_thread_id() -> blockIdx.x * blockDim.x + threadIdx.x
597    GlobalThreadId,
598    /// ctx.sync_threads() -> __syncthreads()
599    SyncThreads,
600    /// ctx.now() -> HLC timestamp
601    Now,
602    /// ctx.atomic_add(ptr, val) -> atomicAdd(ptr, val)
603    AtomicAdd,
604    /// ctx.warp_shuffle(val, lane) -> __shfl_sync(0xFFFFFFFF, val, lane)
605    WarpShuffle,
606    /// ctx.lane_id() -> threadIdx.x % 32
607    LaneId,
608    /// ctx.warp_id() -> threadIdx.x / 32
609    WarpId,
610}
611
612impl ContextMethod {
613    /// Parse a method name to context method.
614    pub fn from_name(name: &str) -> Option<Self> {
615        match name {
616            "thread_id" | "thread_idx" => Some(Self::ThreadId),
617            "block_id" | "block_idx" => Some(Self::BlockId),
618            "global_thread_id" | "global_id" => Some(Self::GlobalThreadId),
619            "sync_threads" | "synchronize" => Some(Self::SyncThreads),
620            "now" | "timestamp" => Some(Self::Now),
621            "atomic_add" => Some(Self::AtomicAdd),
622            "warp_shuffle" | "shuffle" => Some(Self::WarpShuffle),
623            "lane_id" => Some(Self::LaneId),
624            "warp_id" => Some(Self::WarpId),
625            _ => None,
626        }
627    }
628
629    /// Generate CUDA code for this context method.
630    pub fn to_cuda(&self, args: &[String]) -> String {
631        match self {
632            Self::ThreadId => "threadIdx.x".to_string(),
633            Self::BlockId => "blockIdx.x".to_string(),
634            Self::GlobalThreadId => "(blockIdx.x * blockDim.x + threadIdx.x)".to_string(),
635            Self::SyncThreads => "__syncthreads()".to_string(),
636            Self::Now => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
637            Self::AtomicAdd => {
638                if args.len() >= 2 {
639                    format!("atomicAdd({}, {})", args[0], args[1])
640                } else {
641                    "/* atomic_add requires ptr and val */".to_string()
642                }
643            }
644            Self::WarpShuffle => {
645                if args.len() >= 2 {
646                    format!("__shfl_sync(0xFFFFFFFF, {}, {})", args[0], args[1])
647                } else {
648                    "/* warp_shuffle requires val and lane */".to_string()
649                }
650            }
651            Self::LaneId => "(threadIdx.x % 32)".to_string(),
652            Self::WarpId => "(threadIdx.x / 32)".to_string(),
653        }
654    }
655
656    /// Check if this method is a statement (vs expression).
657    pub fn is_statement(&self) -> bool {
658        matches!(self, Self::SyncThreads)
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use syn::parse_quote;
666
667    #[test]
668    fn test_handler_signature_simple() {
669        let func: ItemFn = parse_quote! {
670            fn process(value: f32) -> f32 {
671                value * 2.0
672            }
673        };
674
675        let mapper = TypeMapper::new();
676        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
677
678        assert_eq!(sig.name, "process");
679        assert_eq!(sig.params.len(), 1);
680        assert_eq!(sig.params[0].name, "value");
681        assert_eq!(sig.params[0].kind, HandlerParamKind::Value);
682        assert!(sig.has_response());
683        assert!(!sig.has_context);
684    }
685
686    #[test]
687    fn test_handler_signature_with_context() {
688        let func: ItemFn = parse_quote! {
689            fn handle(ctx: &RingContext, msg: &Message) -> Response {
690                Response { value: msg.value * 2.0 }
691            }
692        };
693
694        let mapper = TypeMapper::new();
695        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
696
697        assert_eq!(sig.name, "handle");
698        assert!(sig.has_context);
699        assert!(sig.message_param.is_some());
700        assert_eq!(sig.message_param.as_ref().unwrap().name, "msg");
701        assert!(sig.has_response());
702    }
703
704    #[test]
705    fn test_handler_signature_no_response() {
706        let func: ItemFn = parse_quote! {
707            fn fire_and_forget(msg: &Event) {
708                process(msg);
709            }
710        };
711
712        let mapper = TypeMapper::new();
713        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
714
715        assert!(!sig.has_response());
716    }
717
718    #[test]
719    fn test_message_deser_generation() {
720        let config = HandlerCodegenConfig::default();
721        let code = generate_message_deser("MyMessage", &config);
722
723        assert!(code.contains("MyMessage* msg"));
724        assert!(code.contains("(MyMessage*)msg_ptr"));
725    }
726
727    #[test]
728    fn test_response_ser_generation() {
729        let config = HandlerCodegenConfig::default();
730        let code = generate_response_ser("MyResponse", &config);
731
732        assert!(code.contains("memcpy"));
733        assert!(code.contains("output_buffer"));
734        assert!(code.contains("RESP_SIZE"));
735    }
736
737    #[test]
738    fn test_cuda_struct_generation() {
739        let fields = vec![
740            ("id".to_string(), "unsigned long long".to_string()),
741            ("value".to_string(), "float".to_string()),
742        ];
743
744        let code = generate_cuda_struct("MyMessage", &fields);
745
746        assert!(code.contains("struct MyMessage"));
747        assert!(code.contains("unsigned long long id"));
748        assert!(code.contains("float value"));
749    }
750
751    #[test]
752    fn test_message_type_info() {
753        let info = MessageTypeInfo::simple("FloatMsg", "float");
754        assert_eq!(info.name, "FloatMsg");
755        assert_eq!(info.size, 4);
756        assert_eq!(info.fields.len(), 1);
757
758        let cuda = info.to_cuda_struct();
759        assert!(cuda.contains("struct FloatMsg"));
760        assert!(cuda.contains("float value"));
761    }
762
763    #[test]
764    fn test_context_method_lookup() {
765        assert_eq!(
766            ContextMethod::from_name("thread_id"),
767            Some(ContextMethod::ThreadId)
768        );
769        assert_eq!(
770            ContextMethod::from_name("sync_threads"),
771            Some(ContextMethod::SyncThreads)
772        );
773        assert_eq!(
774            ContextMethod::from_name("global_thread_id"),
775            Some(ContextMethod::GlobalThreadId)
776        );
777        assert_eq!(ContextMethod::from_name("unknown"), None);
778    }
779
780    #[test]
781    fn test_context_method_cuda_output() {
782        assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
783        assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
784        assert_eq!(
785            ContextMethod::GlobalThreadId.to_cuda(&[]),
786            "(blockIdx.x * blockDim.x + threadIdx.x)"
787        );
788
789        let args = vec!["ptr".to_string(), "1".to_string()];
790        assert_eq!(ContextMethod::AtomicAdd.to_cuda(&args), "atomicAdd(ptr, 1)");
791    }
792
793    #[test]
794    fn test_message_type_registry() {
795        let mut registry = MessageTypeRegistry::new();
796
797        registry.register_message(MessageTypeInfo {
798            name: "Request".to_string(),
799            size: 8,
800            fields: vec![("value".to_string(), "float".to_string())],
801        });
802
803        registry.register_response(MessageTypeInfo {
804            name: "Response".to_string(),
805            size: 8,
806            fields: vec![("result".to_string(), "float".to_string())],
807        });
808
809        let structs = registry.generate_structs();
810        assert!(structs.contains("struct Request"));
811        assert!(structs.contains("struct Response"));
812    }
813
814    #[test]
815    fn test_handler_param_kind_classification() {
816        let ctx_ty: Type = parse_quote!(&RingContext);
817        assert_eq!(
818            HandlerSignature::classify_param("ctx", &ctx_ty),
819            HandlerParamKind::Context
820        );
821
822        let msg_ty: Type = parse_quote!(&MyMessage);
823        assert_eq!(
824            HandlerSignature::classify_param("msg", &msg_ty),
825            HandlerParamKind::Message
826        );
827
828        let slice_ty: Type = parse_quote!(&[f32]);
829        assert_eq!(
830            HandlerSignature::classify_param("data", &slice_ty),
831            HandlerParamKind::Slice
832        );
833
834        let value_ty: Type = parse_quote!(f32);
835        assert_eq!(
836            HandlerSignature::classify_param("value", &value_ty),
837            HandlerParamKind::Value
838        );
839    }
840}