1use crate::error::{CudaError, CudaResult};
4use crate::stream::CudaStream;
5use crate::blas::CuBlas;
6use ghostflow_core::{DType, Shape, Strides, Tensor};
7
8#[derive(Debug)]
10pub struct CudaTensor {
11 ptr: *mut f32,
13 shape: Shape,
15 strides: Strides,
17 dtype: DType,
19 device_id: i32,
21 size_bytes: usize,
23}
24
25impl CudaTensor {
26 pub fn new(shape: &[usize], dtype: DType, device_id: i32) -> CudaResult<Self> {
28 let shape = Shape::new(shape);
29 let strides = shape.default_strides();
30 let size_bytes = shape.numel() * dtype.size_bytes();
31
32 let pool = crate::memory::get_global_gpu_pool();
34 let ptr = (*pool).allocate(size_bytes)
35 .map_err(|_e| CudaError::OutOfMemory)? as *mut f32;
36
37 Ok(CudaTensor {
38 ptr,
39 shape,
40 strides,
41 dtype,
42 device_id,
43 size_bytes,
44 })
45 }
46
47 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
49 pub fn from_tensor(tensor: &Tensor, device_id: i32) -> CudaResult<Self> {
50 let shape = tensor.shape().clone();
51 let strides = shape.default_strides();
52 let dtype = tensor.dtype();
53 let size_bytes = shape.numel() * dtype.size_bytes();
54
55 let data = tensor.data_f32();
56
57 let pool = crate::memory::get_global_gpu_pool();
59 let ptr = (*pool).allocate(size_bytes)
60 .map_err(|_| CudaError::OutOfMemory)? as *mut f32;
61
62 #[cfg(feature = "cuda")]
64 unsafe {
65 let _ = (ptr, data.as_ptr(), data.len());
67 }
68
69 Ok(CudaTensor {
70 ptr,
71 shape,
72 strides,
73 dtype,
74 device_id,
75 size_bytes,
76 })
77 }
78
79 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
81 pub fn from_tensor_async(tensor: &Tensor, device_id: i32, stream: &CudaStream) -> CudaResult<Self> {
82 let shape = tensor.shape().clone();
83 let strides = shape.default_strides();
84 let dtype = tensor.dtype();
85 let size_bytes = shape.numel() * dtype.size_bytes();
86
87 let data = tensor.data_f32();
88
89 let pool = crate::memory::get_global_gpu_pool();
91 let ptr = (*pool).allocate(size_bytes)
92 .map_err(|_| CudaError::OutOfMemory)? as *mut f32;
93
94 #[cfg(feature = "cuda")]
96 unsafe {
97 let _ = (ptr, data.as_ptr(), data.len(), stream);
99 }
100
101 Ok(CudaTensor {
102 ptr,
103 shape,
104 strides,
105 dtype,
106 device_id,
107 size_bytes,
108 })
109 }
110
111 pub fn to_tensor(&self) -> CudaResult<Tensor> {
113 let data = vec![0.0f32; self.numel()];
114
115 #[cfg(feature = "cuda")]
116 unsafe {
117 let _ = (data.as_ptr(), self.ptr, data.len());
119 }
120
121 Tensor::from_slice(&data, self.shape.dims())
122 .map_err(|e| CudaError::InvalidValue(e.to_string()))
123 }
124
125 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
127 pub fn to_tensor_async(&self, stream: &CudaStream) -> CudaResult<Vec<f32>> {
128 let data = vec![0.0f32; self.numel()];
129
130 #[cfg(feature = "cuda")]
131 unsafe {
132 let _ = (data.as_ptr(), self.ptr, data.len(), stream);
134 }
135
136 Ok(data)
137 }
138
139 pub fn shape(&self) -> &Shape {
141 &self.shape
142 }
143
144 pub fn dims(&self) -> &[usize] {
146 self.shape.dims()
147 }
148
149 pub fn numel(&self) -> usize {
151 self.shape.numel()
152 }
153
154 pub fn dtype(&self) -> DType {
156 self.dtype
157 }
158
159 pub fn device_id(&self) -> i32 {
161 self.device_id
162 }
163
164 pub fn as_ptr(&self) -> *const f32 {
166 self.ptr
167 }
168
169 pub fn as_mut_ptr(&mut self) -> *mut f32 {
171 self.ptr
172 }
173
174 pub fn size_bytes(&self) -> usize {
176 self.size_bytes
177 }
178
179 pub fn zeros(shape: &[usize], device_id: i32) -> CudaResult<Self> {
181 let tensor = Self::new(shape, DType::F32, device_id)?;
182
183 #[cfg(feature = "cuda")]
184 unsafe {
185 let _ = (tensor.ptr, tensor.size_bytes);
187 }
188
189 Ok(tensor)
190 }
191
192 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
194 pub fn ones(shape: &[usize], device_id: i32) -> CudaResult<Self> {
195 let shape_obj = Shape::new(shape);
196 let numel = shape_obj.numel();
197 let data: Vec<f32> = vec![1.0; numel];
198 let size_bytes = numel * std::mem::size_of::<f32>();
199
200 let pool = crate::memory::get_global_gpu_pool();
202 let ptr = (*pool).allocate(size_bytes)
203 .map_err(|_| CudaError::OutOfMemory)? as *mut f32;
204
205 #[cfg(feature = "cuda")]
207 unsafe {
208 let _ = (ptr, data.as_ptr(), size_bytes);
210 }
211
212 Ok(CudaTensor {
213 ptr,
214 shape: shape_obj.clone(),
215 strides: shape_obj.default_strides(),
216 dtype: DType::F32,
217 device_id,
218 size_bytes,
219 })
220 }
221
222 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
224 pub fn full(shape: &[usize], value: f32, device_id: i32) -> CudaResult<Self> {
225 let shape_obj = Shape::new(shape);
226 let numel = shape_obj.numel();
227 let data: Vec<f32> = vec![value; numel];
228 let size_bytes = numel * std::mem::size_of::<f32>();
229
230 let pool = crate::memory::get_global_gpu_pool();
232 let ptr = (*pool).allocate(size_bytes)
233 .map_err(|_| CudaError::OutOfMemory)? as *mut f32;
234
235 #[cfg(feature = "cuda")]
237 unsafe {
238 let _ = (ptr, data.as_ptr(), size_bytes);
240 }
241
242 Ok(CudaTensor {
243 ptr,
244 shape: shape_obj.clone(),
245 strides: shape_obj.default_strides(),
246 dtype: DType::F32,
247 device_id,
248 size_bytes,
249 })
250 }
251
252 pub fn reshape(&self, new_shape: &[usize]) -> CudaResult<Self> {
254 let new_shape = Shape::new(new_shape);
255 if new_shape.numel() != self.shape.numel() {
256 return Err(CudaError::InvalidValue(
257 format!("Cannot reshape {} elements to {:?}", self.numel(), new_shape.dims())
258 ));
259 }
260
261 let pool = crate::memory::get_global_gpu_pool();
263 let ptr = (*pool).allocate(self.size_bytes)
264 .map_err(|_| CudaError::OutOfMemory)? as *mut f32;
265
266 #[cfg(feature = "cuda")]
267 unsafe {
268 let _ = (ptr, self.ptr, self.size_bytes);
270 }
271
272 Ok(CudaTensor {
273 ptr,
274 shape: new_shape.clone(),
275 strides: new_shape.default_strides(),
276 dtype: self.dtype,
277 device_id: self.device_id,
278 size_bytes: self.size_bytes,
279 })
280 }
281
282 pub fn transpose(&self, dim0: usize, dim1: usize) -> CudaResult<Self> {
284 let cpu = self.to_tensor()?;
287 let transposed = cpu.transpose(dim0, dim1)
288 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
289 Self::from_tensor(&transposed, self.device_id)
290 }
291
292 pub fn is_contiguous(&self) -> bool {
294 self.strides.is_contiguous(&self.shape)
295 }
296
297 pub fn clone_tensor(&self) -> CudaResult<Self> {
299 let pool = crate::memory::get_global_gpu_pool();
300 let ptr = (*pool).allocate(self.size_bytes)
301 .map_err(|_| CudaError::OutOfMemory)? as *mut f32;
302
303 #[cfg(feature = "cuda")]
304 unsafe {
305 let _ = (ptr, self.ptr, self.size_bytes);
307 }
308
309 Ok(CudaTensor {
310 ptr,
311 shape: self.shape.clone(),
312 strides: self.strides.clone(),
313 dtype: self.dtype,
314 device_id: self.device_id,
315 size_bytes: self.size_bytes,
316 })
317 }
318
319 pub fn add(&self, other: &CudaTensor) -> CudaResult<CudaTensor> {
323 if self.shape.dims() != other.shape.dims() {
324 return Err(CudaError::InvalidValue("Shape mismatch for add".into()));
325 }
326
327 let mut result = self.clone_tensor()?;
329
330 let cublas = CuBlas::new()?;
332 cublas.saxpy(
333 self.numel() as i32,
334 1.0,
335 other.as_ptr() as *const f32,
336 1,
337 result.as_mut_ptr() as *mut f32,
338 1,
339 )?;
340
341 Ok(result)
342 }
343
344 pub fn sub(&self, other: &CudaTensor) -> CudaResult<CudaTensor> {
346 if self.shape.dims() != other.shape.dims() {
347 return Err(CudaError::InvalidValue("Shape mismatch for sub".into()));
348 }
349
350 let mut result = self.clone_tensor()?;
351
352 let cublas = CuBlas::new()?;
354 cublas.saxpy(
355 self.numel() as i32,
356 -1.0,
357 other.as_ptr() as *const f32,
358 1,
359 result.as_mut_ptr() as *mut f32,
360 1,
361 )?;
362
363 Ok(result)
364 }
365
366 pub fn mul_scalar(&self, scalar: f32) -> CudaResult<CudaTensor> {
368 let mut result = self.clone_tensor()?;
369
370 let cublas = CuBlas::new()?;
371 cublas.sscal(
372 self.numel() as i32,
373 scalar,
374 result.as_mut_ptr() as *mut f32,
375 1,
376 )?;
377
378 Ok(result)
379 }
380
381 pub fn matmul(&self, other: &CudaTensor) -> CudaResult<CudaTensor> {
383 let cublas = CuBlas::new()?;
384 cublas.matmul(self, other)
385 }
386
387 pub fn dot(&self, other: &CudaTensor) -> CudaResult<f32> {
389 if self.numel() != other.numel() {
390 return Err(CudaError::InvalidValue("Size mismatch for dot".into()));
391 }
392
393 let cublas = CuBlas::new()?;
394 cublas.sdot(
395 self.numel() as i32,
396 self.as_ptr() as *const f32,
397 1,
398 other.as_ptr() as *const f32,
399 1,
400 )
401 }
402
403 pub fn norm(&self) -> CudaResult<f32> {
405 let cublas = CuBlas::new()?;
406 cublas.snrm2(
407 self.numel() as i32,
408 self.as_ptr() as *const f32,
409 1,
410 )
411 }
412
413 pub fn sum(&self) -> CudaResult<f32> {
415 let cpu = self.to_tensor()?;
418 Ok(cpu.data_f32().iter().sum())
419 }
420
421 pub fn mean(&self) -> CudaResult<f32> {
423 let sum = self.sum()?;
424 Ok(sum / self.numel() as f32)
425 }
426
427 pub fn max(&self) -> CudaResult<f32> {
429 let cpu = self.to_tensor()?;
430 Ok(cpu.data_f32().iter().cloned().fold(f32::NEG_INFINITY, f32::max))
431 }
432
433 pub fn min(&self) -> CudaResult<f32> {
435 let cpu = self.to_tensor()?;
436 Ok(cpu.data_f32().iter().cloned().fold(f32::INFINITY, f32::min))
437 }
438
439 pub fn relu(&self) -> CudaResult<CudaTensor> {
441 let cpu = self.to_tensor()?;
443 let result = cpu.relu();
444 Self::from_tensor(&result, self.device_id)
445 }
446
447 pub fn sigmoid(&self) -> CudaResult<CudaTensor> {
449 let cpu = self.to_tensor()?;
450 let result = cpu.sigmoid();
451 Self::from_tensor(&result, self.device_id)
452 }
453
454 pub fn gelu(&self) -> CudaResult<CudaTensor> {
456 let cpu = self.to_tensor()?;
457 let result = cpu.gelu();
458 Self::from_tensor(&result, self.device_id)
459 }
460
461 pub fn softmax(&self, dim: i32) -> CudaResult<CudaTensor> {
463 let cpu = self.to_tensor()?;
464 let result = cpu.softmax(dim);
465 Self::from_tensor(&result, self.device_id)
466 }
467
468 pub fn exp(&self) -> CudaResult<CudaTensor> {
470 let cpu = self.to_tensor()?;
471 let result = cpu.exp();
472 Self::from_tensor(&result, self.device_id)
473 }
474
475 pub fn log(&self) -> CudaResult<CudaTensor> {
477 let cpu = self.to_tensor()?;
478 let result = cpu.log();
479 Self::from_tensor(&result, self.device_id)
480 }
481
482 pub fn sqrt(&self) -> CudaResult<CudaTensor> {
484 let cpu = self.to_tensor()?;
485 let result = cpu.sqrt();
486 Self::from_tensor(&result, self.device_id)
487 }
488
489 pub fn pow(&self, exp: f32) -> CudaResult<CudaTensor> {
491 let cpu = self.to_tensor()?;
492 let result = cpu.pow(exp);
493 Self::from_tensor(&result, self.device_id)
494 }
495}
496
497impl Clone for CudaTensor {
498 fn clone(&self) -> Self {
499 self.clone_tensor().expect("Failed to clone CudaTensor")
500 }
501}