1use crate::types::WgslType;
6
7#[derive(Debug, Clone)]
9pub struct SharedMemoryDecl {
10 pub name: String,
12 pub element_type: WgslType,
14 pub dimensions: Vec<u32>,
16}
17
18impl SharedMemoryDecl {
19 pub fn new_1d(name: &str, element_type: WgslType, size: u32) -> Self {
21 Self {
22 name: name.to_string(),
23 element_type,
24 dimensions: vec![size],
25 }
26 }
27
28 pub fn new_2d(name: &str, element_type: WgslType, width: u32, height: u32) -> Self {
30 Self {
31 name: name.to_string(),
32 element_type,
33 dimensions: vec![width, height],
34 }
35 }
36
37 pub fn new_3d(name: &str, element_type: WgslType, width: u32, height: u32, depth: u32) -> Self {
39 Self {
40 name: name.to_string(),
41 element_type,
42 dimensions: vec![width, height, depth],
43 }
44 }
45
46 pub fn total_elements(&self) -> u32 {
48 self.dimensions.iter().product()
49 }
50
51 pub fn to_wgsl(&self) -> String {
59 let type_str = self.element_type.to_wgsl();
60 match self.dimensions.len() {
61 0 => format!("var<workgroup> {}: {};", self.name, type_str),
62 1 => format!(
63 "var<workgroup> {}: array<{}, {}>;",
64 self.name, type_str, self.dimensions[0]
65 ),
66 2 => {
67 format!(
69 "var<workgroup> {}: array<array<{}, {}>, {}>;",
70 self.name, type_str, self.dimensions[0], self.dimensions[1]
71 )
72 }
73 3 => {
74 format!(
76 "var<workgroup> {}: array<array<array<{}, {}>, {}>, {}>;",
77 self.name, type_str, self.dimensions[0], self.dimensions[1], self.dimensions[2]
78 )
79 }
80 _ => {
81 let total = self.total_elements();
84 let dims_str = self
85 .dimensions
86 .iter()
87 .map(|d| d.to_string())
88 .collect::<Vec<_>>()
89 .join("x");
90 format!(
91 "var<workgroup> {}: array<{}, {}>; // linearized {}D ({})",
92 self.name,
93 type_str,
94 total,
95 self.dimensions.len(),
96 dims_str
97 )
98 }
99 }
100 }
101
102 pub fn linearized_index_formula(&self, index_vars: &[&str]) -> Option<String> {
107 if self.dimensions.len() < 4 || index_vars.len() != self.dimensions.len() {
108 return None;
109 }
110
111 let mut terms = Vec::new();
112 let mut stride = 1u32;
113
114 for (i, var) in index_vars.iter().enumerate() {
115 if stride == 1 {
116 terms.push(var.to_string());
117 } else {
118 terms.push(format!("{} * {}u", var, stride));
119 }
120 stride *= self.dimensions[i];
121 }
122
123 Some(terms.join(" + "))
124 }
125}
126
127#[derive(Debug, Clone, Default)]
129pub struct SharedMemoryConfig {
130 pub declarations: Vec<SharedMemoryDecl>,
132}
133
134impl SharedMemoryConfig {
135 pub fn new() -> Self {
137 Self::default()
138 }
139
140 pub fn add(&mut self, decl: SharedMemoryDecl) {
142 self.declarations.push(decl);
143 }
144
145 pub fn to_wgsl(&self) -> String {
147 self.declarations
148 .iter()
149 .map(|d| d.to_wgsl())
150 .collect::<Vec<_>>()
151 .join("\n")
152 }
153}
154
155pub struct SharedTile<T, const W: usize, const H: usize> {
160 _marker: std::marker::PhantomData<T>,
161}
162
163impl<T, const W: usize, const H: usize> SharedTile<T, W, H> {
164 pub const fn width() -> usize {
166 W
167 }
168
169 pub const fn height() -> usize {
171 H
172 }
173}
174
175pub struct SharedArray<T, const N: usize> {
177 _marker: std::marker::PhantomData<T>,
178}
179
180impl<T, const N: usize> SharedArray<T, N> {
181 pub const fn size() -> usize {
183 N
184 }
185}
186
187pub struct SharedVolume<T, const X: usize, const Y: usize, const Z: usize> {
193 _marker: std::marker::PhantomData<T>,
194}
195
196impl<T, const X: usize, const Y: usize, const Z: usize> SharedVolume<T, X, Y, Z> {
197 pub const fn width() -> usize {
199 X
200 }
201
202 pub const fn height() -> usize {
204 Y
205 }
206
207 pub const fn depth() -> usize {
209 Z
210 }
211
212 pub const fn total() -> usize {
214 X * Y * Z
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_shared_memory_1d() {
224 let decl = SharedMemoryDecl::new_1d("cache", WgslType::F32, 256);
225 assert_eq!(decl.to_wgsl(), "var<workgroup> cache: array<f32, 256>;");
226 }
227
228 #[test]
229 fn test_shared_memory_2d() {
230 let decl = SharedMemoryDecl::new_2d("tile", WgslType::F32, 16, 16);
231 assert_eq!(
232 decl.to_wgsl(),
233 "var<workgroup> tile: array<array<f32, 16>, 16>;"
234 );
235 }
236
237 #[test]
238 fn test_shared_memory_config() {
239 let mut config = SharedMemoryConfig::new();
240 config.add(SharedMemoryDecl::new_1d("a", WgslType::I32, 64));
241 config.add(SharedMemoryDecl::new_1d("b", WgslType::F32, 128));
242
243 let wgsl = config.to_wgsl();
244 assert!(wgsl.contains("var<workgroup> a: array<i32, 64>;"));
245 assert!(wgsl.contains("var<workgroup> b: array<f32, 128>;"));
246 }
247
248 #[test]
249 fn test_shared_memory_3d() {
250 let decl = SharedMemoryDecl::new_3d("volume", WgslType::F32, 8, 8, 8);
251 assert_eq!(
252 decl.to_wgsl(),
253 "var<workgroup> volume: array<array<array<f32, 8>, 8>, 8>;"
254 );
255 assert_eq!(decl.total_elements(), 512);
256 }
257
258 #[test]
259 fn test_shared_memory_3d_asymmetric() {
260 let decl = SharedMemoryDecl::new_3d("tile_with_halo", WgslType::F32, 10, 10, 10);
262 assert_eq!(
263 decl.to_wgsl(),
264 "var<workgroup> tile_with_halo: array<array<array<f32, 10>, 10>, 10>;"
265 );
266 assert_eq!(decl.total_elements(), 1000);
267 }
268
269 #[test]
270 fn test_shared_memory_4d_linearized() {
271 let decl = SharedMemoryDecl {
273 name: "hypercube".to_string(),
274 element_type: WgslType::F32,
275 dimensions: vec![4, 4, 4, 4],
276 };
277 let wgsl = decl.to_wgsl();
278 assert!(wgsl.contains("array<f32, 256>")); assert!(wgsl.contains("linearized 4D"));
280 assert!(wgsl.contains("4x4x4x4"));
281 }
282
283 #[test]
284 fn test_linearized_index_formula() {
285 let decl = SharedMemoryDecl {
286 name: "data".to_string(),
287 element_type: WgslType::F32,
288 dimensions: vec![4, 8, 2, 3], };
290
291 let formula = decl
293 .linearized_index_formula(&["x", "y", "z", "t"])
294 .unwrap();
295 assert_eq!(formula, "x + y * 4u + z * 32u + t * 64u");
296 }
297
298 #[test]
299 fn test_linearized_index_formula_returns_none_for_3d() {
300 let decl = SharedMemoryDecl::new_3d("vol", WgslType::F32, 8, 8, 8);
301 assert!(decl.linearized_index_formula(&["x", "y", "z"]).is_none());
303 }
304
305 #[test]
306 fn test_shared_volume_marker() {
307 assert_eq!(SharedVolume::<f32, 8, 8, 8>::width(), 8);
309 assert_eq!(SharedVolume::<f32, 8, 8, 8>::height(), 8);
310 assert_eq!(SharedVolume::<f32, 8, 8, 8>::depth(), 8);
311 assert_eq!(SharedVolume::<f32, 8, 8, 8>::total(), 512);
312
313 assert_eq!(SharedVolume::<f32, 16, 8, 4>::width(), 16);
315 assert_eq!(SharedVolume::<f32, 16, 8, 4>::height(), 8);
316 assert_eq!(SharedVolume::<f32, 16, 8, 4>::depth(), 4);
317 assert_eq!(SharedVolume::<f32, 16, 8, 4>::total(), 512);
318 }
319}