Skip to main content

ringkernel_wgpu_codegen/
types.rs

1//! Type mapping from Rust to WGSL.
2//!
3//! This module handles the conversion of Rust types to their WGSL equivalents.
4//!
5//! # Type Mappings
6//!
7//! | Rust Type | WGSL Type | Notes |
8//! |-----------|-----------|-------|
9//! | `f32` | `f32` | Direct mapping |
10//! | `f64` | `f32` | **Warning**: Downcast, WGSL 1.0 has no f64 |
11//! | `i32` | `i32` | Direct mapping |
12//! | `u32` | `u32` | Direct mapping |
13//! | `i64` | `vec2<i32>` | Emulated as lo/hi pair |
14//! | `u64` | `vec2<u32>` | Emulated as lo/hi pair |
15//! | `bool` | `bool` | Direct mapping |
16//! | `&[T]` | `array<T>` | Storage buffer binding |
17//! | `&mut [T]` | `array<T>` | Storage buffer binding (read_write) |
18
19use std::collections::HashMap;
20
21/// WGSL address spaces for variable declarations.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum AddressSpace {
24    /// `var<function>` - Local function scope
25    Function,
26    /// `var<private>` - Thread-private global
27    Private,
28    /// `var<workgroup>` - Shared within workgroup
29    Workgroup,
30    /// `var<uniform>` - Uniform buffer
31    Uniform,
32    /// `var<storage>` - Storage buffer
33    Storage,
34}
35
36impl AddressSpace {
37    /// Convert to WGSL syntax.
38    pub fn to_wgsl(&self) -> &'static str {
39        match self {
40            AddressSpace::Function => "function",
41            AddressSpace::Private => "private",
42            AddressSpace::Workgroup => "workgroup",
43            AddressSpace::Uniform => "uniform",
44            AddressSpace::Storage => "storage",
45        }
46    }
47}
48
49/// Access mode for storage/uniform buffers.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub enum AccessMode {
52    /// Read-only access
53    #[default]
54    Read,
55    /// Write-only access (rare in WGSL)
56    Write,
57    /// Read and write access
58    ReadWrite,
59}
60
61impl AccessMode {
62    /// Convert to WGSL syntax.
63    pub fn to_wgsl(&self) -> &'static str {
64        match self {
65            AccessMode::Read => "read",
66            AccessMode::Write => "write",
67            AccessMode::ReadWrite => "read_write",
68        }
69    }
70}
71
72/// WGSL type representation.
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub enum WgslType {
75    /// 32-bit float
76    F32,
77    /// 32-bit signed integer
78    I32,
79    /// 32-bit unsigned integer
80    U32,
81    /// Boolean
82    Bool,
83    /// Void (for functions with no return)
84    Void,
85    /// 2-component vector
86    Vec2(Box<WgslType>),
87    /// 3-component vector
88    Vec3(Box<WgslType>),
89    /// 4-component vector
90    Vec4(Box<WgslType>),
91    /// 2x2 matrix
92    Mat2x2(Box<WgslType>),
93    /// 3x3 matrix
94    Mat3x3(Box<WgslType>),
95    /// 4x4 matrix
96    Mat4x4(Box<WgslType>),
97    /// Array type
98    Array {
99        element: Box<WgslType>,
100        /// None means runtime-sized array
101        size: Option<usize>,
102    },
103    /// Pointer type (for function parameters)
104    Ptr {
105        address_space: AddressSpace,
106        inner: Box<WgslType>,
107        access: AccessMode,
108    },
109    /// Atomic type
110    Atomic(Box<WgslType>),
111    /// User-defined struct
112    Struct(String),
113    /// Emulated 64-bit unsigned (stored as vec2<u32>)
114    U64Pair,
115    /// Emulated 64-bit signed (stored as vec2<i32>)
116    I64Pair,
117}
118
119impl WgslType {
120    /// Convert to WGSL type syntax.
121    pub fn to_wgsl(&self) -> String {
122        match self {
123            WgslType::F32 => "f32".to_string(),
124            WgslType::I32 => "i32".to_string(),
125            WgslType::U32 => "u32".to_string(),
126            WgslType::Bool => "bool".to_string(),
127            WgslType::Void => "".to_string(), // Functions with no return type
128            WgslType::Vec2(inner) => format!("vec2<{}>", inner.to_wgsl()),
129            WgslType::Vec3(inner) => format!("vec3<{}>", inner.to_wgsl()),
130            WgslType::Vec4(inner) => format!("vec4<{}>", inner.to_wgsl()),
131            WgslType::Mat2x2(inner) => format!("mat2x2<{}>", inner.to_wgsl()),
132            WgslType::Mat3x3(inner) => format!("mat3x3<{}>", inner.to_wgsl()),
133            WgslType::Mat4x4(inner) => format!("mat4x4<{}>", inner.to_wgsl()),
134            WgslType::Array { element, size } => match size {
135                Some(n) => format!("array<{}, {}>", element.to_wgsl(), n),
136                None => format!("array<{}>", element.to_wgsl()),
137            },
138            WgslType::Ptr {
139                address_space,
140                inner,
141                access,
142            } => {
143                format!(
144                    "ptr<{}, {}, {}>",
145                    address_space.to_wgsl(),
146                    inner.to_wgsl(),
147                    access.to_wgsl()
148                )
149            }
150            WgslType::Atomic(inner) => format!("atomic<{}>", inner.to_wgsl()),
151            WgslType::Struct(name) => name.clone(),
152            WgslType::U64Pair => "vec2<u32>".to_string(),
153            WgslType::I64Pair => "vec2<i32>".to_string(),
154        }
155    }
156
157    /// Check if this type is a 64-bit emulated type.
158    pub fn is_emulated_64bit(&self) -> bool {
159        matches!(self, WgslType::U64Pair | WgslType::I64Pair)
160    }
161
162    /// Check if this type is a scalar.
163    pub fn is_scalar(&self) -> bool {
164        matches!(
165            self,
166            WgslType::F32 | WgslType::I32 | WgslType::U32 | WgslType::Bool
167        )
168    }
169
170    /// Check if this type is a vector.
171    pub fn is_vector(&self) -> bool {
172        matches!(
173            self,
174            WgslType::Vec2(_) | WgslType::Vec3(_) | WgslType::Vec4(_)
175        )
176    }
177
178    /// Get the element type for arrays and vectors.
179    pub fn element_type(&self) -> Option<&WgslType> {
180        match self {
181            WgslType::Array { element, .. } => Some(element),
182            WgslType::Vec2(e) | WgslType::Vec3(e) | WgslType::Vec4(e) => Some(e),
183            WgslType::Atomic(inner) => Some(inner),
184            _ => None,
185        }
186    }
187}
188
189/// Type mapper for converting Rust types to WGSL types.
190#[derive(Debug, Clone)]
191pub struct TypeMapper {
192    /// Custom type mappings (Rust type name -> WGSL type)
193    custom_types: HashMap<String, WgslType>,
194    /// Whether to emit warnings for lossy conversions
195    warn_on_lossy: bool,
196}
197
198impl Default for TypeMapper {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl TypeMapper {
205    /// Create a new type mapper with default mappings.
206    pub fn new() -> Self {
207        Self {
208            custom_types: HashMap::new(),
209            warn_on_lossy: true,
210        }
211    }
212
213    /// Register a custom type mapping.
214    pub fn register_type(&mut self, rust_name: &str, wgsl_type: WgslType) {
215        self.custom_types.insert(rust_name.to_string(), wgsl_type);
216    }
217
218    /// Disable warnings for lossy conversions (e.g., f64 -> f32).
219    pub fn disable_lossy_warnings(&mut self) {
220        self.warn_on_lossy = false;
221    }
222
223    /// Map a Rust type to a WGSL type.
224    pub fn map_type(&self, ty: &syn::Type) -> Result<WgslType, String> {
225        match ty {
226            syn::Type::Path(type_path) => self.map_type_path(type_path),
227            syn::Type::Reference(type_ref) => self.map_reference(type_ref),
228            syn::Type::Array(type_array) => self.map_array(type_array),
229            syn::Type::Slice(type_slice) => self.map_slice(type_slice),
230            syn::Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(WgslType::Void),
231            _ => Err(format!("Unsupported type: {:?}", ty)),
232        }
233    }
234
235    fn map_type_path(&self, type_path: &syn::TypePath) -> Result<WgslType, String> {
236        let path = &type_path.path;
237
238        // Get the last segment (e.g., "f32" from "std::f32")
239        let segment = path
240            .segments
241            .last()
242            .ok_or_else(|| "Empty type path".to_string())?;
243
244        let ident = segment.ident.to_string();
245
246        // Check custom types first
247        if let Some(wgsl_type) = self.custom_types.get(&ident) {
248            return Ok(wgsl_type.clone());
249        }
250
251        // Built-in type mappings
252        match ident.as_str() {
253            "f32" => Ok(WgslType::F32),
254            "f64" => {
255                if self.warn_on_lossy {
256                    eprintln!("Warning: f64 will be downcast to f32 (WGSL 1.0 has no f64)");
257                }
258                Ok(WgslType::F32) // Downcast with warning
259            }
260            "i32" => Ok(WgslType::I32),
261            "u32" => Ok(WgslType::U32),
262            "i64" => Ok(WgslType::I64Pair), // Emulated
263            "u64" => Ok(WgslType::U64Pair), // Emulated
264            "bool" => Ok(WgslType::Bool),
265            "usize" => Ok(WgslType::U32), // WGSL uses 32-bit addressing
266            "isize" => Ok(WgslType::I32),
267
268            // Vector types (if we support them)
269            "Vec2" | "vec2" => {
270                let inner = self.extract_generic_arg(segment)?;
271                Ok(WgslType::Vec2(Box::new(inner)))
272            }
273            "Vec3" | "vec3" => {
274                let inner = self.extract_generic_arg(segment)?;
275                Ok(WgslType::Vec3(Box::new(inner)))
276            }
277            "Vec4" | "vec4" => {
278                let inner = self.extract_generic_arg(segment)?;
279                Ok(WgslType::Vec4(Box::new(inner)))
280            }
281
282            // Special marker types (removed during transpilation)
283            "GridPos" => Err("GridPos is a marker type".to_string()),
284            "RingContext" => Err("RingContext is a marker type".to_string()),
285
286            // Assume user-defined struct
287            _ => Ok(WgslType::Struct(ident)),
288        }
289    }
290
291    fn map_reference(&self, type_ref: &syn::TypeReference) -> Result<WgslType, String> {
292        let inner = self.map_type(&type_ref.elem)?;
293        let is_mutable = type_ref.mutability.is_some();
294
295        // Check if it's a slice reference
296        if let syn::Type::Slice(_) = type_ref.elem.as_ref() {
297            // Slice references become storage buffer arrays
298            let access = if is_mutable {
299                AccessMode::ReadWrite
300            } else {
301                AccessMode::Read
302            };
303
304            // For function parameters, we use ptr<storage, ...>
305            // For bindings, we use var<storage, ...>
306            Ok(WgslType::Ptr {
307                address_space: AddressSpace::Storage,
308                inner: Box::new(inner),
309                access,
310            })
311        } else {
312            // Regular references become pointers
313            let access = if is_mutable {
314                AccessMode::ReadWrite
315            } else {
316                AccessMode::Read
317            };
318
319            Ok(WgslType::Ptr {
320                address_space: AddressSpace::Function,
321                inner: Box::new(inner),
322                access,
323            })
324        }
325    }
326
327    fn map_array(&self, type_array: &syn::TypeArray) -> Result<WgslType, String> {
328        let element = self.map_type(&type_array.elem)?;
329
330        // Extract the array length
331        let size = match &type_array.len {
332            syn::Expr::Lit(syn::ExprLit {
333                lit: syn::Lit::Int(lit),
334                ..
335            }) => lit.base10_parse::<usize>().map_err(|e| e.to_string())?,
336            _ => return Err("Array length must be a literal integer".to_string()),
337        };
338
339        Ok(WgslType::Array {
340            element: Box::new(element),
341            size: Some(size),
342        })
343    }
344
345    fn map_slice(&self, type_slice: &syn::TypeSlice) -> Result<WgslType, String> {
346        let element = self.map_type(&type_slice.elem)?;
347
348        Ok(WgslType::Array {
349            element: Box::new(element),
350            size: None, // Runtime-sized
351        })
352    }
353
354    fn extract_generic_arg(&self, segment: &syn::PathSegment) -> Result<WgslType, String> {
355        match &segment.arguments {
356            syn::PathArguments::AngleBracketed(args) => {
357                if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
358                    self.map_type(ty)
359                } else {
360                    Err("Expected type argument".to_string())
361                }
362            }
363            _ => Err("Expected angle-bracketed arguments".to_string()),
364        }
365    }
366}
367
368/// Check if a type is the GridPos marker type.
369pub fn is_grid_pos_type(ty: &syn::Type) -> bool {
370    if let syn::Type::Path(type_path) = ty {
371        if let Some(segment) = type_path.path.segments.last() {
372            return segment.ident == "GridPos";
373        }
374    }
375    false
376}
377
378/// Check if a type is the RingContext marker type.
379pub fn is_ring_context_type(ty: &syn::Type) -> bool {
380    if let syn::Type::Path(type_path) = ty {
381        if let Some(segment) = type_path.path.segments.last() {
382            return segment.ident == "RingContext";
383        }
384    }
385    false
386}
387
388/// Check if a type is a mutable reference.
389pub fn is_mutable_reference(ty: &syn::Type) -> bool {
390    if let syn::Type::Reference(type_ref) = ty {
391        return type_ref.mutability.is_some();
392    }
393    false
394}
395
396/// Get the element type of a slice type.
397pub fn get_slice_element_type(ty: &syn::Type, mapper: &TypeMapper) -> Option<WgslType> {
398    if let syn::Type::Reference(type_ref) = ty {
399        if let syn::Type::Slice(type_slice) = type_ref.elem.as_ref() {
400            return mapper.map_type(&type_slice.elem).ok();
401        }
402    }
403    None
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use syn::parse_quote;
410
411    #[test]
412    fn test_primitive_types() {
413        let mapper = TypeMapper::new();
414
415        let ty: syn::Type = parse_quote!(f32);
416        assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::F32);
417
418        let ty: syn::Type = parse_quote!(i32);
419        assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::I32);
420
421        let ty: syn::Type = parse_quote!(u32);
422        assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::U32);
423
424        let ty: syn::Type = parse_quote!(bool);
425        assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::Bool);
426    }
427
428    #[test]
429    fn test_64bit_emulation() {
430        let mapper = TypeMapper::new();
431
432        let ty: syn::Type = parse_quote!(u64);
433        assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::U64Pair);
434
435        let ty: syn::Type = parse_quote!(i64);
436        assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::I64Pair);
437    }
438
439    #[test]
440    fn test_slice_types() {
441        let mapper = TypeMapper::new();
442
443        let ty: syn::Type = parse_quote!(&[f32]);
444        let result = mapper.map_type(&ty).unwrap();
445        assert!(matches!(
446            result,
447            WgslType::Ptr {
448                access: AccessMode::Read,
449                ..
450            }
451        ));
452
453        let ty: syn::Type = parse_quote!(&mut [f32]);
454        let result = mapper.map_type(&ty).unwrap();
455        assert!(matches!(
456            result,
457            WgslType::Ptr {
458                access: AccessMode::ReadWrite,
459                ..
460            }
461        ));
462    }
463
464    #[test]
465    fn test_wgsl_output() {
466        assert_eq!(WgslType::F32.to_wgsl(), "f32");
467        assert_eq!(
468            WgslType::Vec2(Box::new(WgslType::F32)).to_wgsl(),
469            "vec2<f32>"
470        );
471        assert_eq!(WgslType::U64Pair.to_wgsl(), "vec2<u32>");
472        assert_eq!(
473            WgslType::Array {
474                element: Box::new(WgslType::F32),
475                size: Some(16)
476            }
477            .to_wgsl(),
478            "array<f32, 16>"
479        );
480        assert_eq!(
481            WgslType::Array {
482                element: Box::new(WgslType::F32),
483                size: None
484            }
485            .to_wgsl(),
486            "array<f32>"
487        );
488        assert_eq!(
489            WgslType::Atomic(Box::new(WgslType::U32)).to_wgsl(),
490            "atomic<u32>"
491        );
492    }
493
494    #[test]
495    fn test_custom_types() {
496        let mut mapper = TypeMapper::new();
497        mapper.register_type("MyStruct", WgslType::Struct("MyStruct".to_string()));
498
499        let ty: syn::Type = parse_quote!(MyStruct);
500        assert_eq!(
501            mapper.map_type(&ty).unwrap(),
502            WgslType::Struct("MyStruct".to_string())
503        );
504    }
505
506    #[test]
507    fn test_grid_pos_detection() {
508        let ty: syn::Type = parse_quote!(GridPos);
509        assert!(is_grid_pos_type(&ty));
510
511        let ty: syn::Type = parse_quote!(f32);
512        assert!(!is_grid_pos_type(&ty));
513    }
514
515    #[test]
516    fn test_ring_context_detection() {
517        let ty: syn::Type = parse_quote!(RingContext);
518        assert!(is_ring_context_type(&ty));
519
520        let ty: syn::Type = parse_quote!(&RingContext);
521        // Reference to RingContext should also be detected
522        if let syn::Type::Reference(r) = &ty {
523            assert!(is_ring_context_type(&r.elem));
524        }
525    }
526}