cuda_rust_wasm/transpiler/
memory_mapper.rs1use crate::parser::ast::StorageClass;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct MemoryMapping {
12 pub qualifier: String,
14 pub description: String,
16 pub requires_sync: bool,
18 pub read_only: bool,
20}
21
22pub struct MemoryMapper;
24
25impl MemoryMapper {
26 pub fn to_rust(storage: &StorageClass) -> MemoryMapping {
32 match storage {
33 StorageClass::Global => MemoryMapping {
34 qualifier: "/* global */ ".to_string(),
35 description: "Heap-allocated device buffer (Vec<T> or &mut [T])".to_string(),
36 requires_sync: false,
37 read_only: false,
38 },
39 StorageClass::Shared => MemoryMapping {
40 qualifier: "#[shared] ".to_string(),
41 description: "Thread-local shared memory (SharedMemory<T>)".to_string(),
42 requires_sync: true,
43 read_only: false,
44 },
45 StorageClass::Constant => MemoryMapping {
46 qualifier: "const ".to_string(),
47 description: "Compile-time constant (const or static)".to_string(),
48 requires_sync: false,
49 read_only: true,
50 },
51 StorageClass::Register => MemoryMapping {
52 qualifier: "".to_string(),
53 description: "Local variable (stack-allocated)".to_string(),
54 requires_sync: false,
55 read_only: false,
56 },
57 StorageClass::Local => MemoryMapping {
58 qualifier: "".to_string(),
59 description: "Local variable (stack-allocated)".to_string(),
60 requires_sync: false,
61 read_only: false,
62 },
63 StorageClass::Auto => MemoryMapping {
64 qualifier: "let ".to_string(),
65 description: "Auto storage (stack-allocated)".to_string(),
66 requires_sync: false,
67 read_only: false,
68 },
69 }
70 }
71
72 pub fn rust_var_prefix(storage: &StorageClass, mutable: bool) -> String {
79 match storage {
80 StorageClass::Shared => {
81 if mutable {
82 "/* __shared__ */ let mut ".to_string()
83 } else {
84 "/* __shared__ */ let ".to_string()
85 }
86 }
87 StorageClass::Constant => "const ".to_string(),
88 StorageClass::Global => {
89 if mutable {
90 "/* __device__ */ static mut ".to_string()
91 } else {
92 "/* __device__ */ static ".to_string()
93 }
94 }
95 StorageClass::Register | StorageClass::Local | StorageClass::Auto => {
96 if mutable {
97 "let mut ".to_string()
98 } else {
99 "let ".to_string()
100 }
101 }
102 }
103 }
104
105 pub fn to_wgsl(storage: &StorageClass) -> MemoryMapping {
111 match storage {
112 StorageClass::Global => MemoryMapping {
113 qualifier: "var<storage, read_write>".to_string(),
114 description: "Storage buffer (read_write)".to_string(),
115 requires_sync: false,
116 read_only: false,
117 },
118 StorageClass::Shared => MemoryMapping {
119 qualifier: "var<workgroup>".to_string(),
120 description: "Workgroup memory (shared within workgroup)".to_string(),
121 requires_sync: true,
122 read_only: false,
123 },
124 StorageClass::Constant => MemoryMapping {
125 qualifier: "var<uniform>".to_string(),
126 description: "Uniform buffer (read-only)".to_string(),
127 requires_sync: false,
128 read_only: true,
129 },
130 StorageClass::Register => MemoryMapping {
131 qualifier: "var<private>".to_string(),
132 description: "Private variable (per-invocation)".to_string(),
133 requires_sync: false,
134 read_only: false,
135 },
136 StorageClass::Local => MemoryMapping {
137 qualifier: "var<private>".to_string(),
138 description: "Private variable (per-invocation)".to_string(),
139 requires_sync: false,
140 read_only: false,
141 },
142 StorageClass::Auto => MemoryMapping {
143 qualifier: "var".to_string(),
144 description: "Function-scope variable".to_string(),
145 requires_sync: false,
146 read_only: false,
147 },
148 }
149 }
150
151 pub fn wgsl_var_decl(storage: &StorageClass, name: &str, wgsl_type: &str) -> String {
161 let mapping = Self::to_wgsl(storage);
162 format!("{} {}: {};", mapping.qualifier, name, wgsl_type)
163 }
164
165 pub fn wgsl_binding_decl(
175 storage: &StorageClass,
176 group: u32,
177 binding: u32,
178 name: &str,
179 wgsl_type: &str,
180 read_only: bool,
181 ) -> String {
182 let access = match storage {
183 StorageClass::Constant => "var<storage, read>",
184 StorageClass::Global => {
185 if read_only {
186 "var<storage, read>"
187 } else {
188 "var<storage, read_write>"
189 }
190 }
191 _ => {
192 let mapping = Self::to_wgsl(storage);
193 return format!(
194 "@group({group}) @binding({binding})\n{} {name}: {wgsl_type};",
195 mapping.qualifier
196 );
197 }
198 };
199
200 format!(
201 "@group({group}) @binding({binding})\n{access} {name}: {wgsl_type};"
202 )
203 }
204
205 pub fn parse_cuda_qualifier(qualifier: &str) -> StorageClass {
211 match qualifier.trim() {
212 "__shared__" | "shared" => StorageClass::Shared,
213 "__constant__" | "constant" => StorageClass::Constant,
214 "__device__" | "device" => StorageClass::Global,
215 "__managed__" | "managed" => StorageClass::Global,
216 "register" => StorageClass::Register,
217 "local" => StorageClass::Local,
218 _ => StorageClass::Auto,
219 }
220 }
221
222 pub fn requires_barrier(storage: &StorageClass) -> bool {
229 matches!(storage, StorageClass::Shared)
230 }
231
232 pub fn is_read_only(storage: &StorageClass) -> bool {
234 matches!(storage, StorageClass::Constant)
235 }
236
237 pub fn wgsl_barrier(storage: &StorageClass) -> Option<&'static str> {
240 match storage {
241 StorageClass::Shared => Some("workgroupBarrier()"),
242 StorageClass::Global => Some("storageBarrier()"),
243 _ => None,
244 }
245 }
246
247 pub fn rust_barrier(storage: &StorageClass) -> Option<&'static str> {
250 match storage {
251 StorageClass::Shared => {
252 Some("std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst)")
253 }
254 StorageClass::Global => {
255 Some("std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst)")
256 }
257 _ => None,
258 }
259 }
260}
261
262#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_global_to_rust() {
271 let mapping = MemoryMapper::to_rust(&StorageClass::Global);
272 assert!(!mapping.requires_sync);
273 assert!(!mapping.read_only);
274 }
275
276 #[test]
277 fn test_shared_to_rust() {
278 let mapping = MemoryMapper::to_rust(&StorageClass::Shared);
279 assert!(mapping.requires_sync);
280 assert!(!mapping.read_only);
281 }
282
283 #[test]
284 fn test_constant_to_rust() {
285 let mapping = MemoryMapper::to_rust(&StorageClass::Constant);
286 assert!(!mapping.requires_sync);
287 assert!(mapping.read_only);
288 assert_eq!(mapping.qualifier, "const ");
289 }
290
291 #[test]
292 fn test_register_to_rust() {
293 let mapping = MemoryMapper::to_rust(&StorageClass::Register);
294 assert!(!mapping.requires_sync);
295 assert_eq!(mapping.qualifier, "");
296 }
297
298 #[test]
299 fn test_rust_var_prefix_shared() {
300 let prefix = MemoryMapper::rust_var_prefix(&StorageClass::Shared, true);
301 assert!(prefix.contains("__shared__"));
302 assert!(prefix.contains("let mut"));
303 }
304
305 #[test]
306 fn test_rust_var_prefix_const() {
307 let prefix = MemoryMapper::rust_var_prefix(&StorageClass::Constant, false);
308 assert_eq!(prefix, "const ");
309 }
310
311 #[test]
312 fn test_global_to_wgsl() {
313 let mapping = MemoryMapper::to_wgsl(&StorageClass::Global);
314 assert_eq!(mapping.qualifier, "var<storage, read_write>");
315 assert!(!mapping.read_only);
316 }
317
318 #[test]
319 fn test_shared_to_wgsl() {
320 let mapping = MemoryMapper::to_wgsl(&StorageClass::Shared);
321 assert_eq!(mapping.qualifier, "var<workgroup>");
322 assert!(mapping.requires_sync);
323 }
324
325 #[test]
326 fn test_constant_to_wgsl() {
327 let mapping = MemoryMapper::to_wgsl(&StorageClass::Constant);
328 assert_eq!(mapping.qualifier, "var<uniform>");
329 assert!(mapping.read_only);
330 }
331
332 #[test]
333 fn test_register_to_wgsl() {
334 let mapping = MemoryMapper::to_wgsl(&StorageClass::Register);
335 assert_eq!(mapping.qualifier, "var<private>");
336 }
337
338 #[test]
339 fn test_wgsl_var_decl() {
340 let decl = MemoryMapper::wgsl_var_decl(
341 &StorageClass::Shared,
342 "shared_data",
343 "array<f32, 256>",
344 );
345 assert_eq!(decl, "var<workgroup> shared_data: array<f32, 256>;");
346 }
347
348 #[test]
349 fn test_wgsl_binding_decl() {
350 let decl = MemoryMapper::wgsl_binding_decl(
351 &StorageClass::Global,
352 0,
353 0,
354 "data",
355 "array<f32>",
356 false,
357 );
358 assert!(decl.contains("@group(0) @binding(0)"));
359 assert!(decl.contains("read_write"));
360 }
361
362 #[test]
363 fn test_wgsl_binding_decl_readonly() {
364 let decl = MemoryMapper::wgsl_binding_decl(
365 &StorageClass::Global,
366 0,
367 1,
368 "input",
369 "array<f32>",
370 true,
371 );
372 assert!(decl.contains("read"));
373 assert!(!decl.contains("read_write"));
374 }
375
376 #[test]
377 fn test_parse_cuda_qualifier() {
378 assert!(matches!(
379 MemoryMapper::parse_cuda_qualifier("__shared__"),
380 StorageClass::Shared
381 ));
382 assert!(matches!(
383 MemoryMapper::parse_cuda_qualifier("__constant__"),
384 StorageClass::Constant
385 ));
386 assert!(matches!(
387 MemoryMapper::parse_cuda_qualifier("__device__"),
388 StorageClass::Global
389 ));
390 assert!(matches!(
391 MemoryMapper::parse_cuda_qualifier("register"),
392 StorageClass::Register
393 ));
394 assert!(matches!(
395 MemoryMapper::parse_cuda_qualifier("unknown"),
396 StorageClass::Auto
397 ));
398 }
399
400 #[test]
401 fn test_requires_barrier() {
402 assert!(MemoryMapper::requires_barrier(&StorageClass::Shared));
403 assert!(!MemoryMapper::requires_barrier(&StorageClass::Global));
404 assert!(!MemoryMapper::requires_barrier(&StorageClass::Register));
405 }
406
407 #[test]
408 fn test_is_read_only() {
409 assert!(MemoryMapper::is_read_only(&StorageClass::Constant));
410 assert!(!MemoryMapper::is_read_only(&StorageClass::Global));
411 assert!(!MemoryMapper::is_read_only(&StorageClass::Shared));
412 }
413
414 #[test]
415 fn test_wgsl_barrier() {
416 assert_eq!(
417 MemoryMapper::wgsl_barrier(&StorageClass::Shared),
418 Some("workgroupBarrier()")
419 );
420 assert_eq!(
421 MemoryMapper::wgsl_barrier(&StorageClass::Global),
422 Some("storageBarrier()")
423 );
424 assert_eq!(MemoryMapper::wgsl_barrier(&StorageClass::Register), None);
425 }
426
427 #[test]
428 fn test_rust_barrier() {
429 assert!(MemoryMapper::rust_barrier(&StorageClass::Shared).is_some());
430 assert!(MemoryMapper::rust_barrier(&StorageClass::Global).is_some());
431 assert!(MemoryMapper::rust_barrier(&StorageClass::Register).is_none());
432 }
433}