ringkernel_cuda_codegen/
handler.rs

1//! Handler function integration for ring kernel transpilation.
2//!
3//! This module provides utilities for parsing handler function signatures,
4//! extracting message/response types, and generating the code that binds
5//! handlers to the ring kernel message loop.
6//!
7//! # Handler Signature Patterns
8//!
9//! Ring kernel handlers follow specific patterns:
10//!
11//! ```ignore
12//! // Pattern 1: Message in, response out
13//! fn handle(ctx: &RingContext, msg: &MyMessage) -> MyResponse { ... }
14//!
15//! // Pattern 2: Message in, no response (fire-and-forget)
16//! fn handle(ctx: &RingContext, msg: &MyMessage) { ... }
17//!
18//! // Pattern 3: Simple value processing
19//! fn process(value: f32) -> f32 { ... }
20//! ```
21//!
22//! # Generated Code
23//!
24//! The handler body is embedded within the message loop with proper
25//! message deserialization and response serialization:
26//!
27//! ```cuda
28//! // Message pointer from buffer
29//! MyMessage* msg = (MyMessage*)&input_buffer[msg_idx * MSG_SIZE];
30//!
31//! // === Handler body (transpiled) ===
32//! float result = msg->value * 2.0f;
33//! MyResponse response;
34//! response.value = result;
35//! // ================================
36//!
37//! // Enqueue response
38//! memcpy(&output_buffer[out_idx * RESP_SIZE], &response, sizeof(MyResponse));
39//! ```
40
41use crate::types::{is_ring_context_type, CudaType, TypeMapper};
42use crate::Result;
43use std::fmt::Write;
44use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
45
46/// Information about a handler function parameter.
47#[derive(Debug, Clone)]
48pub struct HandlerParam {
49    /// Parameter name.
50    pub name: String,
51    /// Parameter type (Rust).
52    pub rust_type: String,
53    /// Parameter type (CUDA).
54    pub cuda_type: String,
55    /// Kind of parameter.
56    pub kind: HandlerParamKind,
57}
58
59/// Kinds of handler parameters.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum HandlerParamKind {
62    /// RingContext reference (removed in transpilation).
63    Context,
64    /// Message reference (deserialized from input buffer).
65    Message,
66    /// Mutable message reference.
67    MessageMut,
68    /// Regular value parameter.
69    Value,
70    /// Slice parameter (pointer in CUDA).
71    Slice,
72    /// Mutable slice parameter.
73    SliceMut,
74}
75
76/// Parsed handler function signature.
77#[derive(Debug, Clone)]
78pub struct HandlerSignature {
79    /// Handler function name.
80    pub name: String,
81    /// All parameters.
82    pub params: Vec<HandlerParam>,
83    /// Return type (if any).
84    pub return_type: Option<HandlerReturnType>,
85    /// Whether the handler has a RingContext parameter.
86    pub has_context: bool,
87    /// The message parameter (if any).
88    pub message_param: Option<HandlerParam>,
89}
90
91/// Handler return type information.
92#[derive(Debug, Clone)]
93pub struct HandlerReturnType {
94    /// Rust type name.
95    pub rust_type: String,
96    /// CUDA type string.
97    pub cuda_type: String,
98    /// Whether this is a struct type (vs primitive).
99    pub is_struct: bool,
100}
101
102impl HandlerSignature {
103    /// Parse a handler function signature.
104    pub fn parse(func: &ItemFn, type_mapper: &TypeMapper) -> Result<Self> {
105        let name = func.sig.ident.to_string();
106        let mut params = Vec::new();
107        let mut has_context = false;
108        let mut message_param = None;
109
110        for param in &func.sig.inputs {
111            if let FnArg::Typed(pat_type) = param {
112                let param_name = match pat_type.pat.as_ref() {
113                    Pat::Ident(ident) => ident.ident.to_string(),
114                    _ => continue,
115                };
116
117                let rust_type = quote::quote!(#pat_type.ty).to_string();
118                let kind = Self::classify_param(&param_name, &pat_type.ty);
119
120                if kind == HandlerParamKind::Context {
121                    has_context = true;
122                    continue; // Skip context params in output
123                }
124
125                let cuda_type = match type_mapper.map_type(&pat_type.ty) {
126                    Ok(ct) => ct.to_cuda_string(),
127                    Err(_) => "void*".to_string(), // Fallback for unknown types
128                };
129
130                let param = HandlerParam {
131                    name: param_name,
132                    rust_type,
133                    cuda_type,
134                    kind,
135                };
136
137                if kind == HandlerParamKind::Message || kind == HandlerParamKind::MessageMut {
138                    message_param = Some(param.clone());
139                }
140
141                params.push(param);
142            }
143        }
144
145        let return_type = Self::parse_return_type(&func.sig.output, type_mapper)?;
146
147        Ok(Self {
148            name,
149            params,
150            return_type,
151            has_context,
152            message_param,
153        })
154    }
155
156    /// Classify a parameter based on its name and type.
157    fn classify_param(name: &str, ty: &Type) -> HandlerParamKind {
158        // Check for RingContext
159        if is_ring_context_type(ty) {
160            return HandlerParamKind::Context;
161        }
162
163        // Check name patterns
164        let name_lower = name.to_lowercase();
165        if name_lower == "ctx" || name_lower == "context" {
166            return HandlerParamKind::Context;
167        }
168
169        // Check for references
170        if let Type::Reference(reference) = ty {
171            let is_mut = reference.mutability.is_some();
172
173            // Check for slice
174            if let Type::Slice(_) = reference.elem.as_ref() {
175                return if is_mut {
176                    HandlerParamKind::SliceMut
177                } else {
178                    HandlerParamKind::Slice
179                };
180            }
181
182            // Message-like parameters (references to structs)
183            if name_lower == "msg" || name_lower == "message" || name_lower.starts_with("msg_") {
184                return if is_mut {
185                    HandlerParamKind::MessageMut
186                } else {
187                    HandlerParamKind::Message
188                };
189            }
190
191            // Generic reference - treat as message if it's a struct
192            if let Type::Path(path) = reference.elem.as_ref() {
193                let type_name = path
194                    .path
195                    .segments
196                    .last()
197                    .map(|s| s.ident.to_string())
198                    .unwrap_or_default();
199
200                // Heuristic: if type name ends with "Message" or "Request", treat as message
201                if type_name.ends_with("Message") || type_name.ends_with("Request") {
202                    return if is_mut {
203                        HandlerParamKind::MessageMut
204                    } else {
205                        HandlerParamKind::Message
206                    };
207                }
208            }
209        }
210
211        HandlerParamKind::Value
212    }
213
214    /// Parse the return type.
215    fn parse_return_type(
216        output: &ReturnType,
217        type_mapper: &TypeMapper,
218    ) -> Result<Option<HandlerReturnType>> {
219        match output {
220            ReturnType::Default => Ok(None),
221            ReturnType::Type(_, ty) => {
222                // Check for unit type
223                if let Type::Tuple(tuple) = ty.as_ref() {
224                    if tuple.elems.is_empty() {
225                        return Ok(None);
226                    }
227                }
228
229                let rust_type = quote::quote!(#ty).to_string();
230                let cuda_type = type_mapper
231                    .map_type(ty)
232                    .map(|ct| ct.to_cuda_string())
233                    .unwrap_or_else(|_| rust_type.clone());
234
235                let is_struct = matches!(type_mapper.map_type(ty), Ok(CudaType::Struct(_)));
236
237                Ok(Some(HandlerReturnType {
238                    rust_type,
239                    cuda_type,
240                    is_struct,
241                }))
242            }
243        }
244    }
245
246    /// Check if the handler produces a response.
247    pub fn has_response(&self) -> bool {
248        self.return_type.is_some()
249    }
250
251    /// Get non-context, non-message parameters (additional kernel params).
252    pub fn extra_params(&self) -> Vec<&HandlerParam> {
253        self.params
254            .iter()
255            .filter(|p| {
256                p.kind != HandlerParamKind::Context
257                    && p.kind != HandlerParamKind::Message
258                    && p.kind != HandlerParamKind::MessageMut
259            })
260            .collect()
261    }
262}
263
264/// Configuration for handler code generation.
265#[derive(Debug, Clone)]
266pub struct HandlerCodegenConfig {
267    /// Name for the message pointer variable.
268    pub message_var: String,
269    /// Name for the response variable.
270    pub response_var: String,
271    /// Indent string for generated code.
272    pub indent: String,
273    /// Whether to generate message deserialization.
274    pub generate_deser: bool,
275    /// Whether to generate response serialization.
276    pub generate_ser: bool,
277}
278
279impl Default for HandlerCodegenConfig {
280    fn default() -> Self {
281        Self {
282            message_var: "msg".to_string(),
283            response_var: "response".to_string(),
284            indent: "        ".to_string(), // 2 levels (function + loop)
285            generate_deser: true,
286            generate_ser: true,
287        }
288    }
289}
290
291/// Generate message deserialization code.
292///
293/// This generates code to cast the input buffer pointer to the message type.
294pub fn generate_message_deser(message_type: &str, config: &HandlerCodegenConfig) -> String {
295    let mut code = String::new();
296    let indent = &config.indent;
297
298    writeln!(code, "{}// Deserialize message from buffer", indent).unwrap();
299    writeln!(
300        code,
301        "{}{}* {} = ({}*)msg_ptr;",
302        indent, message_type, config.message_var, message_type
303    )
304    .unwrap();
305
306    code
307}
308
309/// Generate response serialization code.
310///
311/// This generates code to copy the response to the output buffer.
312pub fn generate_response_ser(response_type: &str, config: &HandlerCodegenConfig) -> String {
313    let mut code = String::new();
314    let indent = &config.indent;
315
316    writeln!(code).unwrap();
317    writeln!(code, "{}// Serialize response to output buffer", indent).unwrap();
318    writeln!(
319        code,
320        "{}unsigned long long _out_idx = atomicAdd(&control->output_head, 1) & control->output_mask;",
321        indent
322    ).unwrap();
323    writeln!(
324        code,
325        "{}memcpy(&output_buffer[_out_idx * RESP_SIZE], &{}, sizeof({}));",
326        indent, config.response_var, response_type
327    )
328    .unwrap();
329
330    code
331}
332
333/// Generate a CUDA struct definition from field information.
334pub fn generate_cuda_struct(
335    name: &str,
336    fields: &[(String, String)], // (field_name, cuda_type)
337) -> String {
338    let mut code = String::new();
339
340    writeln!(code, "struct {} {{", name).unwrap();
341    for (field_name, cuda_type) in fields {
342        writeln!(code, "    {} {};", cuda_type, field_name).unwrap();
343    }
344    writeln!(code, "}};").unwrap();
345
346    code
347}
348
349/// Message type information for code generation.
350#[derive(Debug, Clone)]
351pub struct MessageTypeInfo {
352    /// Type name.
353    pub name: String,
354    /// Size in bytes.
355    pub size: usize,
356    /// Fields (name, cuda_type).
357    pub fields: Vec<(String, String)>,
358}
359
360impl MessageTypeInfo {
361    /// Create a simple message type with a single value field.
362    pub fn simple(name: &str, value_type: &str) -> Self {
363        Self {
364            name: name.to_string(),
365            size: Self::type_size(value_type),
366            fields: vec![("value".to_string(), value_type.to_string())],
367        }
368    }
369
370    /// Get the size of a CUDA type in bytes.
371    fn type_size(cuda_type: &str) -> usize {
372        match cuda_type {
373            "float" => 4,
374            "double" => 8,
375            "int" => 4,
376            "unsigned int" => 4,
377            "long long" | "unsigned long long" => 8,
378            "short" | "unsigned short" => 2,
379            "char" | "unsigned char" => 1,
380            _ => 8, // Default assumption
381        }
382    }
383
384    /// Generate CUDA struct definition.
385    pub fn to_cuda_struct(&self) -> String {
386        generate_cuda_struct(&self.name, &self.fields)
387    }
388}
389
390/// Registry of message types for a kernel.
391#[derive(Debug, Clone, Default)]
392pub struct MessageTypeRegistry {
393    /// Registered message types.
394    pub messages: Vec<MessageTypeInfo>,
395    /// Registered response types.
396    pub responses: Vec<MessageTypeInfo>,
397}
398
399impl MessageTypeRegistry {
400    /// Create a new empty registry.
401    pub fn new() -> Self {
402        Self::default()
403    }
404
405    /// Register a message type.
406    pub fn register_message(&mut self, info: MessageTypeInfo) {
407        self.messages.push(info);
408    }
409
410    /// Register a response type.
411    pub fn register_response(&mut self, info: MessageTypeInfo) {
412        self.responses.push(info);
413    }
414
415    /// Generate all struct definitions.
416    pub fn generate_structs(&self) -> String {
417        let mut code = String::new();
418
419        for msg in &self.messages {
420            code.push_str(&msg.to_cuda_struct());
421            code.push('\n');
422        }
423
424        for resp in &self.responses {
425            code.push_str(&resp.to_cuda_struct());
426            code.push('\n');
427        }
428
429        code
430    }
431}
432
433/// RingContext method mappings to CUDA intrinsics.
434#[derive(Debug, Clone, Copy, PartialEq, Eq)]
435pub enum ContextMethod {
436    /// ctx.thread_id() -> threadIdx.x
437    ThreadId,
438    /// ctx.block_id() -> blockIdx.x
439    BlockId,
440    /// ctx.global_thread_id() -> blockIdx.x * blockDim.x + threadIdx.x
441    GlobalThreadId,
442    /// ctx.sync_threads() -> __syncthreads()
443    SyncThreads,
444    /// ctx.now() -> HLC timestamp
445    Now,
446    /// ctx.atomic_add(ptr, val) -> atomicAdd(ptr, val)
447    AtomicAdd,
448    /// ctx.warp_shuffle(val, lane) -> __shfl_sync(0xFFFFFFFF, val, lane)
449    WarpShuffle,
450    /// ctx.lane_id() -> threadIdx.x % 32
451    LaneId,
452    /// ctx.warp_id() -> threadIdx.x / 32
453    WarpId,
454}
455
456impl ContextMethod {
457    /// Parse a method name to context method.
458    pub fn from_name(name: &str) -> Option<Self> {
459        match name {
460            "thread_id" | "thread_idx" => Some(Self::ThreadId),
461            "block_id" | "block_idx" => Some(Self::BlockId),
462            "global_thread_id" | "global_id" => Some(Self::GlobalThreadId),
463            "sync_threads" | "synchronize" => Some(Self::SyncThreads),
464            "now" | "timestamp" => Some(Self::Now),
465            "atomic_add" => Some(Self::AtomicAdd),
466            "warp_shuffle" | "shuffle" => Some(Self::WarpShuffle),
467            "lane_id" => Some(Self::LaneId),
468            "warp_id" => Some(Self::WarpId),
469            _ => None,
470        }
471    }
472
473    /// Generate CUDA code for this context method.
474    pub fn to_cuda(&self, args: &[String]) -> String {
475        match self {
476            Self::ThreadId => "threadIdx.x".to_string(),
477            Self::BlockId => "blockIdx.x".to_string(),
478            Self::GlobalThreadId => "(blockIdx.x * blockDim.x + threadIdx.x)".to_string(),
479            Self::SyncThreads => "__syncthreads()".to_string(),
480            Self::Now => "(hlc_physical << 32) | (hlc_logical & 0xFFFFFFFF)".to_string(),
481            Self::AtomicAdd => {
482                if args.len() >= 2 {
483                    format!("atomicAdd({}, {})", args[0], args[1])
484                } else {
485                    "/* atomic_add requires ptr and val */".to_string()
486                }
487            }
488            Self::WarpShuffle => {
489                if args.len() >= 2 {
490                    format!("__shfl_sync(0xFFFFFFFF, {}, {})", args[0], args[1])
491                } else {
492                    "/* warp_shuffle requires val and lane */".to_string()
493                }
494            }
495            Self::LaneId => "(threadIdx.x % 32)".to_string(),
496            Self::WarpId => "(threadIdx.x / 32)".to_string(),
497        }
498    }
499
500    /// Check if this method is a statement (vs expression).
501    pub fn is_statement(&self) -> bool {
502        matches!(self, Self::SyncThreads)
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use syn::parse_quote;
510
511    #[test]
512    fn test_handler_signature_simple() {
513        let func: ItemFn = parse_quote! {
514            fn process(value: f32) -> f32 {
515                value * 2.0
516            }
517        };
518
519        let mapper = TypeMapper::new();
520        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
521
522        assert_eq!(sig.name, "process");
523        assert_eq!(sig.params.len(), 1);
524        assert_eq!(sig.params[0].name, "value");
525        assert_eq!(sig.params[0].kind, HandlerParamKind::Value);
526        assert!(sig.has_response());
527        assert!(!sig.has_context);
528    }
529
530    #[test]
531    fn test_handler_signature_with_context() {
532        let func: ItemFn = parse_quote! {
533            fn handle(ctx: &RingContext, msg: &Message) -> Response {
534                Response { value: msg.value * 2.0 }
535            }
536        };
537
538        let mapper = TypeMapper::new();
539        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
540
541        assert_eq!(sig.name, "handle");
542        assert!(sig.has_context);
543        assert!(sig.message_param.is_some());
544        assert_eq!(sig.message_param.as_ref().unwrap().name, "msg");
545        assert!(sig.has_response());
546    }
547
548    #[test]
549    fn test_handler_signature_no_response() {
550        let func: ItemFn = parse_quote! {
551            fn fire_and_forget(msg: &Event) {
552                process(msg);
553            }
554        };
555
556        let mapper = TypeMapper::new();
557        let sig = HandlerSignature::parse(&func, &mapper).unwrap();
558
559        assert!(!sig.has_response());
560    }
561
562    #[test]
563    fn test_message_deser_generation() {
564        let config = HandlerCodegenConfig::default();
565        let code = generate_message_deser("MyMessage", &config);
566
567        assert!(code.contains("MyMessage* msg"));
568        assert!(code.contains("(MyMessage*)msg_ptr"));
569    }
570
571    #[test]
572    fn test_response_ser_generation() {
573        let config = HandlerCodegenConfig::default();
574        let code = generate_response_ser("MyResponse", &config);
575
576        assert!(code.contains("memcpy"));
577        assert!(code.contains("output_buffer"));
578        assert!(code.contains("RESP_SIZE"));
579    }
580
581    #[test]
582    fn test_cuda_struct_generation() {
583        let fields = vec![
584            ("id".to_string(), "unsigned long long".to_string()),
585            ("value".to_string(), "float".to_string()),
586        ];
587
588        let code = generate_cuda_struct("MyMessage", &fields);
589
590        assert!(code.contains("struct MyMessage"));
591        assert!(code.contains("unsigned long long id"));
592        assert!(code.contains("float value"));
593    }
594
595    #[test]
596    fn test_message_type_info() {
597        let info = MessageTypeInfo::simple("FloatMsg", "float");
598        assert_eq!(info.name, "FloatMsg");
599        assert_eq!(info.size, 4);
600        assert_eq!(info.fields.len(), 1);
601
602        let cuda = info.to_cuda_struct();
603        assert!(cuda.contains("struct FloatMsg"));
604        assert!(cuda.contains("float value"));
605    }
606
607    #[test]
608    fn test_context_method_lookup() {
609        assert_eq!(
610            ContextMethod::from_name("thread_id"),
611            Some(ContextMethod::ThreadId)
612        );
613        assert_eq!(
614            ContextMethod::from_name("sync_threads"),
615            Some(ContextMethod::SyncThreads)
616        );
617        assert_eq!(
618            ContextMethod::from_name("global_thread_id"),
619            Some(ContextMethod::GlobalThreadId)
620        );
621        assert_eq!(ContextMethod::from_name("unknown"), None);
622    }
623
624    #[test]
625    fn test_context_method_cuda_output() {
626        assert_eq!(ContextMethod::ThreadId.to_cuda(&[]), "threadIdx.x");
627        assert_eq!(ContextMethod::SyncThreads.to_cuda(&[]), "__syncthreads()");
628        assert_eq!(
629            ContextMethod::GlobalThreadId.to_cuda(&[]),
630            "(blockIdx.x * blockDim.x + threadIdx.x)"
631        );
632
633        let args = vec!["ptr".to_string(), "1".to_string()];
634        assert_eq!(ContextMethod::AtomicAdd.to_cuda(&args), "atomicAdd(ptr, 1)");
635    }
636
637    #[test]
638    fn test_message_type_registry() {
639        let mut registry = MessageTypeRegistry::new();
640
641        registry.register_message(MessageTypeInfo {
642            name: "Request".to_string(),
643            size: 8,
644            fields: vec![("value".to_string(), "float".to_string())],
645        });
646
647        registry.register_response(MessageTypeInfo {
648            name: "Response".to_string(),
649            size: 8,
650            fields: vec![("result".to_string(), "float".to_string())],
651        });
652
653        let structs = registry.generate_structs();
654        assert!(structs.contains("struct Request"));
655        assert!(structs.contains("struct Response"));
656    }
657
658    #[test]
659    fn test_handler_param_kind_classification() {
660        let ctx_ty: Type = parse_quote!(&RingContext);
661        assert_eq!(
662            HandlerSignature::classify_param("ctx", &ctx_ty),
663            HandlerParamKind::Context
664        );
665
666        let msg_ty: Type = parse_quote!(&MyMessage);
667        assert_eq!(
668            HandlerSignature::classify_param("msg", &msg_ty),
669            HandlerParamKind::Message
670        );
671
672        let slice_ty: Type = parse_quote!(&[f32]);
673        assert_eq!(
674            HandlerSignature::classify_param("data", &slice_ty),
675            HandlerParamKind::Slice
676        );
677
678        let value_ty: Type = parse_quote!(f32);
679        assert_eq!(
680            HandlerSignature::classify_param("value", &value_ty),
681            HandlerParamKind::Value
682        );
683    }
684}