ringkernel_cuda_codegen/
types.rs

1//! Type mapping from Rust to CUDA.
2//!
3//! This module handles the translation of Rust types to their CUDA equivalents.
4
5use crate::{Result, TranspileError};
6use syn::{Type, TypePath, TypeReference};
7
8/// CUDA type representation.
9#[derive(Debug, Clone, PartialEq)]
10pub enum CudaType {
11    /// Scalar types.
12    Float,
13    Double,
14    Int,
15    UnsignedInt,
16    LongLong,
17    UnsignedLongLong,
18    Bool,
19    Void,
20
21    /// Pointer types.
22    Pointer {
23        inner: Box<CudaType>,
24        is_const: bool,
25        restrict: bool,
26    },
27
28    /// Custom struct type.
29    Struct(String),
30}
31
32impl CudaType {
33    /// Convert to CUDA C type string.
34    pub fn to_cuda_string(&self) -> String {
35        match self {
36            CudaType::Float => "float".to_string(),
37            CudaType::Double => "double".to_string(),
38            CudaType::Int => "int".to_string(),
39            CudaType::UnsignedInt => "unsigned int".to_string(),
40            CudaType::LongLong => "long long".to_string(),
41            CudaType::UnsignedLongLong => "unsigned long long".to_string(),
42            CudaType::Bool => "int".to_string(), // CUDA bool quirks
43            CudaType::Void => "void".to_string(),
44            CudaType::Pointer {
45                inner,
46                is_const,
47                restrict,
48            } => {
49                let mut s = String::new();
50                if *is_const {
51                    s.push_str("const ");
52                }
53                s.push_str(&inner.to_cuda_string());
54                s.push('*');
55                if *restrict {
56                    s.push_str(" __restrict__");
57                }
58                s
59            }
60            CudaType::Struct(name) => name.clone(),
61        }
62    }
63}
64
65/// Type mapper for Rust to CUDA conversions.
66#[derive(Debug, Default)]
67pub struct TypeMapper {
68    /// Custom type mappings for user-defined structs.
69    custom_types: std::collections::HashMap<String, CudaType>,
70}
71
72impl TypeMapper {
73    /// Create a new type mapper.
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Register a custom type mapping.
79    pub fn register_type(&mut self, rust_name: &str, cuda_type: CudaType) {
80        self.custom_types.insert(rust_name.to_string(), cuda_type);
81    }
82
83    /// Map a Rust type to CUDA.
84    pub fn map_type(&self, ty: &Type) -> Result<CudaType> {
85        match ty {
86            Type::Path(path) => self.map_type_path(path),
87            Type::Reference(reference) => self.map_reference(reference),
88            Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(CudaType::Void),
89            _ => Err(TranspileError::Type(format!(
90                "Unsupported type: {}",
91                quote::quote!(#ty)
92            ))),
93        }
94    }
95
96    /// Map a type path (e.g., `f32`, `MyStruct`).
97    fn map_type_path(&self, path: &TypePath) -> Result<CudaType> {
98        let segments: Vec<_> = path.path.segments.iter().collect();
99
100        if segments.len() != 1 {
101            // Check for custom types first
102            let full_path = path
103                .path
104                .segments
105                .iter()
106                .map(|s| s.ident.to_string())
107                .collect::<Vec<_>>()
108                .join("::");
109
110            if let Some(cuda_type) = self.custom_types.get(&full_path) {
111                return Ok(cuda_type.clone());
112            }
113
114            return Err(TranspileError::Type(format!(
115                "Complex path types not supported: {}",
116                quote::quote!(#path)
117            )));
118        }
119
120        let ident = &segments[0].ident;
121        let type_name = ident.to_string();
122
123        // Check primitive types
124        match type_name.as_str() {
125            "f32" => Ok(CudaType::Float),
126            "f64" => Ok(CudaType::Double),
127            "i32" => Ok(CudaType::Int),
128            "u32" => Ok(CudaType::UnsignedInt),
129            "i64" => Ok(CudaType::LongLong),
130            "u64" => Ok(CudaType::UnsignedLongLong),
131            "bool" => Ok(CudaType::Bool),
132            "usize" => Ok(CudaType::UnsignedLongLong), // Assume 64-bit
133            "isize" => Ok(CudaType::LongLong),
134            // GridPos is a special marker type - we don't emit it
135            "GridPos" => Ok(CudaType::Void),
136            _ => {
137                // Check custom types
138                if let Some(cuda_type) = self.custom_types.get(&type_name) {
139                    Ok(cuda_type.clone())
140                } else {
141                    // Assume it's a user struct
142                    Ok(CudaType::Struct(type_name))
143                }
144            }
145        }
146    }
147
148    /// Map a reference type (e.g., `&[f32]`, `&mut [f32]`).
149    fn map_reference(&self, reference: &TypeReference) -> Result<CudaType> {
150        let is_mutable = reference.mutability.is_some();
151
152        match reference.elem.as_ref() {
153            Type::Slice(slice) => {
154                // &[T] -> const T* __restrict__
155                // &mut [T] -> T* __restrict__
156                let inner = self.map_type(&slice.elem)?;
157                Ok(CudaType::Pointer {
158                    inner: Box::new(inner),
159                    is_const: !is_mutable,
160                    restrict: true,
161                })
162            }
163            Type::Path(path) => {
164                // &T -> const T*
165                // &mut T -> T*
166                let inner = self.map_type_path(path)?;
167                Ok(CudaType::Pointer {
168                    inner: Box::new(inner),
169                    is_const: !is_mutable,
170                    restrict: false,
171                })
172            }
173            _ => Err(TranspileError::Type(format!(
174                "Unsupported reference type: {}",
175                quote::quote!(#reference)
176            ))),
177        }
178    }
179}
180
181/// Extract the inner element type from a slice reference.
182pub fn get_slice_element_type(ty: &Type) -> Option<&Type> {
183    if let Type::Reference(reference) = ty {
184        if let Type::Slice(slice) = reference.elem.as_ref() {
185            return Some(&slice.elem);
186        }
187    }
188    None
189}
190
191/// Check if a type is a mutable reference.
192pub fn is_mutable_reference(ty: &Type) -> bool {
193    matches!(ty, Type::Reference(r) if r.mutability.is_some())
194}
195
196/// Check if a type is the GridPos context type.
197pub fn is_grid_pos_type(ty: &Type) -> bool {
198    if let Type::Path(path) = ty {
199        if let Some(segment) = path.path.segments.last() {
200            return segment.ident == "GridPos";
201        }
202    }
203    false
204}
205
206/// Check if a type is the ControlBlock type.
207pub fn is_control_block_type(ty: &Type) -> bool {
208    if let Type::Path(path) = ty {
209        if let Some(segment) = path.path.segments.last() {
210            return segment.ident == "ControlBlock";
211        }
212    }
213    // Also check for reference to ControlBlock
214    if let Type::Reference(reference) = ty {
215        return is_control_block_type(&reference.elem);
216    }
217    false
218}
219
220/// Check if a type is the RingContext type.
221pub fn is_ring_context_type(ty: &Type) -> bool {
222    if let Type::Path(path) = ty {
223        if let Some(segment) = path.path.segments.last() {
224            return segment.ident == "RingContext";
225        }
226    }
227    // Also check for reference to RingContext
228    if let Type::Reference(reference) = ty {
229        return is_ring_context_type(&reference.elem);
230    }
231    false
232}
233
234/// Check if a type is the HlcState type.
235#[allow(dead_code)]
236pub fn is_hlc_state_type(ty: &Type) -> bool {
237    if let Type::Path(path) = ty {
238        if let Some(segment) = path.path.segments.last() {
239            return segment.ident == "HlcState";
240        }
241    }
242    false
243}
244
245/// Known ring kernel parameter types.
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247pub enum RingKernelParamKind {
248    /// ControlBlock pointer.
249    ControlBlock,
250    /// Input message buffer.
251    InputBuffer,
252    /// Output response buffer.
253    OutputBuffer,
254    /// Shared state pointer.
255    SharedState,
256    /// K2K routing table.
257    K2KRoutes,
258    /// RingContext (marker, removed in transpilation).
259    RingContext,
260    /// Regular parameter.
261    Regular,
262}
263
264impl RingKernelParamKind {
265    /// Detect the parameter kind from type and name.
266    pub fn from_param(name: &str, ty: &Type) -> Self {
267        // Check type first
268        if is_control_block_type(ty) {
269            return Self::ControlBlock;
270        }
271        if is_ring_context_type(ty) {
272            return Self::RingContext;
273        }
274
275        // Check name patterns
276        let name_lower = name.to_lowercase();
277        if name_lower.contains("control") {
278            return Self::ControlBlock;
279        }
280        if name_lower.contains("input") || name_lower == "inbox" {
281            return Self::InputBuffer;
282        }
283        if name_lower.contains("output") || name_lower == "outbox" {
284            return Self::OutputBuffer;
285        }
286        if name_lower.contains("k2k") || name_lower.contains("route") {
287            return Self::K2KRoutes;
288        }
289        if name_lower.contains("state") || name_lower.contains("shared") {
290            return Self::SharedState;
291        }
292        if name_lower == "ctx" || name_lower == "context" {
293            return Self::RingContext;
294        }
295
296        Self::Regular
297    }
298}
299
300/// Create a TypeMapper with ring kernel types pre-registered.
301pub fn ring_kernel_type_mapper() -> TypeMapper {
302    let mut mapper = TypeMapper::new();
303
304    // Register ControlBlock
305    mapper.register_type("ControlBlock", CudaType::Struct("ControlBlock".to_string()));
306
307    // Register HlcState
308    mapper.register_type("HlcState", CudaType::Struct("HlcState".to_string()));
309
310    // Register K2K types
311    mapper.register_type(
312        "K2KRoutingTable",
313        CudaType::Struct("K2KRoutingTable".to_string()),
314    );
315    mapper.register_type("K2KRoute", CudaType::Struct("K2KRoute".to_string()));
316
317    // RingContext is a marker type (removed in transpilation)
318    mapper.register_type("RingContext", CudaType::Void);
319
320    mapper
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use syn::parse_quote;
327
328    #[test]
329    fn test_primitive_types() {
330        let mapper = TypeMapper::new();
331
332        let f32_ty: Type = parse_quote!(f32);
333        assert_eq!(mapper.map_type(&f32_ty).unwrap().to_cuda_string(), "float");
334
335        let i32_ty: Type = parse_quote!(i32);
336        assert_eq!(mapper.map_type(&i32_ty).unwrap().to_cuda_string(), "int");
337
338        let bool_ty: Type = parse_quote!(bool);
339        assert_eq!(mapper.map_type(&bool_ty).unwrap().to_cuda_string(), "int");
340    }
341
342    #[test]
343    fn test_slice_types() {
344        let mapper = TypeMapper::new();
345
346        // &[f32] -> const float* __restrict__
347        let slice_ty: Type = parse_quote!(&[f32]);
348        assert_eq!(
349            mapper.map_type(&slice_ty).unwrap().to_cuda_string(),
350            "const float* __restrict__"
351        );
352
353        // &mut [f32] -> float* __restrict__
354        let mut_slice_ty: Type = parse_quote!(&mut [f32]);
355        assert_eq!(
356            mapper.map_type(&mut_slice_ty).unwrap().to_cuda_string(),
357            "float* __restrict__"
358        );
359    }
360
361    #[test]
362    fn test_grid_pos_type() {
363        let ty: Type = parse_quote!(GridPos);
364        assert!(is_grid_pos_type(&ty));
365
366        let ty: Type = parse_quote!(f32);
367        assert!(!is_grid_pos_type(&ty));
368    }
369
370    #[test]
371    fn test_custom_types() {
372        let mut mapper = TypeMapper::new();
373        mapper.register_type("WaveParams", CudaType::Struct("WaveParams".to_string()));
374
375        let ty: Type = parse_quote!(WaveParams);
376        assert_eq!(mapper.map_type(&ty).unwrap().to_cuda_string(), "WaveParams");
377    }
378
379    #[test]
380    fn test_control_block_type() {
381        let ty: Type = parse_quote!(ControlBlock);
382        assert!(is_control_block_type(&ty));
383
384        let ref_ty: Type = parse_quote!(&ControlBlock);
385        assert!(is_control_block_type(&ref_ty));
386
387        let mut_ref_ty: Type = parse_quote!(&mut ControlBlock);
388        assert!(is_control_block_type(&mut_ref_ty));
389
390        let ty: Type = parse_quote!(f32);
391        assert!(!is_control_block_type(&ty));
392    }
393
394    #[test]
395    fn test_ring_context_type() {
396        let ty: Type = parse_quote!(RingContext);
397        assert!(is_ring_context_type(&ty));
398
399        let ref_ty: Type = parse_quote!(&RingContext);
400        assert!(is_ring_context_type(&ref_ty));
401
402        let ty: Type = parse_quote!(f32);
403        assert!(!is_ring_context_type(&ty));
404    }
405
406    #[test]
407    fn test_ring_kernel_param_kind() {
408        let ctrl_ty: Type = parse_quote!(&mut ControlBlock);
409        assert_eq!(
410            RingKernelParamKind::from_param("control", &ctrl_ty),
411            RingKernelParamKind::ControlBlock
412        );
413
414        let ctx_ty: Type = parse_quote!(&RingContext);
415        assert_eq!(
416            RingKernelParamKind::from_param("ctx", &ctx_ty),
417            RingKernelParamKind::RingContext
418        );
419
420        let input_ty: Type = parse_quote!(&[u8]);
421        assert_eq!(
422            RingKernelParamKind::from_param("input_buffer", &input_ty),
423            RingKernelParamKind::InputBuffer
424        );
425
426        let output_ty: Type = parse_quote!(&mut [u8]);
427        assert_eq!(
428            RingKernelParamKind::from_param("output_buffer", &output_ty),
429            RingKernelParamKind::OutputBuffer
430        );
431
432        let regular_ty: Type = parse_quote!(f32);
433        assert_eq!(
434            RingKernelParamKind::from_param("value", &regular_ty),
435            RingKernelParamKind::Regular
436        );
437    }
438
439    #[test]
440    fn test_ring_kernel_type_mapper() {
441        let mapper = ring_kernel_type_mapper();
442
443        let ctrl_ty: Type = parse_quote!(ControlBlock);
444        assert_eq!(
445            mapper.map_type(&ctrl_ty).unwrap().to_cuda_string(),
446            "ControlBlock"
447        );
448
449        let hlc_ty: Type = parse_quote!(HlcState);
450        assert_eq!(
451            mapper.map_type(&hlc_ty).unwrap().to_cuda_string(),
452            "HlcState"
453        );
454
455        let ctx_ty: Type = parse_quote!(RingContext);
456        assert_eq!(mapper.map_type(&ctx_ty).unwrap().to_cuda_string(), "void");
457    }
458}