1use crate::{Result, TranspileError};
6use syn::{Type, TypePath, TypeReference};
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum CudaType {
11 Float,
13 Double,
14 Int,
15 UnsignedInt,
16 LongLong,
17 UnsignedLongLong,
18 Bool,
19 Void,
20
21 Pointer {
23 inner: Box<CudaType>,
24 is_const: bool,
25 restrict: bool,
26 },
27
28 Struct(String),
30}
31
32impl CudaType {
33 pub fn to_cuda_string(&self) -> String {
35 match self {
36 CudaType::Float => "float".to_string(),
37 CudaType::Double => "double".to_string(),
38 CudaType::Int => "int".to_string(),
39 CudaType::UnsignedInt => "unsigned int".to_string(),
40 CudaType::LongLong => "long long".to_string(),
41 CudaType::UnsignedLongLong => "unsigned long long".to_string(),
42 CudaType::Bool => "int".to_string(), CudaType::Void => "void".to_string(),
44 CudaType::Pointer {
45 inner,
46 is_const,
47 restrict,
48 } => {
49 let mut s = String::new();
50 if *is_const {
51 s.push_str("const ");
52 }
53 s.push_str(&inner.to_cuda_string());
54 s.push('*');
55 if *restrict {
56 s.push_str(" __restrict__");
57 }
58 s
59 }
60 CudaType::Struct(name) => name.clone(),
61 }
62 }
63}
64
65#[derive(Debug, Default)]
67pub struct TypeMapper {
68 custom_types: std::collections::HashMap<String, CudaType>,
70}
71
72impl TypeMapper {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn register_type(&mut self, rust_name: &str, cuda_type: CudaType) {
80 self.custom_types.insert(rust_name.to_string(), cuda_type);
81 }
82
83 pub fn map_type(&self, ty: &Type) -> Result<CudaType> {
85 match ty {
86 Type::Path(path) => self.map_type_path(path),
87 Type::Reference(reference) => self.map_reference(reference),
88 Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(CudaType::Void),
89 _ => Err(TranspileError::Type(format!(
90 "Unsupported type: {}",
91 quote::quote!(#ty)
92 ))),
93 }
94 }
95
96 fn map_type_path(&self, path: &TypePath) -> Result<CudaType> {
98 let segments: Vec<_> = path.path.segments.iter().collect();
99
100 if segments.len() != 1 {
101 let full_path = path
103 .path
104 .segments
105 .iter()
106 .map(|s| s.ident.to_string())
107 .collect::<Vec<_>>()
108 .join("::");
109
110 if let Some(cuda_type) = self.custom_types.get(&full_path) {
111 return Ok(cuda_type.clone());
112 }
113
114 return Err(TranspileError::Type(format!(
115 "Complex path types not supported: {}",
116 quote::quote!(#path)
117 )));
118 }
119
120 let ident = &segments[0].ident;
121 let type_name = ident.to_string();
122
123 match type_name.as_str() {
125 "f32" => Ok(CudaType::Float),
126 "f64" => Ok(CudaType::Double),
127 "i32" => Ok(CudaType::Int),
128 "u32" => Ok(CudaType::UnsignedInt),
129 "i64" => Ok(CudaType::LongLong),
130 "u64" => Ok(CudaType::UnsignedLongLong),
131 "bool" => Ok(CudaType::Bool),
132 "usize" => Ok(CudaType::UnsignedLongLong), "isize" => Ok(CudaType::LongLong),
134 "GridPos" => Ok(CudaType::Void),
136 _ => {
137 if let Some(cuda_type) = self.custom_types.get(&type_name) {
139 Ok(cuda_type.clone())
140 } else {
141 Ok(CudaType::Struct(type_name))
143 }
144 }
145 }
146 }
147
148 fn map_reference(&self, reference: &TypeReference) -> Result<CudaType> {
150 let is_mutable = reference.mutability.is_some();
151
152 match reference.elem.as_ref() {
153 Type::Slice(slice) => {
154 let inner = self.map_type(&slice.elem)?;
157 Ok(CudaType::Pointer {
158 inner: Box::new(inner),
159 is_const: !is_mutable,
160 restrict: true,
161 })
162 }
163 Type::Path(path) => {
164 let inner = self.map_type_path(path)?;
167 Ok(CudaType::Pointer {
168 inner: Box::new(inner),
169 is_const: !is_mutable,
170 restrict: false,
171 })
172 }
173 _ => Err(TranspileError::Type(format!(
174 "Unsupported reference type: {}",
175 quote::quote!(#reference)
176 ))),
177 }
178 }
179}
180
181pub fn get_slice_element_type(ty: &Type) -> Option<&Type> {
183 if let Type::Reference(reference) = ty {
184 if let Type::Slice(slice) = reference.elem.as_ref() {
185 return Some(&slice.elem);
186 }
187 }
188 None
189}
190
191pub fn is_mutable_reference(ty: &Type) -> bool {
193 matches!(ty, Type::Reference(r) if r.mutability.is_some())
194}
195
196pub fn is_grid_pos_type(ty: &Type) -> bool {
198 if let Type::Path(path) = ty {
199 if let Some(segment) = path.path.segments.last() {
200 return segment.ident == "GridPos";
201 }
202 }
203 false
204}
205
206pub fn is_control_block_type(ty: &Type) -> bool {
208 if let Type::Path(path) = ty {
209 if let Some(segment) = path.path.segments.last() {
210 return segment.ident == "ControlBlock";
211 }
212 }
213 if let Type::Reference(reference) = ty {
215 return is_control_block_type(&reference.elem);
216 }
217 false
218}
219
220pub fn is_ring_context_type(ty: &Type) -> bool {
222 if let Type::Path(path) = ty {
223 if let Some(segment) = path.path.segments.last() {
224 return segment.ident == "RingContext";
225 }
226 }
227 if let Type::Reference(reference) = ty {
229 return is_ring_context_type(&reference.elem);
230 }
231 false
232}
233
234#[allow(dead_code)]
236pub fn is_hlc_state_type(ty: &Type) -> bool {
237 if let Type::Path(path) = ty {
238 if let Some(segment) = path.path.segments.last() {
239 return segment.ident == "HlcState";
240 }
241 }
242 false
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247pub enum RingKernelParamKind {
248 ControlBlock,
250 InputBuffer,
252 OutputBuffer,
254 SharedState,
256 K2KRoutes,
258 RingContext,
260 Regular,
262}
263
264impl RingKernelParamKind {
265 pub fn from_param(name: &str, ty: &Type) -> Self {
267 if is_control_block_type(ty) {
269 return Self::ControlBlock;
270 }
271 if is_ring_context_type(ty) {
272 return Self::RingContext;
273 }
274
275 let name_lower = name.to_lowercase();
277 if name_lower.contains("control") {
278 return Self::ControlBlock;
279 }
280 if name_lower.contains("input") || name_lower == "inbox" {
281 return Self::InputBuffer;
282 }
283 if name_lower.contains("output") || name_lower == "outbox" {
284 return Self::OutputBuffer;
285 }
286 if name_lower.contains("k2k") || name_lower.contains("route") {
287 return Self::K2KRoutes;
288 }
289 if name_lower.contains("state") || name_lower.contains("shared") {
290 return Self::SharedState;
291 }
292 if name_lower == "ctx" || name_lower == "context" {
293 return Self::RingContext;
294 }
295
296 Self::Regular
297 }
298}
299
300pub fn ring_kernel_type_mapper() -> TypeMapper {
302 let mut mapper = TypeMapper::new();
303
304 mapper.register_type("ControlBlock", CudaType::Struct("ControlBlock".to_string()));
306
307 mapper.register_type("HlcState", CudaType::Struct("HlcState".to_string()));
309
310 mapper.register_type(
312 "K2KRoutingTable",
313 CudaType::Struct("K2KRoutingTable".to_string()),
314 );
315 mapper.register_type("K2KRoute", CudaType::Struct("K2KRoute".to_string()));
316
317 mapper.register_type("RingContext", CudaType::Void);
319
320 mapper
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use syn::parse_quote;
327
328 #[test]
329 fn test_primitive_types() {
330 let mapper = TypeMapper::new();
331
332 let f32_ty: Type = parse_quote!(f32);
333 assert_eq!(mapper.map_type(&f32_ty).unwrap().to_cuda_string(), "float");
334
335 let i32_ty: Type = parse_quote!(i32);
336 assert_eq!(mapper.map_type(&i32_ty).unwrap().to_cuda_string(), "int");
337
338 let bool_ty: Type = parse_quote!(bool);
339 assert_eq!(mapper.map_type(&bool_ty).unwrap().to_cuda_string(), "int");
340 }
341
342 #[test]
343 fn test_slice_types() {
344 let mapper = TypeMapper::new();
345
346 let slice_ty: Type = parse_quote!(&[f32]);
348 assert_eq!(
349 mapper.map_type(&slice_ty).unwrap().to_cuda_string(),
350 "const float* __restrict__"
351 );
352
353 let mut_slice_ty: Type = parse_quote!(&mut [f32]);
355 assert_eq!(
356 mapper.map_type(&mut_slice_ty).unwrap().to_cuda_string(),
357 "float* __restrict__"
358 );
359 }
360
361 #[test]
362 fn test_grid_pos_type() {
363 let ty: Type = parse_quote!(GridPos);
364 assert!(is_grid_pos_type(&ty));
365
366 let ty: Type = parse_quote!(f32);
367 assert!(!is_grid_pos_type(&ty));
368 }
369
370 #[test]
371 fn test_custom_types() {
372 let mut mapper = TypeMapper::new();
373 mapper.register_type("WaveParams", CudaType::Struct("WaveParams".to_string()));
374
375 let ty: Type = parse_quote!(WaveParams);
376 assert_eq!(mapper.map_type(&ty).unwrap().to_cuda_string(), "WaveParams");
377 }
378
379 #[test]
380 fn test_control_block_type() {
381 let ty: Type = parse_quote!(ControlBlock);
382 assert!(is_control_block_type(&ty));
383
384 let ref_ty: Type = parse_quote!(&ControlBlock);
385 assert!(is_control_block_type(&ref_ty));
386
387 let mut_ref_ty: Type = parse_quote!(&mut ControlBlock);
388 assert!(is_control_block_type(&mut_ref_ty));
389
390 let ty: Type = parse_quote!(f32);
391 assert!(!is_control_block_type(&ty));
392 }
393
394 #[test]
395 fn test_ring_context_type() {
396 let ty: Type = parse_quote!(RingContext);
397 assert!(is_ring_context_type(&ty));
398
399 let ref_ty: Type = parse_quote!(&RingContext);
400 assert!(is_ring_context_type(&ref_ty));
401
402 let ty: Type = parse_quote!(f32);
403 assert!(!is_ring_context_type(&ty));
404 }
405
406 #[test]
407 fn test_ring_kernel_param_kind() {
408 let ctrl_ty: Type = parse_quote!(&mut ControlBlock);
409 assert_eq!(
410 RingKernelParamKind::from_param("control", &ctrl_ty),
411 RingKernelParamKind::ControlBlock
412 );
413
414 let ctx_ty: Type = parse_quote!(&RingContext);
415 assert_eq!(
416 RingKernelParamKind::from_param("ctx", &ctx_ty),
417 RingKernelParamKind::RingContext
418 );
419
420 let input_ty: Type = parse_quote!(&[u8]);
421 assert_eq!(
422 RingKernelParamKind::from_param("input_buffer", &input_ty),
423 RingKernelParamKind::InputBuffer
424 );
425
426 let output_ty: Type = parse_quote!(&mut [u8]);
427 assert_eq!(
428 RingKernelParamKind::from_param("output_buffer", &output_ty),
429 RingKernelParamKind::OutputBuffer
430 );
431
432 let regular_ty: Type = parse_quote!(f32);
433 assert_eq!(
434 RingKernelParamKind::from_param("value", ®ular_ty),
435 RingKernelParamKind::Regular
436 );
437 }
438
439 #[test]
440 fn test_ring_kernel_type_mapper() {
441 let mapper = ring_kernel_type_mapper();
442
443 let ctrl_ty: Type = parse_quote!(ControlBlock);
444 assert_eq!(
445 mapper.map_type(&ctrl_ty).unwrap().to_cuda_string(),
446 "ControlBlock"
447 );
448
449 let hlc_ty: Type = parse_quote!(HlcState);
450 assert_eq!(
451 mapper.map_type(&hlc_ty).unwrap().to_cuda_string(),
452 "HlcState"
453 );
454
455 let ctx_ty: Type = parse_quote!(RingContext);
456 assert_eq!(mapper.map_type(&ctx_ty).unwrap().to_cuda_string(), "void");
457 }
458}