cuda_rust_wasm/runtime/
grid.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub struct Dim3 {
8 pub x: u32,
9 pub y: u32,
10 pub z: u32,
11}
12
13impl Dim3 {
14 pub fn new(x: u32, y: u32, z: u32) -> Self {
16 Self { x, y, z }
17 }
18
19 pub fn one_d(x: u32) -> Self {
21 Self { x, y: 1, z: 1 }
22 }
23
24 pub fn two_d(x: u32, y: u32) -> Self {
26 Self { x, y, z: 1 }
27 }
28
29 pub fn size(&self) -> u32 {
31 self.x * self.y * self.z
32 }
33}
34
35impl From<u32> for Dim3 {
36 fn from(x: u32) -> Self {
37 Self::one_d(x)
38 }
39}
40
41impl From<(u32, u32)> for Dim3 {
42 fn from((x, y): (u32, u32)) -> Self {
43 Self::two_d(x, y)
44 }
45}
46
47impl From<(u32, u32, u32)> for Dim3 {
48 fn from((x, y, z): (u32, u32, u32)) -> Self {
49 Self::new(x, y, z)
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
55pub struct Grid {
56 pub dim: Dim3,
57}
58
59impl Grid {
60 pub fn new<D: Into<Dim3>>(dim: D) -> Self {
62 Self { dim: dim.into() }
63 }
64
65 pub fn num_blocks(&self) -> u32 {
67 self.dim.size()
68 }
69}
70
71#[derive(Debug, Clone, Copy)]
73pub struct Block {
74 pub dim: Dim3,
75}
76
77impl Block {
78 pub fn new<D: Into<Dim3>>(dim: D) -> Self {
80 Self { dim: dim.into() }
81 }
82
83 pub fn num_threads(&self) -> u32 {
85 self.dim.size()
86 }
87
88 pub fn validate(&self) -> crate::Result<()> {
90 const MAX_THREADS_PER_BLOCK: u32 = 1024;
92 const MAX_BLOCK_DIM_X: u32 = 1024;
93 const MAX_BLOCK_DIM_Y: u32 = 1024;
94 const MAX_BLOCK_DIM_Z: u32 = 64;
95
96 if self.num_threads() > MAX_THREADS_PER_BLOCK {
97 return Err(crate::runtime_error!(
98 "Block size {} exceeds maximum threads per block {}",
99 self.num_threads(),
100 MAX_THREADS_PER_BLOCK
101 ));
102 }
103
104 if self.dim.x > MAX_BLOCK_DIM_X {
105 return Err(crate::runtime_error!(
106 "Block x dimension {} exceeds maximum {}",
107 self.dim.x,
108 MAX_BLOCK_DIM_X
109 ));
110 }
111
112 if self.dim.y > MAX_BLOCK_DIM_Y {
113 return Err(crate::runtime_error!(
114 "Block y dimension {} exceeds maximum {}",
115 self.dim.y,
116 MAX_BLOCK_DIM_Y
117 ));
118 }
119
120 if self.dim.z > MAX_BLOCK_DIM_Z {
121 return Err(crate::runtime_error!(
122 "Block z dimension {} exceeds maximum {}",
123 self.dim.z,
124 MAX_BLOCK_DIM_Z
125 ));
126 }
127
128 Ok(())
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_dim3_creation() {
138 let d1 = Dim3::one_d(256);
139 assert_eq!(d1, Dim3 { x: 256, y: 1, z: 1 });
140 assert_eq!(d1.size(), 256);
141
142 let d2 = Dim3::two_d(16, 16);
143 assert_eq!(d2, Dim3 { x: 16, y: 16, z: 1 });
144 assert_eq!(d2.size(), 256);
145
146 let d3 = Dim3::new(8, 8, 4);
147 assert_eq!(d3.size(), 256);
148 }
149
150 #[test]
151 fn test_block_validation() {
152 let valid_block = Block::new(256);
153 assert!(valid_block.validate().is_ok());
154
155 let invalid_block = Block::new(2048);
156 assert!(invalid_block.validate().is_err());
157 }
158}