1use crate::{Result, TranspileError};
6use syn::{Type, TypePath, TypeReference};
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum CudaType {
11 Float,
13 Double,
14 Int,
15 UnsignedInt,
16 Short,
17 UnsignedShort,
18 Char,
19 UnsignedChar,
20 LongLong,
21 UnsignedLongLong,
22 Bool,
23 Void,
24
25 Pointer {
27 inner: Box<CudaType>,
28 is_const: bool,
29 restrict: bool,
30 },
31
32 Struct(String),
34}
35
36impl CudaType {
37 pub fn to_cuda_string(&self) -> String {
39 match self {
40 CudaType::Float => "float".to_string(),
41 CudaType::Double => "double".to_string(),
42 CudaType::Int => "int".to_string(),
43 CudaType::UnsignedInt => "unsigned int".to_string(),
44 CudaType::Short => "short".to_string(),
45 CudaType::UnsignedShort => "unsigned short".to_string(),
46 CudaType::Char => "char".to_string(),
47 CudaType::UnsignedChar => "unsigned char".to_string(),
48 CudaType::LongLong => "long long".to_string(),
49 CudaType::UnsignedLongLong => "unsigned long long".to_string(),
50 CudaType::Bool => "int".to_string(), CudaType::Void => "void".to_string(),
52 CudaType::Pointer {
53 inner,
54 is_const,
55 restrict,
56 } => {
57 let mut s = String::new();
58 if *is_const {
59 s.push_str("const ");
60 }
61 s.push_str(&inner.to_cuda_string());
62 s.push('*');
63 if *restrict {
64 s.push_str(" __restrict__");
65 }
66 s
67 }
68 CudaType::Struct(name) => name.clone(),
69 }
70 }
71}
72
73#[derive(Debug, Default)]
75pub struct TypeMapper {
76 custom_types: std::collections::HashMap<String, CudaType>,
78}
79
80impl TypeMapper {
81 pub fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn register_type(&mut self, rust_name: &str, cuda_type: CudaType) {
88 self.custom_types.insert(rust_name.to_string(), cuda_type);
89 }
90
91 pub fn map_type(&self, ty: &Type) -> Result<CudaType> {
93 match ty {
94 Type::Path(path) => self.map_type_path(path),
95 Type::Reference(reference) => self.map_reference(reference),
96 Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(CudaType::Void),
97 _ => Err(TranspileError::Type(format!(
98 "Unsupported type: {}",
99 quote::quote!(#ty)
100 ))),
101 }
102 }
103
104 fn map_type_path(&self, path: &TypePath) -> Result<CudaType> {
106 let segments: Vec<_> = path.path.segments.iter().collect();
107
108 if segments.len() != 1 {
109 let full_path = path
111 .path
112 .segments
113 .iter()
114 .map(|s| s.ident.to_string())
115 .collect::<Vec<_>>()
116 .join("::");
117
118 if let Some(cuda_type) = self.custom_types.get(&full_path) {
119 return Ok(cuda_type.clone());
120 }
121
122 return Err(TranspileError::Type(format!(
123 "Complex path types not supported: {}",
124 quote::quote!(#path)
125 )));
126 }
127
128 let ident = &segments[0].ident;
129 let type_name = ident.to_string();
130
131 match type_name.as_str() {
133 "f32" => Ok(CudaType::Float),
134 "f64" => Ok(CudaType::Double),
135 "i32" => Ok(CudaType::Int),
136 "u32" => Ok(CudaType::UnsignedInt),
137 "i16" => Ok(CudaType::Short),
138 "u16" => Ok(CudaType::UnsignedShort),
139 "i8" => Ok(CudaType::Char),
140 "u8" => Ok(CudaType::UnsignedChar),
141 "i64" => Ok(CudaType::LongLong),
142 "u64" => Ok(CudaType::UnsignedLongLong),
143 "bool" => Ok(CudaType::Bool),
144 "usize" => Ok(CudaType::UnsignedLongLong), "isize" => Ok(CudaType::LongLong),
146 "GridPos" => Ok(CudaType::Void),
148 _ => {
149 if let Some(cuda_type) = self.custom_types.get(&type_name) {
151 Ok(cuda_type.clone())
152 } else {
153 Ok(CudaType::Struct(type_name))
155 }
156 }
157 }
158 }
159
160 fn map_reference(&self, reference: &TypeReference) -> Result<CudaType> {
162 let is_mutable = reference.mutability.is_some();
163
164 match reference.elem.as_ref() {
165 Type::Slice(slice) => {
166 let inner = self.map_type(&slice.elem)?;
169 Ok(CudaType::Pointer {
170 inner: Box::new(inner),
171 is_const: !is_mutable,
172 restrict: true,
173 })
174 }
175 Type::Path(path) => {
176 let inner = self.map_type_path(path)?;
179 Ok(CudaType::Pointer {
180 inner: Box::new(inner),
181 is_const: !is_mutable,
182 restrict: false,
183 })
184 }
185 _ => Err(TranspileError::Type(format!(
186 "Unsupported reference type: {}",
187 quote::quote!(#reference)
188 ))),
189 }
190 }
191}
192
193pub fn get_slice_element_type(ty: &Type) -> Option<&Type> {
195 if let Type::Reference(reference) = ty {
196 if let Type::Slice(slice) = reference.elem.as_ref() {
197 return Some(&slice.elem);
198 }
199 }
200 None
201}
202
203pub fn is_mutable_reference(ty: &Type) -> bool {
205 matches!(ty, Type::Reference(r) if r.mutability.is_some())
206}
207
208pub fn is_grid_pos_type(ty: &Type) -> bool {
210 if let Type::Path(path) = ty {
211 if let Some(segment) = path.path.segments.last() {
212 return segment.ident == "GridPos";
213 }
214 }
215 false
216}
217
218pub fn is_control_block_type(ty: &Type) -> bool {
220 if let Type::Path(path) = ty {
221 if let Some(segment) = path.path.segments.last() {
222 return segment.ident == "ControlBlock";
223 }
224 }
225 if let Type::Reference(reference) = ty {
227 return is_control_block_type(&reference.elem);
228 }
229 false
230}
231
232pub fn is_ring_context_type(ty: &Type) -> bool {
234 if let Type::Path(path) = ty {
235 if let Some(segment) = path.path.segments.last() {
236 return segment.ident == "RingContext";
237 }
238 }
239 if let Type::Reference(reference) = ty {
241 return is_ring_context_type(&reference.elem);
242 }
243 false
244}
245
246#[allow(dead_code)]
248pub fn is_hlc_state_type(ty: &Type) -> bool {
249 if let Type::Path(path) = ty {
250 if let Some(segment) = path.path.segments.last() {
251 return segment.ident == "HlcState";
252 }
253 }
254 false
255}
256
257#[derive(Debug, Clone, Copy, PartialEq, Eq)]
259pub enum RingKernelParamKind {
260 ControlBlock,
262 InputBuffer,
264 OutputBuffer,
266 SharedState,
268 K2KRoutes,
270 RingContext,
272 Regular,
274}
275
276impl RingKernelParamKind {
277 pub fn from_param(name: &str, ty: &Type) -> Self {
279 if is_control_block_type(ty) {
281 return Self::ControlBlock;
282 }
283 if is_ring_context_type(ty) {
284 return Self::RingContext;
285 }
286
287 let name_lower = name.to_lowercase();
289 if name_lower.contains("control") {
290 return Self::ControlBlock;
291 }
292 if name_lower.contains("input") || name_lower == "inbox" {
293 return Self::InputBuffer;
294 }
295 if name_lower.contains("output") || name_lower == "outbox" {
296 return Self::OutputBuffer;
297 }
298 if name_lower.contains("k2k") || name_lower.contains("route") {
299 return Self::K2KRoutes;
300 }
301 if name_lower.contains("state") || name_lower.contains("shared") {
302 return Self::SharedState;
303 }
304 if name_lower == "ctx" || name_lower == "context" {
305 return Self::RingContext;
306 }
307
308 Self::Regular
309 }
310}
311
312pub fn ring_kernel_type_mapper() -> TypeMapper {
314 let mut mapper = TypeMapper::new();
315
316 mapper.register_type("ControlBlock", CudaType::Struct("ControlBlock".to_string()));
318
319 mapper.register_type("HlcState", CudaType::Struct("HlcState".to_string()));
321
322 mapper.register_type(
324 "K2KRoutingTable",
325 CudaType::Struct("K2KRoutingTable".to_string()),
326 );
327 mapper.register_type("K2KRoute", CudaType::Struct("K2KRoute".to_string()));
328
329 mapper.register_type("RingContext", CudaType::Void);
331
332 mapper
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use syn::parse_quote;
339
340 #[test]
341 fn test_primitive_types() {
342 let mapper = TypeMapper::new();
343
344 let f32_ty: Type = parse_quote!(f32);
345 assert_eq!(mapper.map_type(&f32_ty).unwrap().to_cuda_string(), "float");
346
347 let i32_ty: Type = parse_quote!(i32);
348 assert_eq!(mapper.map_type(&i32_ty).unwrap().to_cuda_string(), "int");
349
350 let bool_ty: Type = parse_quote!(bool);
351 assert_eq!(mapper.map_type(&bool_ty).unwrap().to_cuda_string(), "int");
352 }
353
354 #[test]
355 fn test_slice_types() {
356 let mapper = TypeMapper::new();
357
358 let slice_ty: Type = parse_quote!(&[f32]);
360 assert_eq!(
361 mapper.map_type(&slice_ty).unwrap().to_cuda_string(),
362 "const float* __restrict__"
363 );
364
365 let mut_slice_ty: Type = parse_quote!(&mut [f32]);
367 assert_eq!(
368 mapper.map_type(&mut_slice_ty).unwrap().to_cuda_string(),
369 "float* __restrict__"
370 );
371 }
372
373 #[test]
374 fn test_grid_pos_type() {
375 let ty: Type = parse_quote!(GridPos);
376 assert!(is_grid_pos_type(&ty));
377
378 let ty: Type = parse_quote!(f32);
379 assert!(!is_grid_pos_type(&ty));
380 }
381
382 #[test]
383 fn test_custom_types() {
384 let mut mapper = TypeMapper::new();
385 mapper.register_type("WaveParams", CudaType::Struct("WaveParams".to_string()));
386
387 let ty: Type = parse_quote!(WaveParams);
388 assert_eq!(mapper.map_type(&ty).unwrap().to_cuda_string(), "WaveParams");
389 }
390
391 #[test]
392 fn test_control_block_type() {
393 let ty: Type = parse_quote!(ControlBlock);
394 assert!(is_control_block_type(&ty));
395
396 let ref_ty: Type = parse_quote!(&ControlBlock);
397 assert!(is_control_block_type(&ref_ty));
398
399 let mut_ref_ty: Type = parse_quote!(&mut ControlBlock);
400 assert!(is_control_block_type(&mut_ref_ty));
401
402 let ty: Type = parse_quote!(f32);
403 assert!(!is_control_block_type(&ty));
404 }
405
406 #[test]
407 fn test_ring_context_type() {
408 let ty: Type = parse_quote!(RingContext);
409 assert!(is_ring_context_type(&ty));
410
411 let ref_ty: Type = parse_quote!(&RingContext);
412 assert!(is_ring_context_type(&ref_ty));
413
414 let ty: Type = parse_quote!(f32);
415 assert!(!is_ring_context_type(&ty));
416 }
417
418 #[test]
419 fn test_ring_kernel_param_kind() {
420 let ctrl_ty: Type = parse_quote!(&mut ControlBlock);
421 assert_eq!(
422 RingKernelParamKind::from_param("control", &ctrl_ty),
423 RingKernelParamKind::ControlBlock
424 );
425
426 let ctx_ty: Type = parse_quote!(&RingContext);
427 assert_eq!(
428 RingKernelParamKind::from_param("ctx", &ctx_ty),
429 RingKernelParamKind::RingContext
430 );
431
432 let input_ty: Type = parse_quote!(&[u8]);
433 assert_eq!(
434 RingKernelParamKind::from_param("input_buffer", &input_ty),
435 RingKernelParamKind::InputBuffer
436 );
437
438 let output_ty: Type = parse_quote!(&mut [u8]);
439 assert_eq!(
440 RingKernelParamKind::from_param("output_buffer", &output_ty),
441 RingKernelParamKind::OutputBuffer
442 );
443
444 let regular_ty: Type = parse_quote!(f32);
445 assert_eq!(
446 RingKernelParamKind::from_param("value", ®ular_ty),
447 RingKernelParamKind::Regular
448 );
449 }
450
451 #[test]
452 fn test_ring_kernel_type_mapper() {
453 let mapper = ring_kernel_type_mapper();
454
455 let ctrl_ty: Type = parse_quote!(ControlBlock);
456 assert_eq!(
457 mapper.map_type(&ctrl_ty).unwrap().to_cuda_string(),
458 "ControlBlock"
459 );
460
461 let hlc_ty: Type = parse_quote!(HlcState);
462 assert_eq!(
463 mapper.map_type(&hlc_ty).unwrap().to_cuda_string(),
464 "HlcState"
465 );
466
467 let ctx_ty: Type = parse_quote!(RingContext);
468 assert_eq!(mapper.map_type(&ctx_ty).unwrap().to_cuda_string(), "void");
469 }
470}