cuda_rust_wasm/runtime/
grid.rs

1//! Grid and block dimension types
2
3use serde::{Deserialize, Serialize};
4
5/// 3D dimension type (similar to CUDA's dim3)
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub struct Dim3 {
8    pub x: u32,
9    pub y: u32,
10    pub z: u32,
11}
12
13impl Dim3 {
14    /// Create a new Dim3
15    pub fn new(x: u32, y: u32, z: u32) -> Self {
16        Self { x, y, z }
17    }
18    
19    /// Create a 1D dimension
20    pub fn one_d(x: u32) -> Self {
21        Self { x, y: 1, z: 1 }
22    }
23    
24    /// Create a 2D dimension
25    pub fn two_d(x: u32, y: u32) -> Self {
26        Self { x, y, z: 1 }
27    }
28    
29    /// Get total number of elements
30    pub fn size(&self) -> u32 {
31        self.x * self.y * self.z
32    }
33}
34
35impl From<u32> for Dim3 {
36    fn from(x: u32) -> Self {
37        Self::one_d(x)
38    }
39}
40
41impl From<(u32, u32)> for Dim3 {
42    fn from((x, y): (u32, u32)) -> Self {
43        Self::two_d(x, y)
44    }
45}
46
47impl From<(u32, u32, u32)> for Dim3 {
48    fn from((x, y, z): (u32, u32, u32)) -> Self {
49        Self::new(x, y, z)
50    }
51}
52
53/// Grid configuration for kernel launch
54#[derive(Debug, Clone, Copy)]
55pub struct Grid {
56    pub dim: Dim3,
57}
58
59impl Grid {
60    /// Create a new grid configuration
61    pub fn new<D: Into<Dim3>>(dim: D) -> Self {
62        Self { dim: dim.into() }
63    }
64    
65    /// Get total number of blocks
66    pub fn num_blocks(&self) -> u32 {
67        self.dim.size()
68    }
69}
70
71/// Block configuration for kernel launch
72#[derive(Debug, Clone, Copy)]
73pub struct Block {
74    pub dim: Dim3,
75}
76
77impl Block {
78    /// Create a new block configuration
79    pub fn new<D: Into<Dim3>>(dim: D) -> Self {
80        Self { dim: dim.into() }
81    }
82    
83    /// Get total number of threads per block
84    pub fn num_threads(&self) -> u32 {
85        self.dim.size()
86    }
87    
88    /// Validate block dimensions against hardware limits
89    pub fn validate(&self) -> crate::Result<()> {
90        // Typical CUDA limits
91        const MAX_THREADS_PER_BLOCK: u32 = 1024;
92        const MAX_BLOCK_DIM_X: u32 = 1024;
93        const MAX_BLOCK_DIM_Y: u32 = 1024;
94        const MAX_BLOCK_DIM_Z: u32 = 64;
95        
96        if self.num_threads() > MAX_THREADS_PER_BLOCK {
97            return Err(crate::runtime_error!(
98                "Block size {} exceeds maximum threads per block {}",
99                self.num_threads(),
100                MAX_THREADS_PER_BLOCK
101            ));
102        }
103        
104        if self.dim.x > MAX_BLOCK_DIM_X {
105            return Err(crate::runtime_error!(
106                "Block x dimension {} exceeds maximum {}",
107                self.dim.x,
108                MAX_BLOCK_DIM_X
109            ));
110        }
111        
112        if self.dim.y > MAX_BLOCK_DIM_Y {
113            return Err(crate::runtime_error!(
114                "Block y dimension {} exceeds maximum {}",
115                self.dim.y,
116                MAX_BLOCK_DIM_Y
117            ));
118        }
119        
120        if self.dim.z > MAX_BLOCK_DIM_Z {
121            return Err(crate::runtime_error!(
122                "Block z dimension {} exceeds maximum {}",
123                self.dim.z,
124                MAX_BLOCK_DIM_Z
125            ));
126        }
127        
128        Ok(())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    
136    #[test]
137    fn test_dim3_creation() {
138        let d1 = Dim3::one_d(256);
139        assert_eq!(d1, Dim3 { x: 256, y: 1, z: 1 });
140        assert_eq!(d1.size(), 256);
141        
142        let d2 = Dim3::two_d(16, 16);
143        assert_eq!(d2, Dim3 { x: 16, y: 16, z: 1 });
144        assert_eq!(d2.size(), 256);
145        
146        let d3 = Dim3::new(8, 8, 4);
147        assert_eq!(d3.size(), 256);
148    }
149    
150    #[test]
151    fn test_block_validation() {
152        let valid_block = Block::new(256);
153        assert!(valid_block.validate().is_ok());
154        
155        let invalid_block = Block::new(2048);
156        assert!(invalid_block.validate().is_err());
157    }
158}