Skip to main content

cuda_rust_wasm/transpiler/
memory_mapper.rs

1//! CUDA memory space to target memory mapping
2//!
3//! Maps CUDA memory address spaces (global, shared, constant, register, local)
4//! to their equivalents in Rust (for CPU fallback) and WGSL (for WebGPU compute
5//! shaders).
6
7use crate::parser::ast::StorageClass;
8
9/// Memory space descriptor for a target platform.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct MemoryMapping {
12    /// The target-specific storage qualifier / address space annotation.
13    pub qualifier: String,
14    /// Human-readable description of the mapping.
15    pub description: String,
16    /// Whether this memory space requires explicit synchronisation.
17    pub requires_sync: bool,
18    /// Whether the memory is read-only.
19    pub read_only: bool,
20}
21
22/// Maps CUDA memory spaces to Rust and WGSL equivalents.
23pub struct MemoryMapper;
24
25impl MemoryMapper {
26    // -----------------------------------------------------------------------
27    // Rust target
28    // -----------------------------------------------------------------------
29
30    /// Map a CUDA `StorageClass` to its Rust representation.
31    pub fn to_rust(storage: &StorageClass) -> MemoryMapping {
32        match storage {
33            StorageClass::Global => MemoryMapping {
34                qualifier: "/* global */ ".to_string(),
35                description: "Heap-allocated device buffer (Vec<T> or &mut [T])".to_string(),
36                requires_sync: false,
37                read_only: false,
38            },
39            StorageClass::Shared => MemoryMapping {
40                qualifier: "#[shared] ".to_string(),
41                description: "Thread-local shared memory (SharedMemory<T>)".to_string(),
42                requires_sync: true,
43                read_only: false,
44            },
45            StorageClass::Constant => MemoryMapping {
46                qualifier: "const ".to_string(),
47                description: "Compile-time constant (const or static)".to_string(),
48                requires_sync: false,
49                read_only: true,
50            },
51            StorageClass::Register => MemoryMapping {
52                qualifier: "".to_string(),
53                description: "Local variable (stack-allocated)".to_string(),
54                requires_sync: false,
55                read_only: false,
56            },
57            StorageClass::Local => MemoryMapping {
58                qualifier: "".to_string(),
59                description: "Local variable (stack-allocated)".to_string(),
60                requires_sync: false,
61                read_only: false,
62            },
63            StorageClass::Auto => MemoryMapping {
64                qualifier: "let ".to_string(),
65                description: "Auto storage (stack-allocated)".to_string(),
66                requires_sync: false,
67                read_only: false,
68            },
69        }
70    }
71
72    /// Generate a Rust variable declaration prefix for the given storage class.
73    ///
74    /// # Examples
75    /// - `StorageClass::Shared` -> `"/* __shared__ */ let mut "`
76    /// - `StorageClass::Constant` -> `"const "`
77    /// - `StorageClass::Auto` -> `"let "`
78    pub fn rust_var_prefix(storage: &StorageClass, mutable: bool) -> String {
79        match storage {
80            StorageClass::Shared => {
81                if mutable {
82                    "/* __shared__ */ let mut ".to_string()
83                } else {
84                    "/* __shared__ */ let ".to_string()
85                }
86            }
87            StorageClass::Constant => "const ".to_string(),
88            StorageClass::Global => {
89                if mutable {
90                    "/* __device__ */ static mut ".to_string()
91                } else {
92                    "/* __device__ */ static ".to_string()
93                }
94            }
95            StorageClass::Register | StorageClass::Local | StorageClass::Auto => {
96                if mutable {
97                    "let mut ".to_string()
98                } else {
99                    "let ".to_string()
100                }
101            }
102        }
103    }
104
105    // -----------------------------------------------------------------------
106    // WGSL target
107    // -----------------------------------------------------------------------
108
109    /// Map a CUDA `StorageClass` to its WGSL representation.
110    pub fn to_wgsl(storage: &StorageClass) -> MemoryMapping {
111        match storage {
112            StorageClass::Global => MemoryMapping {
113                qualifier: "var<storage, read_write>".to_string(),
114                description: "Storage buffer (read_write)".to_string(),
115                requires_sync: false,
116                read_only: false,
117            },
118            StorageClass::Shared => MemoryMapping {
119                qualifier: "var<workgroup>".to_string(),
120                description: "Workgroup memory (shared within workgroup)".to_string(),
121                requires_sync: true,
122                read_only: false,
123            },
124            StorageClass::Constant => MemoryMapping {
125                qualifier: "var<uniform>".to_string(),
126                description: "Uniform buffer (read-only)".to_string(),
127                requires_sync: false,
128                read_only: true,
129            },
130            StorageClass::Register => MemoryMapping {
131                qualifier: "var<private>".to_string(),
132                description: "Private variable (per-invocation)".to_string(),
133                requires_sync: false,
134                read_only: false,
135            },
136            StorageClass::Local => MemoryMapping {
137                qualifier: "var<private>".to_string(),
138                description: "Private variable (per-invocation)".to_string(),
139                requires_sync: false,
140                read_only: false,
141            },
142            StorageClass::Auto => MemoryMapping {
143                qualifier: "var".to_string(),
144                description: "Function-scope variable".to_string(),
145                requires_sync: false,
146                read_only: false,
147            },
148        }
149    }
150
151    /// Generate a WGSL variable declaration for the given storage class.
152    ///
153    /// # Arguments
154    /// * `storage` - The CUDA storage class
155    /// * `name` - Variable name
156    /// * `wgsl_type` - WGSL type string (e.g. "f32", "array<f32, 256>")
157    ///
158    /// # Returns
159    /// A complete WGSL variable declaration string.
160    pub fn wgsl_var_decl(storage: &StorageClass, name: &str, wgsl_type: &str) -> String {
161        let mapping = Self::to_wgsl(storage);
162        format!("{} {}: {};", mapping.qualifier, name, wgsl_type)
163    }
164
165    /// Generate a WGSL binding declaration for a kernel parameter.
166    ///
167    /// # Arguments
168    /// * `storage` - The CUDA storage class
169    /// * `group` - Binding group number
170    /// * `binding` - Binding index
171    /// * `name` - Variable name
172    /// * `wgsl_type` - WGSL type string
173    /// * `read_only` - Whether the binding is read-only
174    pub fn wgsl_binding_decl(
175        storage: &StorageClass,
176        group: u32,
177        binding: u32,
178        name: &str,
179        wgsl_type: &str,
180        read_only: bool,
181    ) -> String {
182        let access = match storage {
183            StorageClass::Constant => "var<storage, read>",
184            StorageClass::Global => {
185                if read_only {
186                    "var<storage, read>"
187                } else {
188                    "var<storage, read_write>"
189                }
190            }
191            _ => {
192                let mapping = Self::to_wgsl(storage);
193                return format!(
194                    "@group({group}) @binding({binding})\n{} {name}: {wgsl_type};",
195                    mapping.qualifier
196                );
197            }
198        };
199
200        format!(
201            "@group({group}) @binding({binding})\n{access} {name}: {wgsl_type};"
202        )
203    }
204
205    // -----------------------------------------------------------------------
206    // CUDA memory space name -> StorageClass
207    // -----------------------------------------------------------------------
208
209    /// Parse a CUDA memory space qualifier string into a `StorageClass`.
210    pub fn parse_cuda_qualifier(qualifier: &str) -> StorageClass {
211        match qualifier.trim() {
212            "__shared__" | "shared" => StorageClass::Shared,
213            "__constant__" | "constant" => StorageClass::Constant,
214            "__device__" | "device" => StorageClass::Global,
215            "__managed__" | "managed" => StorageClass::Global,
216            "register" => StorageClass::Register,
217            "local" => StorageClass::Local,
218            _ => StorageClass::Auto,
219        }
220    }
221
222    // -----------------------------------------------------------------------
223    // Query helpers
224    // -----------------------------------------------------------------------
225
226    /// Returns true if the storage class requires barrier synchronisation
227    /// before other threads can see writes.
228    pub fn requires_barrier(storage: &StorageClass) -> bool {
229        matches!(storage, StorageClass::Shared)
230    }
231
232    /// Returns true if the storage class is read-only.
233    pub fn is_read_only(storage: &StorageClass) -> bool {
234        matches!(storage, StorageClass::Constant)
235    }
236
237    /// Returns the WGSL barrier function name needed after writes to this
238    /// memory space, if any.
239    pub fn wgsl_barrier(storage: &StorageClass) -> Option<&'static str> {
240        match storage {
241            StorageClass::Shared => Some("workgroupBarrier()"),
242            StorageClass::Global => Some("storageBarrier()"),
243            _ => None,
244        }
245    }
246
247    /// Returns the Rust synchronization primitive needed after writes to this
248    /// memory space, if any.
249    pub fn rust_barrier(storage: &StorageClass) -> Option<&'static str> {
250        match storage {
251            StorageClass::Shared => {
252                Some("std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst)")
253            }
254            StorageClass::Global => {
255                Some("std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst)")
256            }
257            _ => None,
258        }
259    }
260}
261
262// ---------------------------------------------------------------------------
263// Tests
264// ---------------------------------------------------------------------------
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_global_to_rust() {
271        let mapping = MemoryMapper::to_rust(&StorageClass::Global);
272        assert!(!mapping.requires_sync);
273        assert!(!mapping.read_only);
274    }
275
276    #[test]
277    fn test_shared_to_rust() {
278        let mapping = MemoryMapper::to_rust(&StorageClass::Shared);
279        assert!(mapping.requires_sync);
280        assert!(!mapping.read_only);
281    }
282
283    #[test]
284    fn test_constant_to_rust() {
285        let mapping = MemoryMapper::to_rust(&StorageClass::Constant);
286        assert!(!mapping.requires_sync);
287        assert!(mapping.read_only);
288        assert_eq!(mapping.qualifier, "const ");
289    }
290
291    #[test]
292    fn test_register_to_rust() {
293        let mapping = MemoryMapper::to_rust(&StorageClass::Register);
294        assert!(!mapping.requires_sync);
295        assert_eq!(mapping.qualifier, "");
296    }
297
298    #[test]
299    fn test_rust_var_prefix_shared() {
300        let prefix = MemoryMapper::rust_var_prefix(&StorageClass::Shared, true);
301        assert!(prefix.contains("__shared__"));
302        assert!(prefix.contains("let mut"));
303    }
304
305    #[test]
306    fn test_rust_var_prefix_const() {
307        let prefix = MemoryMapper::rust_var_prefix(&StorageClass::Constant, false);
308        assert_eq!(prefix, "const ");
309    }
310
311    #[test]
312    fn test_global_to_wgsl() {
313        let mapping = MemoryMapper::to_wgsl(&StorageClass::Global);
314        assert_eq!(mapping.qualifier, "var<storage, read_write>");
315        assert!(!mapping.read_only);
316    }
317
318    #[test]
319    fn test_shared_to_wgsl() {
320        let mapping = MemoryMapper::to_wgsl(&StorageClass::Shared);
321        assert_eq!(mapping.qualifier, "var<workgroup>");
322        assert!(mapping.requires_sync);
323    }
324
325    #[test]
326    fn test_constant_to_wgsl() {
327        let mapping = MemoryMapper::to_wgsl(&StorageClass::Constant);
328        assert_eq!(mapping.qualifier, "var<uniform>");
329        assert!(mapping.read_only);
330    }
331
332    #[test]
333    fn test_register_to_wgsl() {
334        let mapping = MemoryMapper::to_wgsl(&StorageClass::Register);
335        assert_eq!(mapping.qualifier, "var<private>");
336    }
337
338    #[test]
339    fn test_wgsl_var_decl() {
340        let decl = MemoryMapper::wgsl_var_decl(
341            &StorageClass::Shared,
342            "shared_data",
343            "array<f32, 256>",
344        );
345        assert_eq!(decl, "var<workgroup> shared_data: array<f32, 256>;");
346    }
347
348    #[test]
349    fn test_wgsl_binding_decl() {
350        let decl = MemoryMapper::wgsl_binding_decl(
351            &StorageClass::Global,
352            0,
353            0,
354            "data",
355            "array<f32>",
356            false,
357        );
358        assert!(decl.contains("@group(0) @binding(0)"));
359        assert!(decl.contains("read_write"));
360    }
361
362    #[test]
363    fn test_wgsl_binding_decl_readonly() {
364        let decl = MemoryMapper::wgsl_binding_decl(
365            &StorageClass::Global,
366            0,
367            1,
368            "input",
369            "array<f32>",
370            true,
371        );
372        assert!(decl.contains("read"));
373        assert!(!decl.contains("read_write"));
374    }
375
376    #[test]
377    fn test_parse_cuda_qualifier() {
378        assert!(matches!(
379            MemoryMapper::parse_cuda_qualifier("__shared__"),
380            StorageClass::Shared
381        ));
382        assert!(matches!(
383            MemoryMapper::parse_cuda_qualifier("__constant__"),
384            StorageClass::Constant
385        ));
386        assert!(matches!(
387            MemoryMapper::parse_cuda_qualifier("__device__"),
388            StorageClass::Global
389        ));
390        assert!(matches!(
391            MemoryMapper::parse_cuda_qualifier("register"),
392            StorageClass::Register
393        ));
394        assert!(matches!(
395            MemoryMapper::parse_cuda_qualifier("unknown"),
396            StorageClass::Auto
397        ));
398    }
399
400    #[test]
401    fn test_requires_barrier() {
402        assert!(MemoryMapper::requires_barrier(&StorageClass::Shared));
403        assert!(!MemoryMapper::requires_barrier(&StorageClass::Global));
404        assert!(!MemoryMapper::requires_barrier(&StorageClass::Register));
405    }
406
407    #[test]
408    fn test_is_read_only() {
409        assert!(MemoryMapper::is_read_only(&StorageClass::Constant));
410        assert!(!MemoryMapper::is_read_only(&StorageClass::Global));
411        assert!(!MemoryMapper::is_read_only(&StorageClass::Shared));
412    }
413
414    #[test]
415    fn test_wgsl_barrier() {
416        assert_eq!(
417            MemoryMapper::wgsl_barrier(&StorageClass::Shared),
418            Some("workgroupBarrier()")
419        );
420        assert_eq!(
421            MemoryMapper::wgsl_barrier(&StorageClass::Global),
422            Some("storageBarrier()")
423        );
424        assert_eq!(MemoryMapper::wgsl_barrier(&StorageClass::Register), None);
425    }
426
427    #[test]
428    fn test_rust_barrier() {
429        assert!(MemoryMapper::rust_barrier(&StorageClass::Shared).is_some());
430        assert!(MemoryMapper::rust_barrier(&StorageClass::Global).is_some());
431        assert!(MemoryMapper::rust_barrier(&StorageClass::Register).is_none());
432    }
433}