ringkernel_cuda_codegen/
shared.rs

1//! Shared memory support for CUDA code generation.
2//!
3//! This module provides types and utilities for working with CUDA shared memory
4//! (`__shared__`) in the Rust DSL. Shared memory is fast on-chip memory that is
5//! shared among all threads in a block.
6//!
7//! # Overview
8//!
9//! Shared memory is crucial for efficient GPU programming:
10//! - Much faster than global memory (~100x lower latency)
11//! - Shared among all threads in a block
12//! - Limited size (typically 48KB-164KB per SM)
13//! - Requires explicit synchronization after writes
14//!
15//! # Usage in DSL
16//!
17//! ```ignore
18//! use ringkernel_cuda_codegen::shared::SharedTile;
19//!
20//! fn kernel(data: &[f32], out: &mut [f32], width: i32) {
21//!     // Declare a 16x16 shared memory tile
22//!     let tile = SharedTile::<f32, 16, 16>::new();
23//!
24//!     // Load from global memory
25//!     let gx = block_idx_x() * 16 + thread_idx_x();
26//!     let gy = block_idx_y() * 16 + thread_idx_y();
27//!     tile.set(thread_idx_x(), thread_idx_y(), data[gy * width + gx]);
28//!
29//!     // Synchronize before reading
30//!     sync_threads();
31//!
32//!     // Read from shared memory
33//!     let val = tile.get(thread_idx_x(), thread_idx_y());
34//!     out[gy * width + gx] = val * 2.0;
35//! }
36//! ```
37//!
38//! # Generated CUDA
39//!
40//! The above DSL generates:
41//!
42//! ```cuda
43//! __shared__ float tile[16][16];
44//! int gx = blockIdx.x * 16 + threadIdx.x;
45//! int gy = blockIdx.y * 16 + threadIdx.y;
46//! tile[threadIdx.y][threadIdx.x] = data[gy * width + gx];
47//! __syncthreads();
48//! float val = tile[threadIdx.y][threadIdx.x];
49//! out[gy * width + gx] = val * 2.0f;
50//! ```
51
52use std::marker::PhantomData;
53
54/// A 2D shared memory tile.
55///
56/// This type represents a 2D array in CUDA shared memory. On the CPU side,
57/// it's a zero-sized type that serves as a marker for the transpiler.
58///
59/// # Type Parameters
60///
61/// * `T` - Element type (f32, f64, i32, etc.)
62/// * `W` - Tile width (columns)
63/// * `H` - Tile height (rows)
64///
65/// # Example
66///
67/// ```ignore
68/// // 16x16 tile of floats
69/// let tile = SharedTile::<f32, 16, 16>::new();
70///
71/// // 32x8 tile for matrix operations
72/// let mat_tile = SharedTile::<f32, 32, 8>::new();
73/// ```
74#[derive(Debug)]
75pub struct SharedTile<T, const W: usize, const H: usize> {
76    _phantom: PhantomData<T>,
77}
78
79impl<T: Default + Copy, const W: usize, const H: usize> SharedTile<T, W, H> {
80    /// Create a new shared memory tile.
81    ///
82    /// On GPU, this translates to: `__shared__ T tile[H][W];`
83    #[inline]
84    pub fn new() -> Self {
85        Self {
86            _phantom: PhantomData,
87        }
88    }
89
90    /// Get the tile width.
91    #[inline]
92    pub const fn width() -> usize {
93        W
94    }
95
96    /// Get the tile height.
97    #[inline]
98    pub const fn height() -> usize {
99        H
100    }
101
102    /// Get the total number of elements.
103    #[inline]
104    pub const fn size() -> usize {
105        W * H
106    }
107
108    /// Get an element from the tile (CPU stub - actual access on GPU).
109    ///
110    /// On GPU, this translates to: `tile[y][x]`
111    ///
112    /// # Arguments
113    ///
114    /// * `x` - Column index (0..W)
115    /// * `y` - Row index (0..H)
116    #[inline]
117    pub fn get(&self, _x: i32, _y: i32) -> T {
118        // This is a transpiler marker - actual access happens on GPU
119        T::default()
120    }
121
122    /// Set an element in the tile (CPU stub - actual write on GPU).
123    ///
124    /// On GPU, this translates to: `tile[y][x] = value;`
125    ///
126    /// # Arguments
127    ///
128    /// * `x` - Column index (0..W)
129    /// * `y` - Row index (0..H)
130    /// * `value` - Value to store
131    #[inline]
132    pub fn set(&mut self, _x: i32, _y: i32, _value: T) {
133        // This is a transpiler marker - actual write happens on GPU
134    }
135}
136
137impl<T: Default + Copy, const W: usize, const H: usize> Default for SharedTile<T, W, H> {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// A 1D shared memory array.
144///
145/// Simpler than `SharedTile` for linear data access patterns.
146///
147/// # Type Parameters
148///
149/// * `T` - Element type
150/// * `N` - Array size
151#[derive(Debug)]
152pub struct SharedArray<T, const N: usize> {
153    _phantom: PhantomData<T>,
154}
155
156impl<T: Default + Copy, const N: usize> SharedArray<T, N> {
157    /// Create a new shared memory array.
158    ///
159    /// On GPU, this translates to: `__shared__ T arr[N];`
160    #[inline]
161    pub fn new() -> Self {
162        Self {
163            _phantom: PhantomData,
164        }
165    }
166
167    /// Get the array size.
168    #[inline]
169    pub const fn size() -> usize {
170        N
171    }
172
173    /// Get an element from the array (CPU stub - actual access on GPU).
174    #[inline]
175    pub fn get(&self, _idx: i32) -> T {
176        T::default()
177    }
178
179    /// Set an element in the array (CPU stub - actual write on GPU).
180    #[inline]
181    pub fn set(&mut self, _idx: i32, _value: T) {
182        // This is a transpiler marker - actual write happens on GPU
183    }
184}
185
186impl<T: Default + Copy, const N: usize> Default for SharedArray<T, N> {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192/// Information about a shared memory declaration for transpilation.
193#[derive(Debug, Clone)]
194pub struct SharedMemoryDecl {
195    /// Variable name.
196    pub name: String,
197    /// Element type (CUDA type string).
198    pub element_type: String,
199    /// Dimensions (1D: [size], 2D: [height, width]).
200    pub dimensions: Vec<usize>,
201}
202
203impl SharedMemoryDecl {
204    /// Create a 1D shared memory declaration.
205    pub fn array(name: impl Into<String>, element_type: impl Into<String>, size: usize) -> Self {
206        Self {
207            name: name.into(),
208            element_type: element_type.into(),
209            dimensions: vec![size],
210        }
211    }
212
213    /// Create a 2D shared memory declaration.
214    pub fn tile(
215        name: impl Into<String>,
216        element_type: impl Into<String>,
217        width: usize,
218        height: usize,
219    ) -> Self {
220        Self {
221            name: name.into(),
222            element_type: element_type.into(),
223            dimensions: vec![height, width], // Row-major: [rows][cols]
224        }
225    }
226
227    /// Generate CUDA declaration string.
228    ///
229    /// # Returns
230    ///
231    /// A string like `__shared__ float tile[16][16];`
232    pub fn to_cuda_decl(&self) -> String {
233        let dims: String = self.dimensions.iter().map(|d| format!("[{}]", d)).collect();
234
235        format!("__shared__ {} {}{};", self.element_type, self.name, dims)
236    }
237
238    /// Generate CUDA access expression.
239    ///
240    /// # Arguments
241    ///
242    /// * `indices` - Index expressions for each dimension
243    ///
244    /// # Returns
245    ///
246    /// A string like `tile[y][x]`
247    pub fn to_cuda_access(&self, indices: &[String]) -> String {
248        let idx_str: String = indices.iter().map(|i| format!("[{}]", i)).collect();
249        format!("{}{}", self.name, idx_str)
250    }
251}
252
253/// Shared memory configuration for a kernel.
254#[derive(Debug, Clone, Default)]
255pub struct SharedMemoryConfig {
256    /// All shared memory declarations in the kernel.
257    pub declarations: Vec<SharedMemoryDecl>,
258}
259
260impl SharedMemoryConfig {
261    /// Create a new empty configuration.
262    pub fn new() -> Self {
263        Self {
264            declarations: Vec::new(),
265        }
266    }
267
268    /// Add a shared memory declaration.
269    pub fn add(&mut self, decl: SharedMemoryDecl) {
270        self.declarations.push(decl);
271    }
272
273    /// Add a 1D shared array.
274    pub fn add_array(
275        &mut self,
276        name: impl Into<String>,
277        element_type: impl Into<String>,
278        size: usize,
279    ) {
280        self.declarations
281            .push(SharedMemoryDecl::array(name, element_type, size));
282    }
283
284    /// Add a 2D shared tile.
285    pub fn add_tile(
286        &mut self,
287        name: impl Into<String>,
288        element_type: impl Into<String>,
289        width: usize,
290        height: usize,
291    ) {
292        self.declarations
293            .push(SharedMemoryDecl::tile(name, element_type, width, height));
294    }
295
296    /// Generate all CUDA shared memory declarations.
297    pub fn generate_declarations(&self, indent: &str) -> String {
298        self.declarations
299            .iter()
300            .map(|d| format!("{}{}", indent, d.to_cuda_decl()))
301            .collect::<Vec<_>>()
302            .join("\n")
303    }
304
305    /// Check if any shared memory is used.
306    pub fn is_empty(&self) -> bool {
307        self.declarations.is_empty()
308    }
309
310    /// Calculate total shared memory size in bytes.
311    pub fn total_bytes(&self) -> usize {
312        self.declarations
313            .iter()
314            .map(|d| {
315                let elem_size = match d.element_type.as_str() {
316                    "float" => 4,
317                    "double" => 8,
318                    "int" => 4,
319                    "unsigned int" => 4,
320                    "long long" | "unsigned long long" => 8,
321                    "short" | "unsigned short" => 2,
322                    "char" | "unsigned char" => 1,
323                    _ => 4, // Default assumption
324                };
325                let count: usize = d.dimensions.iter().product();
326                elem_size * count
327            })
328            .sum()
329    }
330}
331
332/// Parse a SharedTile type to extract dimensions.
333///
334/// # Arguments
335///
336/// * `type_path` - The type path (e.g., `SharedTile::<f32, 16, 16>`)
337///
338/// # Returns
339///
340/// `(element_type, width, height)` if successfully parsed.
341pub fn parse_shared_tile_type(type_str: &str) -> Option<(String, usize, usize)> {
342    // Pattern: SharedTile<T, W, H> or SharedTile::<T, W, H>
343    let inner = type_str
344        .strip_prefix("SharedTile")?
345        .trim_start_matches("::")
346        .strip_prefix('<')?
347        .strip_suffix('>')?;
348
349    let parts: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
350    if parts.len() != 3 {
351        return None;
352    }
353
354    let element_type = parts[0].to_string();
355    let width: usize = parts[1].parse().ok()?;
356    let height: usize = parts[2].parse().ok()?;
357
358    Some((element_type, width, height))
359}
360
361/// Parse a SharedArray type to extract size.
362///
363/// # Arguments
364///
365/// * `type_str` - The type path (e.g., `SharedArray::<f32, 256>`)
366///
367/// # Returns
368///
369/// `(element_type, size)` if successfully parsed.
370pub fn parse_shared_array_type(type_str: &str) -> Option<(String, usize)> {
371    // Pattern: SharedArray<T, N> or SharedArray::<T, N>
372    let inner = type_str
373        .strip_prefix("SharedArray")?
374        .trim_start_matches("::")
375        .strip_prefix('<')?
376        .strip_suffix('>')?;
377
378    let parts: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
379    if parts.len() != 2 {
380        return None;
381    }
382
383    let element_type = parts[0].to_string();
384    let size: usize = parts[1].parse().ok()?;
385
386    Some((element_type, size))
387}
388
389/// Map Rust element type to CUDA type for shared memory.
390pub fn rust_to_cuda_element_type(rust_type: &str) -> &'static str {
391    match rust_type {
392        "f32" => "float",
393        "f64" => "double",
394        "i32" => "int",
395        "u32" => "unsigned int",
396        "i64" => "long long",
397        "u64" => "unsigned long long",
398        "i16" => "short",
399        "u16" => "unsigned short",
400        "i8" => "char",
401        "u8" => "unsigned char",
402        "bool" => "int",
403        _ => "float", // Default fallback
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_shared_tile_dimensions() {
413        assert_eq!(SharedTile::<f32, 16, 16>::width(), 16);
414        assert_eq!(SharedTile::<f32, 16, 16>::height(), 16);
415        assert_eq!(SharedTile::<f32, 16, 16>::size(), 256);
416
417        assert_eq!(SharedTile::<f32, 32, 8>::width(), 32);
418        assert_eq!(SharedTile::<f32, 32, 8>::height(), 8);
419        assert_eq!(SharedTile::<f32, 32, 8>::size(), 256);
420    }
421
422    #[test]
423    fn test_shared_array_size() {
424        assert_eq!(SharedArray::<f32, 256>::size(), 256);
425        assert_eq!(SharedArray::<i32, 1024>::size(), 1024);
426    }
427
428    #[test]
429    fn test_shared_memory_decl_1d() {
430        let decl = SharedMemoryDecl::array("buffer", "float", 256);
431        assert_eq!(decl.to_cuda_decl(), "__shared__ float buffer[256];");
432        assert_eq!(decl.to_cuda_access(&["i".to_string()]), "buffer[i]");
433    }
434
435    #[test]
436    fn test_shared_memory_decl_2d() {
437        let decl = SharedMemoryDecl::tile("tile", "float", 16, 16);
438        assert_eq!(decl.to_cuda_decl(), "__shared__ float tile[16][16];");
439        assert_eq!(
440            decl.to_cuda_access(&["y".to_string(), "x".to_string()]),
441            "tile[y][x]"
442        );
443    }
444
445    #[test]
446    fn test_shared_memory_config() {
447        let mut config = SharedMemoryConfig::new();
448        config.add_tile("tile", "float", 16, 16);
449        config.add_array("temp", "int", 128);
450
451        let decls = config.generate_declarations("    ");
452        assert!(decls.contains("__shared__ float tile[16][16];"));
453        assert!(decls.contains("__shared__ int temp[128];"));
454    }
455
456    #[test]
457    fn test_total_bytes() {
458        let mut config = SharedMemoryConfig::new();
459        config.add_tile("tile", "float", 16, 16); // 16*16*4 = 1024
460        config.add_array("temp", "double", 64); // 64*8 = 512
461
462        assert_eq!(config.total_bytes(), 1024 + 512);
463    }
464
465    #[test]
466    fn test_parse_shared_tile_type() {
467        let result = parse_shared_tile_type("SharedTile::<f32, 16, 16>");
468        assert_eq!(result, Some(("f32".to_string(), 16, 16)));
469
470        let result2 = parse_shared_tile_type("SharedTile<i32, 32, 8>");
471        assert_eq!(result2, Some(("i32".to_string(), 32, 8)));
472    }
473
474    #[test]
475    fn test_parse_shared_array_type() {
476        let result = parse_shared_array_type("SharedArray::<f32, 256>");
477        assert_eq!(result, Some(("f32".to_string(), 256)));
478
479        let result2 = parse_shared_array_type("SharedArray<u32, 1024>");
480        assert_eq!(result2, Some(("u32".to_string(), 1024)));
481    }
482
483    #[test]
484    fn test_rust_to_cuda_element_type() {
485        assert_eq!(rust_to_cuda_element_type("f32"), "float");
486        assert_eq!(rust_to_cuda_element_type("f64"), "double");
487        assert_eq!(rust_to_cuda_element_type("i32"), "int");
488        assert_eq!(rust_to_cuda_element_type("u64"), "unsigned long long");
489    }
490}