1use crate::tensor::Tensor;
6use crate::error::{GhostError, Result};
7
8pub struct RocmDevice {
10 pub device_id: usize,
11 pub name: String,
12 pub compute_units: usize,
13 pub memory_size: usize,
14}
15
16impl RocmDevice {
17 pub fn new(device_id: usize) -> Result<Self> {
19 #[cfg(feature = "rocm")]
20 {
21 Ok(RocmDevice {
26 device_id,
27 name: format!("AMD GPU {}", device_id),
28 compute_units: 64, memory_size: 8 * 1024 * 1024 * 1024, })
31 }
32 #[cfg(not(feature = "rocm"))]
33 {
34 Err(GhostError::DeviceError(
35 "ROCm support not compiled. Enable 'rocm' feature.".to_string()
36 ))
37 }
38 }
39
40 pub fn device_count() -> Result<usize> {
42 #[cfg(feature = "rocm")]
43 {
44 Ok(1) }
47 #[cfg(not(feature = "rocm"))]
48 {
49 Ok(0)
50 }
51 }
52
53 pub fn synchronize(&self) -> Result<()> {
55 #[cfg(feature = "rocm")]
56 {
57 Ok(())
59 }
60 #[cfg(not(feature = "rocm"))]
61 {
62 Ok(())
63 }
64 }
65}
66
67pub struct RocmBuffer {
69 ptr: usize,
70 size: usize,
71 device_id: usize,
72}
73
74impl RocmBuffer {
75 pub fn allocate(size: usize, device_id: usize) -> Result<Self> {
77 #[cfg(feature = "rocm")]
78 {
79 Ok(RocmBuffer {
81 ptr: 0, size,
83 device_id,
84 })
85 }
86 #[cfg(not(feature = "rocm"))]
87 {
88 Err(GhostError::DeviceError("ROCm not available".to_string()))
89 }
90 }
91
92 pub fn copy_from_host(&mut self, data: &[f32]) -> Result<()> {
94 #[cfg(feature = "rocm")]
95 {
96 if data.len() * std::mem::size_of::<f32>() > self.size {
98 return Err(GhostError::DeviceError("Buffer too small".to_string()));
99 }
100 Ok(())
101 }
102 #[cfg(not(feature = "rocm"))]
103 {
104 let _ = data;
105 Err(GhostError::DeviceError("ROCm not available".to_string()))
106 }
107 }
108
109 pub fn copy_to_host(&self, data: &mut [f32]) -> Result<()> {
111 #[cfg(feature = "rocm")]
112 {
113 if data.len() * std::mem::size_of::<f32>() > self.size {
115 return Err(GhostError::DeviceError("Buffer too small".to_string()));
116 }
117 Ok(())
118 }
119 #[cfg(not(feature = "rocm"))]
120 {
121 let _ = data;
122 Err(GhostError::DeviceError("ROCm not available".to_string()))
123 }
124 }
125}
126
127impl Drop for RocmBuffer {
128 fn drop(&mut self) {
129 #[cfg(feature = "rocm")]
130 {
131 }
133 }
134}
135
136pub struct RocmKernel {
138 name: String,
139 grid_dim: (u32, u32, u32),
140 block_dim: (u32, u32, u32),
141}
142
143impl RocmKernel {
144 pub fn new(name: &str) -> Self {
146 RocmKernel {
147 name: name.to_string(),
148 grid_dim: (1, 1, 1),
149 block_dim: (256, 1, 1),
150 }
151 }
152
153 pub fn grid(mut self, x: u32, y: u32, z: u32) -> Self {
155 self.grid_dim = (x, y, z);
156 self
157 }
158
159 pub fn block(mut self, x: u32, y: u32, z: u32) -> Self {
161 self.block_dim = (x, y, z);
162 self
163 }
164
165 pub fn launch(&self) -> Result<()> {
167 #[cfg(feature = "rocm")]
168 {
169 Ok(())
171 }
172 #[cfg(not(feature = "rocm"))]
173 {
174 Err(GhostError::DeviceError("ROCm not available".to_string()))
175 }
176 }
177}
178
179pub mod ops {
181 use super::*;
182
183 pub fn matmul_rocm(a: &Tensor, b: &Tensor, device_id: usize) -> Result<Tensor> {
185 let dims_a = a.dims();
186 let dims_b = b.dims();
187
188 if dims_a.len() != 2 || dims_b.len() != 2 {
189 return Err(GhostError::InvalidShape("matmul requires 2D tensors".to_string()));
190 }
191
192 let (m, k) = (dims_a[0], dims_a[1]);
193 let (k2, n) = (dims_b[0], dims_b[1]);
194
195 if k != k2 {
196 return Err(GhostError::ShapeMismatch {
197 expected: vec![k],
198 got: vec![k2],
199 });
200 }
201
202 #[cfg(feature = "rocm")]
203 {
204 let size_a = m * k * std::mem::size_of::<f32>();
206 let size_b = k * n * std::mem::size_of::<f32>();
207 let size_c = m * n * std::mem::size_of::<f32>();
208
209 let mut buf_a = RocmBuffer::allocate(size_a, device_id)?;
210 let mut buf_b = RocmBuffer::allocate(size_b, device_id)?;
211 let buf_c = RocmBuffer::allocate(size_c, device_id)?;
212
213 buf_a.copy_from_host(&a.data_f32())?;
215 buf_b.copy_from_host(&b.data_f32())?;
216
217 let kernel = RocmKernel::new("matmul_kernel")
219 .grid((n as u32 + 15) / 16, (m as u32 + 15) / 16, 1)
220 .block(16, 16, 1);
221
222 kernel.launch()?;
223
224 let mut result_data = vec![0.0f32; m * n];
226 buf_c.copy_to_host(&mut result_data)?;
227
228 Tensor::from_slice(&result_data, &[m, n])
229 }
230 #[cfg(not(feature = "rocm"))]
231 {
232 let _ = device_id;
233 a.matmul(b)
235 }
236 }
237
238 pub fn conv2d_rocm(
240 input: &Tensor,
241 kernel: &Tensor,
242 stride: (usize, usize),
243 padding: (usize, usize),
244 device_id: usize,
245 ) -> Result<Tensor> {
246 let input_dims = input.dims();
247 let kernel_dims = kernel.dims();
248
249 if input_dims.len() != 4 || kernel_dims.len() != 4 {
250 return Err(GhostError::InvalidShape("conv2d requires 4D tensors [N,C,H,W]".to_string()));
251 }
252
253 let (batch, in_channels, in_h, in_w) = (input_dims[0], input_dims[1], input_dims[2], input_dims[3]);
254 let (out_channels, _, k_h, k_w) = (kernel_dims[0], kernel_dims[1], kernel_dims[2], kernel_dims[3]);
255
256 let out_h = (in_h + 2 * padding.0 - k_h) / stride.0 + 1;
257 let out_w = (in_w + 2 * padding.1 - k_w) / stride.1 + 1;
258
259 #[cfg(feature = "rocm")]
260 {
261 let _ = device_id;
263 Err(GhostError::NotImplemented("ROCm conv2d - use CPU fallback".to_string()))
265 }
266 #[cfg(not(feature = "rocm"))]
267 {
268 let _ = (device_id, stride, padding);
269 Err(GhostError::NotImplemented("conv2d CPU fallback".to_string()))
271 }
272 }
273
274 pub fn relu_rocm(input: &Tensor, device_id: usize) -> Result<Tensor> {
276 let data = input.data_f32();
277 let size = data.len();
278
279 #[cfg(feature = "rocm")]
280 {
281 let buf_size = size * std::mem::size_of::<f32>();
282 let mut buf = RocmBuffer::allocate(buf_size, device_id)?;
283 buf.copy_from_host(&data)?;
284
285 let kernel = RocmKernel::new("relu_kernel")
287 .grid((size as u32 + 255) / 256, 1, 1)
288 .block(256, 1, 1);
289
290 kernel.launch()?;
291
292 let mut result = vec![0.0f32; size];
293 buf.copy_to_host(&mut result)?;
294
295 Tensor::from_slice(&result, input.dims())
296 }
297 #[cfg(not(feature = "rocm"))]
298 {
299 let _ = device_id;
300 Ok(input.relu())
302 }
303 }
304
305 pub fn batch_norm_rocm(input: &Tensor, device_id: usize) -> Result<Tensor> {
307 #[cfg(feature = "rocm")]
308 {
309 let _ = (input, device_id);
310 Err(GhostError::NotImplemented("ROCm batch norm".to_string()))
311 }
312 #[cfg(not(feature = "rocm"))]
313 {
314 let _ = (input, device_id);
315 Err(GhostError::DeviceError("ROCm not available".to_string()))
316 }
317 }
318}
319
320#[cfg(feature = "rocm")]
322pub const ROCM_KERNEL_SOURCE: &str = r#"
323extern "C" __global__ void vector_add(float* a, float* b, float* c, int n) {
324 int idx = blockIdx.x * blockDim.x + threadIdx.x;
325 if (idx < n) {
326 c[idx] = a[idx] + b[idx];
327 }
328}
329
330extern "C" __global__ void relu_kernel(float* data, int n) {
331 int idx = blockIdx.x * blockDim.x + threadIdx.x;
332 if (idx < n) {
333 data[idx] = fmaxf(0.0f, data[idx]);
334 }
335}
336
337extern "C" __global__ void matmul_kernel(
338 float* A, float* B, float* C,
339 int M, int N, int K
340) {
341 int row = blockIdx.y * blockDim.y + threadIdx.y;
342 int col = blockIdx.x * blockDim.x + threadIdx.x;
343
344 if (row < M && col < N) {
345 float sum = 0.0f;
346 for (int k = 0; k < K; k++) {
347 sum += A[row * K + k] * B[k * N + col];
348 }
349 C[row * N + col] = sum;
350 }
351}
352"#;
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_rocm_device_count() {
360 let count = RocmDevice::device_count().unwrap_or(0);
361 assert!(count >= 0);
363 }
364
365 #[test]
366 #[cfg(feature = "rocm")]
367 fn test_rocm_device_creation() {
368 if let Ok(device) = RocmDevice::new(0) {
369 assert_eq!(device.device_id, 0);
370 assert!(!device.name.is_empty());
371 }
372 }
373
374 #[test]
375 fn test_rocm_kernel_config() {
376 let kernel = RocmKernel::new("test_kernel")
377 .grid(10, 1, 1)
378 .block(256, 1, 1);
379
380 assert_eq!(kernel.grid_dim, (10, 1, 1));
381 assert_eq!(kernel.block_dim, (256, 1, 1));
382 }
383}