1use std::collections::HashMap;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum AddressSpace {
24 Function,
26 Private,
28 Workgroup,
30 Uniform,
32 Storage,
34}
35
36impl AddressSpace {
37 pub fn to_wgsl(&self) -> &'static str {
39 match self {
40 AddressSpace::Function => "function",
41 AddressSpace::Private => "private",
42 AddressSpace::Workgroup => "workgroup",
43 AddressSpace::Uniform => "uniform",
44 AddressSpace::Storage => "storage",
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub enum AccessMode {
52 #[default]
54 Read,
55 Write,
57 ReadWrite,
59}
60
61impl AccessMode {
62 pub fn to_wgsl(&self) -> &'static str {
64 match self {
65 AccessMode::Read => "read",
66 AccessMode::Write => "write",
67 AccessMode::ReadWrite => "read_write",
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74pub enum WgslType {
75 F32,
77 I32,
79 U32,
81 Bool,
83 Void,
85 Vec2(Box<WgslType>),
87 Vec3(Box<WgslType>),
89 Vec4(Box<WgslType>),
91 Mat2x2(Box<WgslType>),
93 Mat3x3(Box<WgslType>),
95 Mat4x4(Box<WgslType>),
97 Array {
99 element: Box<WgslType>,
100 size: Option<usize>,
102 },
103 Ptr {
105 address_space: AddressSpace,
106 inner: Box<WgslType>,
107 access: AccessMode,
108 },
109 Atomic(Box<WgslType>),
111 Struct(String),
113 U64Pair,
115 I64Pair,
117}
118
119impl WgslType {
120 pub fn to_wgsl(&self) -> String {
122 match self {
123 WgslType::F32 => "f32".to_string(),
124 WgslType::I32 => "i32".to_string(),
125 WgslType::U32 => "u32".to_string(),
126 WgslType::Bool => "bool".to_string(),
127 WgslType::Void => "".to_string(), WgslType::Vec2(inner) => format!("vec2<{}>", inner.to_wgsl()),
129 WgslType::Vec3(inner) => format!("vec3<{}>", inner.to_wgsl()),
130 WgslType::Vec4(inner) => format!("vec4<{}>", inner.to_wgsl()),
131 WgslType::Mat2x2(inner) => format!("mat2x2<{}>", inner.to_wgsl()),
132 WgslType::Mat3x3(inner) => format!("mat3x3<{}>", inner.to_wgsl()),
133 WgslType::Mat4x4(inner) => format!("mat4x4<{}>", inner.to_wgsl()),
134 WgslType::Array { element, size } => match size {
135 Some(n) => format!("array<{}, {}>", element.to_wgsl(), n),
136 None => format!("array<{}>", element.to_wgsl()),
137 },
138 WgslType::Ptr {
139 address_space,
140 inner,
141 access,
142 } => {
143 format!(
144 "ptr<{}, {}, {}>",
145 address_space.to_wgsl(),
146 inner.to_wgsl(),
147 access.to_wgsl()
148 )
149 }
150 WgslType::Atomic(inner) => format!("atomic<{}>", inner.to_wgsl()),
151 WgslType::Struct(name) => name.clone(),
152 WgslType::U64Pair => "vec2<u32>".to_string(),
153 WgslType::I64Pair => "vec2<i32>".to_string(),
154 }
155 }
156
157 pub fn is_emulated_64bit(&self) -> bool {
159 matches!(self, WgslType::U64Pair | WgslType::I64Pair)
160 }
161
162 pub fn is_scalar(&self) -> bool {
164 matches!(
165 self,
166 WgslType::F32 | WgslType::I32 | WgslType::U32 | WgslType::Bool
167 )
168 }
169
170 pub fn is_vector(&self) -> bool {
172 matches!(
173 self,
174 WgslType::Vec2(_) | WgslType::Vec3(_) | WgslType::Vec4(_)
175 )
176 }
177
178 pub fn element_type(&self) -> Option<&WgslType> {
180 match self {
181 WgslType::Array { element, .. } => Some(element),
182 WgslType::Vec2(e) | WgslType::Vec3(e) | WgslType::Vec4(e) => Some(e),
183 WgslType::Atomic(inner) => Some(inner),
184 _ => None,
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct TypeMapper {
192 custom_types: HashMap<String, WgslType>,
194 warn_on_lossy: bool,
196}
197
198impl Default for TypeMapper {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl TypeMapper {
205 pub fn new() -> Self {
207 Self {
208 custom_types: HashMap::new(),
209 warn_on_lossy: true,
210 }
211 }
212
213 pub fn register_type(&mut self, rust_name: &str, wgsl_type: WgslType) {
215 self.custom_types.insert(rust_name.to_string(), wgsl_type);
216 }
217
218 pub fn disable_lossy_warnings(&mut self) {
220 self.warn_on_lossy = false;
221 }
222
223 pub fn map_type(&self, ty: &syn::Type) -> Result<WgslType, String> {
225 match ty {
226 syn::Type::Path(type_path) => self.map_type_path(type_path),
227 syn::Type::Reference(type_ref) => self.map_reference(type_ref),
228 syn::Type::Array(type_array) => self.map_array(type_array),
229 syn::Type::Slice(type_slice) => self.map_slice(type_slice),
230 syn::Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(WgslType::Void),
231 _ => Err(format!("Unsupported type: {:?}", ty)),
232 }
233 }
234
235 fn map_type_path(&self, type_path: &syn::TypePath) -> Result<WgslType, String> {
236 let path = &type_path.path;
237
238 let segment = path
240 .segments
241 .last()
242 .ok_or_else(|| "Empty type path".to_string())?;
243
244 let ident = segment.ident.to_string();
245
246 if let Some(wgsl_type) = self.custom_types.get(&ident) {
248 return Ok(wgsl_type.clone());
249 }
250
251 match ident.as_str() {
253 "f32" => Ok(WgslType::F32),
254 "f64" => {
255 if self.warn_on_lossy {
256 eprintln!("Warning: f64 will be downcast to f32 (WGSL 1.0 has no f64)");
257 }
258 Ok(WgslType::F32) }
260 "i32" => Ok(WgslType::I32),
261 "u32" => Ok(WgslType::U32),
262 "i64" => Ok(WgslType::I64Pair), "u64" => Ok(WgslType::U64Pair), "bool" => Ok(WgslType::Bool),
265 "usize" => Ok(WgslType::U32), "isize" => Ok(WgslType::I32),
267
268 "Vec2" | "vec2" => {
270 let inner = self.extract_generic_arg(segment)?;
271 Ok(WgslType::Vec2(Box::new(inner)))
272 }
273 "Vec3" | "vec3" => {
274 let inner = self.extract_generic_arg(segment)?;
275 Ok(WgslType::Vec3(Box::new(inner)))
276 }
277 "Vec4" | "vec4" => {
278 let inner = self.extract_generic_arg(segment)?;
279 Ok(WgslType::Vec4(Box::new(inner)))
280 }
281
282 "GridPos" => Err("GridPos is a marker type".to_string()),
284 "RingContext" => Err("RingContext is a marker type".to_string()),
285
286 _ => Ok(WgslType::Struct(ident)),
288 }
289 }
290
291 fn map_reference(&self, type_ref: &syn::TypeReference) -> Result<WgslType, String> {
292 let inner = self.map_type(&type_ref.elem)?;
293 let is_mutable = type_ref.mutability.is_some();
294
295 if let syn::Type::Slice(_) = type_ref.elem.as_ref() {
297 let access = if is_mutable {
299 AccessMode::ReadWrite
300 } else {
301 AccessMode::Read
302 };
303
304 Ok(WgslType::Ptr {
307 address_space: AddressSpace::Storage,
308 inner: Box::new(inner),
309 access,
310 })
311 } else {
312 let access = if is_mutable {
314 AccessMode::ReadWrite
315 } else {
316 AccessMode::Read
317 };
318
319 Ok(WgslType::Ptr {
320 address_space: AddressSpace::Function,
321 inner: Box::new(inner),
322 access,
323 })
324 }
325 }
326
327 fn map_array(&self, type_array: &syn::TypeArray) -> Result<WgslType, String> {
328 let element = self.map_type(&type_array.elem)?;
329
330 let size = match &type_array.len {
332 syn::Expr::Lit(syn::ExprLit {
333 lit: syn::Lit::Int(lit),
334 ..
335 }) => lit.base10_parse::<usize>().map_err(|e| e.to_string())?,
336 _ => return Err("Array length must be a literal integer".to_string()),
337 };
338
339 Ok(WgslType::Array {
340 element: Box::new(element),
341 size: Some(size),
342 })
343 }
344
345 fn map_slice(&self, type_slice: &syn::TypeSlice) -> Result<WgslType, String> {
346 let element = self.map_type(&type_slice.elem)?;
347
348 Ok(WgslType::Array {
349 element: Box::new(element),
350 size: None, })
352 }
353
354 fn extract_generic_arg(&self, segment: &syn::PathSegment) -> Result<WgslType, String> {
355 match &segment.arguments {
356 syn::PathArguments::AngleBracketed(args) => {
357 if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
358 self.map_type(ty)
359 } else {
360 Err("Expected type argument".to_string())
361 }
362 }
363 _ => Err("Expected angle-bracketed arguments".to_string()),
364 }
365 }
366}
367
368pub fn is_grid_pos_type(ty: &syn::Type) -> bool {
370 if let syn::Type::Path(type_path) = ty {
371 if let Some(segment) = type_path.path.segments.last() {
372 return segment.ident == "GridPos";
373 }
374 }
375 false
376}
377
378pub fn is_ring_context_type(ty: &syn::Type) -> bool {
380 if let syn::Type::Path(type_path) = ty {
381 if let Some(segment) = type_path.path.segments.last() {
382 return segment.ident == "RingContext";
383 }
384 }
385 false
386}
387
388pub fn is_mutable_reference(ty: &syn::Type) -> bool {
390 if let syn::Type::Reference(type_ref) = ty {
391 return type_ref.mutability.is_some();
392 }
393 false
394}
395
396pub fn get_slice_element_type(ty: &syn::Type, mapper: &TypeMapper) -> Option<WgslType> {
398 if let syn::Type::Reference(type_ref) = ty {
399 if let syn::Type::Slice(type_slice) = type_ref.elem.as_ref() {
400 return mapper.map_type(&type_slice.elem).ok();
401 }
402 }
403 None
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use syn::parse_quote;
410
411 #[test]
412 fn test_primitive_types() {
413 let mapper = TypeMapper::new();
414
415 let ty: syn::Type = parse_quote!(f32);
416 assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::F32);
417
418 let ty: syn::Type = parse_quote!(i32);
419 assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::I32);
420
421 let ty: syn::Type = parse_quote!(u32);
422 assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::U32);
423
424 let ty: syn::Type = parse_quote!(bool);
425 assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::Bool);
426 }
427
428 #[test]
429 fn test_64bit_emulation() {
430 let mapper = TypeMapper::new();
431
432 let ty: syn::Type = parse_quote!(u64);
433 assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::U64Pair);
434
435 let ty: syn::Type = parse_quote!(i64);
436 assert_eq!(mapper.map_type(&ty).unwrap(), WgslType::I64Pair);
437 }
438
439 #[test]
440 fn test_slice_types() {
441 let mapper = TypeMapper::new();
442
443 let ty: syn::Type = parse_quote!(&[f32]);
444 let result = mapper.map_type(&ty).unwrap();
445 assert!(matches!(
446 result,
447 WgslType::Ptr {
448 access: AccessMode::Read,
449 ..
450 }
451 ));
452
453 let ty: syn::Type = parse_quote!(&mut [f32]);
454 let result = mapper.map_type(&ty).unwrap();
455 assert!(matches!(
456 result,
457 WgslType::Ptr {
458 access: AccessMode::ReadWrite,
459 ..
460 }
461 ));
462 }
463
464 #[test]
465 fn test_wgsl_output() {
466 assert_eq!(WgslType::F32.to_wgsl(), "f32");
467 assert_eq!(
468 WgslType::Vec2(Box::new(WgslType::F32)).to_wgsl(),
469 "vec2<f32>"
470 );
471 assert_eq!(WgslType::U64Pair.to_wgsl(), "vec2<u32>");
472 assert_eq!(
473 WgslType::Array {
474 element: Box::new(WgslType::F32),
475 size: Some(16)
476 }
477 .to_wgsl(),
478 "array<f32, 16>"
479 );
480 assert_eq!(
481 WgslType::Array {
482 element: Box::new(WgslType::F32),
483 size: None
484 }
485 .to_wgsl(),
486 "array<f32>"
487 );
488 assert_eq!(
489 WgslType::Atomic(Box::new(WgslType::U32)).to_wgsl(),
490 "atomic<u32>"
491 );
492 }
493
494 #[test]
495 fn test_custom_types() {
496 let mut mapper = TypeMapper::new();
497 mapper.register_type("MyStruct", WgslType::Struct("MyStruct".to_string()));
498
499 let ty: syn::Type = parse_quote!(MyStruct);
500 assert_eq!(
501 mapper.map_type(&ty).unwrap(),
502 WgslType::Struct("MyStruct".to_string())
503 );
504 }
505
506 #[test]
507 fn test_grid_pos_detection() {
508 let ty: syn::Type = parse_quote!(GridPos);
509 assert!(is_grid_pos_type(&ty));
510
511 let ty: syn::Type = parse_quote!(f32);
512 assert!(!is_grid_pos_type(&ty));
513 }
514
515 #[test]
516 fn test_ring_context_detection() {
517 let ty: syn::Type = parse_quote!(RingContext);
518 assert!(is_ring_context_type(&ty));
519
520 let ty: syn::Type = parse_quote!(&RingContext);
521 if let syn::Type::Reference(r) = &ty {
523 assert!(is_ring_context_type(&r.elem));
524 }
525 }
526}