ringkernel_wgpu_codegen/
shared.rs1use crate::types::WgslType;
6
7#[derive(Debug, Clone)]
9pub struct SharedMemoryDecl {
10 pub name: String,
12 pub element_type: WgslType,
14 pub dimensions: Vec<u32>,
16}
17
18impl SharedMemoryDecl {
19 pub fn new_1d(name: &str, element_type: WgslType, size: u32) -> Self {
21 Self {
22 name: name.to_string(),
23 element_type,
24 dimensions: vec![size],
25 }
26 }
27
28 pub fn new_2d(name: &str, element_type: WgslType, width: u32, height: u32) -> Self {
30 Self {
31 name: name.to_string(),
32 element_type,
33 dimensions: vec![width, height],
34 }
35 }
36
37 pub fn to_wgsl(&self) -> String {
39 let type_str = self.element_type.to_wgsl();
40 match self.dimensions.len() {
41 1 => format!(
42 "var<workgroup> {}: array<{}, {}>;",
43 self.name, type_str, self.dimensions[0]
44 ),
45 2 => format!(
46 "var<workgroup> {}: array<array<{}, {}>, {}>;",
47 self.name, type_str, self.dimensions[0], self.dimensions[1]
48 ),
49 _ => format!(
50 "var<workgroup> {}: array<{}, {}>; // TODO: higher dimensions",
51 self.name, type_str, self.dimensions[0]
52 ),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct SharedMemoryConfig {
60 pub declarations: Vec<SharedMemoryDecl>,
62}
63
64impl SharedMemoryConfig {
65 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn add(&mut self, decl: SharedMemoryDecl) {
72 self.declarations.push(decl);
73 }
74
75 pub fn to_wgsl(&self) -> String {
77 self.declarations
78 .iter()
79 .map(|d| d.to_wgsl())
80 .collect::<Vec<_>>()
81 .join("\n")
82 }
83}
84
85pub struct SharedTile<T, const W: usize, const H: usize> {
90 _marker: std::marker::PhantomData<T>,
91}
92
93impl<T, const W: usize, const H: usize> SharedTile<T, W, H> {
94 pub const fn width() -> usize {
96 W
97 }
98
99 pub const fn height() -> usize {
101 H
102 }
103}
104
105pub struct SharedArray<T, const N: usize> {
107 _marker: std::marker::PhantomData<T>,
108}
109
110impl<T, const N: usize> SharedArray<T, N> {
111 pub const fn size() -> usize {
113 N
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_shared_memory_1d() {
123 let decl = SharedMemoryDecl::new_1d("cache", WgslType::F32, 256);
124 assert_eq!(decl.to_wgsl(), "var<workgroup> cache: array<f32, 256>;");
125 }
126
127 #[test]
128 fn test_shared_memory_2d() {
129 let decl = SharedMemoryDecl::new_2d("tile", WgslType::F32, 16, 16);
130 assert_eq!(
131 decl.to_wgsl(),
132 "var<workgroup> tile: array<array<f32, 16>, 16>;"
133 );
134 }
135
136 #[test]
137 fn test_shared_memory_config() {
138 let mut config = SharedMemoryConfig::new();
139 config.add(SharedMemoryDecl::new_1d("a", WgslType::I32, 64));
140 config.add(SharedMemoryDecl::new_1d("b", WgslType::F32, 128));
141
142 let wgsl = config.to_wgsl();
143 assert!(wgsl.contains("var<workgroup> a: array<i32, 64>;"));
144 assert!(wgsl.contains("var<workgroup> b: array<f32, 128>;"));
145 }
146}