1use std::marker::PhantomData;
53
54#[derive(Debug)]
75pub struct SharedTile<T, const W: usize, const H: usize> {
76 _phantom: PhantomData<T>,
77}
78
79impl<T: Default + Copy, const W: usize, const H: usize> SharedTile<T, W, H> {
80 #[inline]
84 pub fn new() -> Self {
85 Self {
86 _phantom: PhantomData,
87 }
88 }
89
90 #[inline]
92 pub const fn width() -> usize {
93 W
94 }
95
96 #[inline]
98 pub const fn height() -> usize {
99 H
100 }
101
102 #[inline]
104 pub const fn size() -> usize {
105 W * H
106 }
107
108 #[inline]
117 pub fn get(&self, _x: i32, _y: i32) -> T {
118 T::default()
120 }
121
122 #[inline]
132 pub fn set(&mut self, _x: i32, _y: i32, _value: T) {
133 }
135}
136
137impl<T: Default + Copy, const W: usize, const H: usize> Default for SharedTile<T, W, H> {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143#[derive(Debug)]
152pub struct SharedArray<T, const N: usize> {
153 _phantom: PhantomData<T>,
154}
155
156impl<T: Default + Copy, const N: usize> SharedArray<T, N> {
157 #[inline]
161 pub fn new() -> Self {
162 Self {
163 _phantom: PhantomData,
164 }
165 }
166
167 #[inline]
169 pub const fn size() -> usize {
170 N
171 }
172
173 #[inline]
175 pub fn get(&self, _idx: i32) -> T {
176 T::default()
177 }
178
179 #[inline]
181 pub fn set(&mut self, _idx: i32, _value: T) {
182 }
184}
185
186impl<T: Default + Copy, const N: usize> Default for SharedArray<T, N> {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct SharedMemoryDecl {
195 pub name: String,
197 pub element_type: String,
199 pub dimensions: Vec<usize>,
201}
202
203impl SharedMemoryDecl {
204 pub fn array(name: impl Into<String>, element_type: impl Into<String>, size: usize) -> Self {
206 Self {
207 name: name.into(),
208 element_type: element_type.into(),
209 dimensions: vec![size],
210 }
211 }
212
213 pub fn tile(
215 name: impl Into<String>,
216 element_type: impl Into<String>,
217 width: usize,
218 height: usize,
219 ) -> Self {
220 Self {
221 name: name.into(),
222 element_type: element_type.into(),
223 dimensions: vec![height, width], }
225 }
226
227 pub fn to_cuda_decl(&self) -> String {
233 let dims: String = self.dimensions.iter().map(|d| format!("[{}]", d)).collect();
234
235 format!("__shared__ {} {}{};", self.element_type, self.name, dims)
236 }
237
238 pub fn to_cuda_access(&self, indices: &[String]) -> String {
248 let idx_str: String = indices.iter().map(|i| format!("[{}]", i)).collect();
249 format!("{}{}", self.name, idx_str)
250 }
251}
252
253#[derive(Debug, Clone, Default)]
255pub struct SharedMemoryConfig {
256 pub declarations: Vec<SharedMemoryDecl>,
258}
259
260impl SharedMemoryConfig {
261 pub fn new() -> Self {
263 Self {
264 declarations: Vec::new(),
265 }
266 }
267
268 pub fn add(&mut self, decl: SharedMemoryDecl) {
270 self.declarations.push(decl);
271 }
272
273 pub fn add_array(
275 &mut self,
276 name: impl Into<String>,
277 element_type: impl Into<String>,
278 size: usize,
279 ) {
280 self.declarations
281 .push(SharedMemoryDecl::array(name, element_type, size));
282 }
283
284 pub fn add_tile(
286 &mut self,
287 name: impl Into<String>,
288 element_type: impl Into<String>,
289 width: usize,
290 height: usize,
291 ) {
292 self.declarations
293 .push(SharedMemoryDecl::tile(name, element_type, width, height));
294 }
295
296 pub fn generate_declarations(&self, indent: &str) -> String {
298 self.declarations
299 .iter()
300 .map(|d| format!("{}{}", indent, d.to_cuda_decl()))
301 .collect::<Vec<_>>()
302 .join("\n")
303 }
304
305 pub fn is_empty(&self) -> bool {
307 self.declarations.is_empty()
308 }
309
310 pub fn total_bytes(&self) -> usize {
312 self.declarations
313 .iter()
314 .map(|d| {
315 let elem_size = match d.element_type.as_str() {
316 "float" => 4,
317 "double" => 8,
318 "int" => 4,
319 "unsigned int" => 4,
320 "long long" | "unsigned long long" => 8,
321 "short" | "unsigned short" => 2,
322 "char" | "unsigned char" => 1,
323 _ => 4, };
325 let count: usize = d.dimensions.iter().product();
326 elem_size * count
327 })
328 .sum()
329 }
330}
331
332pub fn parse_shared_tile_type(type_str: &str) -> Option<(String, usize, usize)> {
342 let inner = type_str
344 .strip_prefix("SharedTile")?
345 .trim_start_matches("::")
346 .strip_prefix('<')?
347 .strip_suffix('>')?;
348
349 let parts: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
350 if parts.len() != 3 {
351 return None;
352 }
353
354 let element_type = parts[0].to_string();
355 let width: usize = parts[1].parse().ok()?;
356 let height: usize = parts[2].parse().ok()?;
357
358 Some((element_type, width, height))
359}
360
361pub fn parse_shared_array_type(type_str: &str) -> Option<(String, usize)> {
371 let inner = type_str
373 .strip_prefix("SharedArray")?
374 .trim_start_matches("::")
375 .strip_prefix('<')?
376 .strip_suffix('>')?;
377
378 let parts: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
379 if parts.len() != 2 {
380 return None;
381 }
382
383 let element_type = parts[0].to_string();
384 let size: usize = parts[1].parse().ok()?;
385
386 Some((element_type, size))
387}
388
389pub fn rust_to_cuda_element_type(rust_type: &str) -> &'static str {
391 match rust_type {
392 "f32" => "float",
393 "f64" => "double",
394 "i32" => "int",
395 "u32" => "unsigned int",
396 "i64" => "long long",
397 "u64" => "unsigned long long",
398 "i16" => "short",
399 "u16" => "unsigned short",
400 "i8" => "char",
401 "u8" => "unsigned char",
402 "bool" => "int",
403 _ => "float", }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_shared_tile_dimensions() {
413 assert_eq!(SharedTile::<f32, 16, 16>::width(), 16);
414 assert_eq!(SharedTile::<f32, 16, 16>::height(), 16);
415 assert_eq!(SharedTile::<f32, 16, 16>::size(), 256);
416
417 assert_eq!(SharedTile::<f32, 32, 8>::width(), 32);
418 assert_eq!(SharedTile::<f32, 32, 8>::height(), 8);
419 assert_eq!(SharedTile::<f32, 32, 8>::size(), 256);
420 }
421
422 #[test]
423 fn test_shared_array_size() {
424 assert_eq!(SharedArray::<f32, 256>::size(), 256);
425 assert_eq!(SharedArray::<i32, 1024>::size(), 1024);
426 }
427
428 #[test]
429 fn test_shared_memory_decl_1d() {
430 let decl = SharedMemoryDecl::array("buffer", "float", 256);
431 assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
432 assert_eq!(decl.to_cuda_access(&["i".to_string()]), "buffer[i]");
433 }
434
435 #[test]
436 fn test_shared_memory_decl_2d() {
437 let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
438 assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
439 assert_eq!(
440 decl.to_cuda_access(&["y".to_string(), "x".to_string()]),
441 "tile[y][x]"
442 );
443 }
444
445 #[test]
446 fn test_shared_memory_config() {
447 let mut config = SharedMemoryConfig::new();
448 config.add_tile("tile", "float", 16, 16);
449 config.add_array("temp", "int", 128);
450
451 let decls = config.generate_declarations(" ");
452 assert!(decls.contains("__shared__ float tile[16][16];"));
453 assert!(decls.contains("__shared__ int temp[128];"));
454 }
455
456 #[test]
457 fn test_total_bytes() {
458 let mut config = SharedMemoryConfig::new();
459 config.add_tile("tile", "float", 16, 16); config.add_array("temp", "double", 64); assert_eq!(config.total_bytes(), 1024 + 512);
463 }
464
465 #[test]
466 fn test_parse_shared_tile_type() {
467 let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
468 assert_eq!(result, Some(("f32".to_string(), 16, 16)));
469
470 let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
471 assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
472 }
473
474 #[test]
475 fn test_parse_shared_array_type() {
476 let result = parse_shared_array_type("SharedArray::<f32, 256>");
477 assert_eq!(result, Some(("f32".to_string(), 256)));
478
479 let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
480 assert_eq!(result2, Some(("u32".to_string(), 1024)));
481 }
482
483 #[test]
484 fn test_rust_to_cuda_element_type() {
485 assert_eq!(rust_to_cuda_element_type("f32"), "float");
486 assert_eq!(rust_to_cuda_element_type("f64"), "double");
487 assert_eq!(rust_to_cuda_element_type("i32"), "int");
488 assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
489 }
490}