ringkernel_cuda_codegen/
types.rs

1//! Type mapping from Rust to CUDA.
2//!
3//! This module handles the translation of Rust types to their CUDA equivalents.
4
5use crate::{Result, TranspileError};
6use syn::{Type, TypePath, TypeReference};
7
8/// CUDA type representation.
9#[derive(Debug, Clone, PartialEq)]
10pub enum CudaType {
11    /// Scalar types.
12    Float,
13    Double,
14    Int,
15    UnsignedInt,
16    Short,
17    UnsignedShort,
18    Char,
19    UnsignedChar,
20    LongLong,
21    UnsignedLongLong,
22    Bool,
23    Void,
24
25    /// Pointer types.
26    Pointer {
27        inner: Box<CudaType>,
28        is_const: bool,
29        restrict: bool,
30    },
31
32    /// Custom struct type.
33    Struct(String),
34}
35
36impl CudaType {
37    /// Convert to CUDA C type string.
38    pub fn to_cuda_string(&self) -> String {
39        match self {
40            CudaType::Float => "float".to_string(),
41            CudaType::Double => "double".to_string(),
42            CudaType::Int => "int".to_string(),
43            CudaType::UnsignedInt => "unsigned int".to_string(),
44            CudaType::Short => "short".to_string(),
45            CudaType::UnsignedShort => "unsigned short".to_string(),
46            CudaType::Char => "char".to_string(),
47            CudaType::UnsignedChar => "unsigned char".to_string(),
48            CudaType::LongLong => "long long".to_string(),
49            CudaType::UnsignedLongLong => "unsigned long long".to_string(),
50            CudaType::Bool => "int".to_string(), // CUDA bool quirks
51            CudaType::Void => "void".to_string(),
52            CudaType::Pointer {
53                inner,
54                is_const,
55                restrict,
56            } => {
57                let mut s = String::new();
58                if *is_const {
59                    s.push_str("const ");
60                }
61                s.push_str(&inner.to_cuda_string());
62                s.push('*');
63                if *restrict {
64                    s.push_str(" __restrict__");
65                }
66                s
67            }
68            CudaType::Struct(name) => name.clone(),
69        }
70    }
71}
72
73/// Type mapper for Rust to CUDA conversions.
74#[derive(Debug, Default)]
75pub struct TypeMapper {
76    /// Custom type mappings for user-defined structs.
77    custom_types: std::collections::HashMap<String, CudaType>,
78}
79
80impl TypeMapper {
81    /// Create a new type mapper.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Register a custom type mapping.
87    pub fn register_type(&mut self, rust_name: &str, cuda_type: CudaType) {
88        self.custom_types.insert(rust_name.to_string(), cuda_type);
89    }
90
91    /// Map a Rust type to CUDA.
92    pub fn map_type(&self, ty: &Type) -> Result<CudaType> {
93        match ty {
94            Type::Path(path) => self.map_type_path(path),
95            Type::Reference(reference) => self.map_reference(reference),
96            Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(CudaType::Void),
97            _ => Err(TranspileError::Type(format!(
98                "Unsupported type: {}",
99                quote::quote!(#ty)
100            ))),
101        }
102    }
103
104    /// Map a type path (e.g., `f32`, `MyStruct`).
105    fn map_type_path(&self, path: &TypePath) -> Result<CudaType> {
106        let segments: Vec<_> = path.path.segments.iter().collect();
107
108        if segments.len() != 1 {
109            // Check for custom types first
110            let full_path = path
111                .path
112                .segments
113                .iter()
114                .map(|s| s.ident.to_string())
115                .collect::<Vec<_>>()
116                .join("::");
117
118            if let Some(cuda_type) = self.custom_types.get(&full_path) {
119                return Ok(cuda_type.clone());
120            }
121
122            return Err(TranspileError::Type(format!(
123                "Complex path types not supported: {}",
124                quote::quote!(#path)
125            )));
126        }
127
128        let ident = &segments[0].ident;
129        let type_name = ident.to_string();
130
131        // Check primitive types
132        match type_name.as_str() {
133            "f32" => Ok(CudaType::Float),
134            "f64" => Ok(CudaType::Double),
135            "i32" => Ok(CudaType::Int),
136            "u32" => Ok(CudaType::UnsignedInt),
137            "i16" => Ok(CudaType::Short),
138            "u16" => Ok(CudaType::UnsignedShort),
139            "i8" => Ok(CudaType::Char),
140            "u8" => Ok(CudaType::UnsignedChar),
141            "i64" => Ok(CudaType::LongLong),
142            "u64" => Ok(CudaType::UnsignedLongLong),
143            "bool" => Ok(CudaType::Bool),
144            "usize" => Ok(CudaType::UnsignedLongLong), // Assume 64-bit
145            "isize" => Ok(CudaType::LongLong),
146            // GridPos is a special marker type - we don't emit it
147            "GridPos" => Ok(CudaType::Void),
148            _ => {
149                // Check custom types
150                if let Some(cuda_type) = self.custom_types.get(&type_name) {
151                    Ok(cuda_type.clone())
152                } else {
153                    // Assume it's a user struct
154                    Ok(CudaType::Struct(type_name))
155                }
156            }
157        }
158    }
159
160    /// Map a reference type (e.g., `&[f32]`, `&mut [f32]`).
161    fn map_reference(&self, reference: &TypeReference) -> Result<CudaType> {
162        let is_mutable = reference.mutability.is_some();
163
164        match reference.elem.as_ref() {
165            Type::Slice(slice) => {
166                // &[T] -> const T* __restrict__
167                // &mut [T] -> T* __restrict__
168                let inner = self.map_type(&slice.elem)?;
169                Ok(CudaType::Pointer {
170                    inner: Box::new(inner),
171                    is_const: !is_mutable,
172                    restrict: true,
173                })
174            }
175            Type::Path(path) => {
176                // &T -> const T*
177                // &mut T -> T*
178                let inner = self.map_type_path(path)?;
179                Ok(CudaType::Pointer {
180                    inner: Box::new(inner),
181                    is_const: !is_mutable,
182                    restrict: false,
183                })
184            }
185            _ => Err(TranspileError::Type(format!(
186                "Unsupported reference type: {}",
187                quote::quote!(#reference)
188            ))),
189        }
190    }
191}
192
193/// Extract the inner element type from a slice reference.
194pub fn get_slice_element_type(ty: &Type) -> Option<&Type> {
195    if let Type::Reference(reference) = ty {
196        if let Type::Slice(slice) = reference.elem.as_ref() {
197            return Some(&slice.elem);
198        }
199    }
200    None
201}
202
203/// Check if a type is a mutable reference.
204pub fn is_mutable_reference(ty: &Type) -> bool {
205    matches!(ty, Type::Reference(r) if r.mutability.is_some())
206}
207
208/// Check if a type is the GridPos context type.
209pub fn is_grid_pos_type(ty: &Type) -> bool {
210    if let Type::Path(path) = ty {
211        if let Some(segment) = path.path.segments.last() {
212            return segment.ident == "GridPos";
213        }
214    }
215    false
216}
217
218/// Check if a type is the ControlBlock type.
219pub fn is_control_block_type(ty: &Type) -> bool {
220    if let Type::Path(path) = ty {
221        if let Some(segment) = path.path.segments.last() {
222            return segment.ident == "ControlBlock";
223        }
224    }
225    // Also check for reference to ControlBlock
226    if let Type::Reference(reference) = ty {
227        return is_control_block_type(&reference.elem);
228    }
229    false
230}
231
232/// Check if a type is the RingContext type.
233pub fn is_ring_context_type(ty: &Type) -> bool {
234    if let Type::Path(path) = ty {
235        if let Some(segment) = path.path.segments.last() {
236            return segment.ident == "RingContext";
237        }
238    }
239    // Also check for reference to RingContext
240    if let Type::Reference(reference) = ty {
241        return is_ring_context_type(&reference.elem);
242    }
243    false
244}
245
246/// Check if a type is the HlcState type.
247#[allow(dead_code)]
248pub fn is_hlc_state_type(ty: &Type) -> bool {
249    if let Type::Path(path) = ty {
250        if let Some(segment) = path.path.segments.last() {
251            return segment.ident == "HlcState";
252        }
253    }
254    false
255}
256
257/// Known ring kernel parameter types.
258#[derive(Debug, Clone, Copy, PartialEq, Eq)]
259pub enum RingKernelParamKind {
260    /// ControlBlock pointer.
261    ControlBlock,
262    /// Input message buffer.
263    InputBuffer,
264    /// Output response buffer.
265    OutputBuffer,
266    /// Shared state pointer.
267    SharedState,
268    /// K2K routing table.
269    K2KRoutes,
270    /// RingContext (marker, removed in transpilation).
271    RingContext,
272    /// Regular parameter.
273    Regular,
274}
275
276impl RingKernelParamKind {
277    /// Detect the parameter kind from type and name.
278    pub fn from_param(name: &str, ty: &Type) -> Self {
279        // Check type first
280        if is_control_block_type(ty) {
281            return Self::ControlBlock;
282        }
283        if is_ring_context_type(ty) {
284            return Self::RingContext;
285        }
286
287        // Check name patterns
288        let name_lower = name.to_lowercase();
289        if name_lower.contains("control") {
290            return Self::ControlBlock;
291        }
292        if name_lower.contains("input") || name_lower == "inbox" {
293            return Self::InputBuffer;
294        }
295        if name_lower.contains("output") || name_lower == "outbox" {
296            return Self::OutputBuffer;
297        }
298        if name_lower.contains("k2k") || name_lower.contains("route") {
299            return Self::K2KRoutes;
300        }
301        if name_lower.contains("state") || name_lower.contains("shared") {
302            return Self::SharedState;
303        }
304        if name_lower == "ctx" || name_lower == "context" {
305            return Self::RingContext;
306        }
307
308        Self::Regular
309    }
310}
311
312/// Create a TypeMapper with ring kernel types pre-registered.
313pub fn ring_kernel_type_mapper() -> TypeMapper {
314    let mut mapper = TypeMapper::new();
315
316    // Register ControlBlock
317    mapper.register_type("ControlBlock", CudaType::Struct("ControlBlock".to_string()));
318
319    // Register HlcState
320    mapper.register_type("HlcState", CudaType::Struct("HlcState".to_string()));
321
322    // Register K2K types
323    mapper.register_type(
324        "K2KRoutingTable",
325        CudaType::Struct("K2KRoutingTable".to_string()),
326    );
327    mapper.register_type("K2KRoute", CudaType::Struct("K2KRoute".to_string()));
328
329    // RingContext is a marker type (removed in transpilation)
330    mapper.register_type("RingContext", CudaType::Void);
331
332    mapper
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use syn::parse_quote;
339
340    #[test]
341    fn test_primitive_types() {
342        let mapper = TypeMapper::new();
343
344        let f32_ty: Type = parse_quote!(f32);
345        assert_eq!(mapper.map_type(&f32_ty).unwrap().to_cuda_string(), "float");
346
347        let i32_ty: Type = parse_quote!(i32);
348        assert_eq!(mapper.map_type(&i32_ty).unwrap().to_cuda_string(), "int");
349
350        let bool_ty: Type = parse_quote!(bool);
351        assert_eq!(mapper.map_type(&bool_ty).unwrap().to_cuda_string(), "int");
352    }
353
354    #[test]
355    fn test_slice_types() {
356        let mapper = TypeMapper::new();
357
358        // &[f32] -> const float* __restrict__
359        let slice_ty: Type = parse_quote!(&[f32]);
360        assert_eq!(
361            mapper.map_type(&slice_ty).unwrap().to_cuda_string(),
362            "const float* __restrict__"
363        );
364
365        // &mut [f32] -> float* __restrict__
366        let mut_slice_ty: Type = parse_quote!(&mut [f32]);
367        assert_eq!(
368            mapper.map_type(&mut_slice_ty).unwrap().to_cuda_string(),
369            "float* __restrict__"
370        );
371    }
372
373    #[test]
374    fn test_grid_pos_type() {
375        let ty: Type = parse_quote!(GridPos);
376        assert!(is_grid_pos_type(&ty));
377
378        let ty: Type = parse_quote!(f32);
379        assert!(!is_grid_pos_type(&ty));
380    }
381
382    #[test]
383    fn test_custom_types() {
384        let mut mapper = TypeMapper::new();
385        mapper.register_type("WaveParams", CudaType::Struct("WaveParams".to_string()));
386
387        let ty: Type = parse_quote!(WaveParams);
388        assert_eq!(mapper.map_type(&ty).unwrap().to_cuda_string(), "WaveParams");
389    }
390
391    #[test]
392    fn test_control_block_type() {
393        let ty: Type = parse_quote!(ControlBlock);
394        assert!(is_control_block_type(&ty));
395
396        let ref_ty: Type = parse_quote!(&ControlBlock);
397        assert!(is_control_block_type(&ref_ty));
398
399        let mut_ref_ty: Type = parse_quote!(&mut ControlBlock);
400        assert!(is_control_block_type(&mut_ref_ty));
401
402        let ty: Type = parse_quote!(f32);
403        assert!(!is_control_block_type(&ty));
404    }
405
406    #[test]
407    fn test_ring_context_type() {
408        let ty: Type = parse_quote!(RingContext);
409        assert!(is_ring_context_type(&ty));
410
411        let ref_ty: Type = parse_quote!(&RingContext);
412        assert!(is_ring_context_type(&ref_ty));
413
414        let ty: Type = parse_quote!(f32);
415        assert!(!is_ring_context_type(&ty));
416    }
417
418    #[test]
419    fn test_ring_kernel_param_kind() {
420        let ctrl_ty: Type = parse_quote!(&mut ControlBlock);
421        assert_eq!(
422            RingKernelParamKind::from_param("control", &ctrl_ty),
423            RingKernelParamKind::ControlBlock
424        );
425
426        let ctx_ty: Type = parse_quote!(&RingContext);
427        assert_eq!(
428            RingKernelParamKind::from_param("ctx", &ctx_ty),
429            RingKernelParamKind::RingContext
430        );
431
432        let input_ty: Type = parse_quote!(&[u8]);
433        assert_eq!(
434            RingKernelParamKind::from_param("input_buffer", &input_ty),
435            RingKernelParamKind::InputBuffer
436        );
437
438        let output_ty: Type = parse_quote!(&mut [u8]);
439        assert_eq!(
440            RingKernelParamKind::from_param("output_buffer", &output_ty),
441            RingKernelParamKind::OutputBuffer
442        );
443
444        let regular_ty: Type = parse_quote!(f32);
445        assert_eq!(
446            RingKernelParamKind::from_param("value", &regular_ty),
447            RingKernelParamKind::Regular
448        );
449    }
450
451    #[test]
452    fn test_ring_kernel_type_mapper() {
453        let mapper = ring_kernel_type_mapper();
454
455        let ctrl_ty: Type = parse_quote!(ControlBlock);
456        assert_eq!(
457            mapper.map_type(&ctrl_ty).unwrap().to_cuda_string(),
458            "ControlBlock"
459        );
460
461        let hlc_ty: Type = parse_quote!(HlcState);
462        assert_eq!(
463            mapper.map_type(&hlc_ty).unwrap().to_cuda_string(),
464            "HlcState"
465        );
466
467        let ctx_ty: Type = parse_quote!(RingContext);
468        assert_eq!(mapper.map_type(&ctx_ty).unwrap().to_cuda_string(), "void");
469    }
470}