1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
19
20use crate::blas::{gpu_matmul_f32, gpu_matmul_f64};
21use crate::buffer::CudaBuffer;
22use crate::conv::gpu_conv2d_f32;
23use crate::device::GpuDevice;
24use crate::error::{GpuError, GpuResult};
25use crate::kernels::{
26 gpu_add, gpu_add_f64, gpu_mul, gpu_mul_f64, gpu_neg, gpu_neg_f64, gpu_relu, gpu_relu_f64,
27 gpu_sub, gpu_sub_f64,
28};
29use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
30
31#[cfg(feature = "cuda")]
44pub trait GpuFloat: Float + cudarc::driver::DeviceRepr {}
45
46#[cfg(feature = "cuda")]
47impl GpuFloat for f32 {}
48#[cfg(feature = "cuda")]
49impl GpuFloat for f64 {}
50
51#[cfg(not(feature = "cuda"))]
52pub trait GpuFloat: Float {}
53
54#[cfg(not(feature = "cuda"))]
55impl GpuFloat for f32 {}
56#[cfg(not(feature = "cuda"))]
57impl GpuFloat for f64 {}
58
59pub struct GpuTensor<T: GpuFloat> {
72 buffer: CudaBuffer<T>,
73 shape: Vec<usize>,
74 device: GpuDevice,
75}
76
77impl<T: GpuFloat> GpuTensor<T> {
78 #[inline]
80 pub fn shape(&self) -> &[usize] {
81 &self.shape
82 }
83
84 #[inline]
86 pub fn numel(&self) -> usize {
87 self.shape.iter().product()
88 }
89
90 #[inline]
92 pub fn device(&self) -> &GpuDevice {
93 &self.device
94 }
95
96 #[inline]
98 pub fn buffer(&self) -> &CudaBuffer<T> {
99 &self.buffer
100 }
101
102 #[inline]
104 pub fn ndim(&self) -> usize {
105 self.shape.len()
106 }
107
108 pub fn cpu(&self) -> FerrotorchResult<Tensor<T>> {
112 tensor_to_cpu(self)
113 }
114}
115
116impl<T: GpuFloat> std::fmt::Debug for GpuTensor<T> {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("GpuTensor")
119 .field("shape", &self.shape)
120 .field("numel", &self.numel())
121 .field("device_ordinal", &self.device.ordinal())
122 .finish_non_exhaustive()
123 }
124}
125
126#[inline]
132fn is_f32<T: GpuFloat>() -> bool {
133 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
134}
135
136#[inline]
138fn is_f64<T: GpuFloat>() -> bool {
139 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>()
140}
141
142fn validate_shapes<T: GpuFloat>(a: &GpuTensor<T>, b: &GpuTensor<T>) -> GpuResult<()> {
144 if a.shape() != b.shape() {
145 return Err(GpuError::LengthMismatch {
146 a: a.numel(),
147 b: b.numel(),
148 });
149 }
150 if a.device.ordinal() != b.device.ordinal() {
151 return Err(GpuError::DeviceMismatch {
152 expected: a.device.ordinal(),
153 got: b.device.ordinal(),
154 });
155 }
156 Ok(())
157}
158
159impl<T: GpuFloat> GpuTensor<T> {
160 pub fn add(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
170 validate_shapes(self, other)?;
171 if is_f32::<T>() {
172 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
173 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
174 let out_buf = gpu_add(a_buf, b_buf, &self.device)?;
175 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
176 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
177 } else if is_f64::<T>() {
178 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
179 let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
180 let out_buf = gpu_add_f64(a_buf, b_buf, &self.device)?;
181 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
182 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
183 } else {
184 binary_cpu_fallback(self, other, |a, b| a + b)
185 }
186 }
187
188 pub fn sub(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
192 validate_shapes(self, other)?;
193 if is_f32::<T>() {
194 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
195 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
196 let out_buf = gpu_sub(a_buf, b_buf, &self.device)?;
197 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
198 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
199 } else if is_f64::<T>() {
200 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
201 let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
202 let out_buf = gpu_sub_f64(a_buf, b_buf, &self.device)?;
203 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
204 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
205 } else {
206 binary_cpu_fallback(self, other, |a, b| a - b)
207 }
208 }
209
210 pub fn mul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
214 validate_shapes(self, other)?;
215 if is_f32::<T>() {
216 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
217 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
218 let out_buf = gpu_mul(a_buf, b_buf, &self.device)?;
219 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
220 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
221 } else if is_f64::<T>() {
222 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
223 let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
224 let out_buf = gpu_mul_f64(a_buf, b_buf, &self.device)?;
225 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
226 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
227 } else {
228 binary_cpu_fallback(self, other, |a, b| a * b)
229 }
230 }
231
232 pub fn neg(&self) -> GpuResult<GpuTensor<T>> {
236 if is_f32::<T>() {
237 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
238 let out_buf = gpu_neg(a_buf, &self.device)?;
239 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
240 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
241 } else if is_f64::<T>() {
242 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
243 let out_buf = gpu_neg_f64(a_buf, &self.device)?;
244 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
245 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
246 } else {
247 unary_cpu_fallback(self, |x| -x)
248 }
249 }
250
251 pub fn relu(&self) -> GpuResult<GpuTensor<T>> {
255 if is_f32::<T>() {
256 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
257 let out_buf = gpu_relu(a_buf, &self.device)?;
258 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
259 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
260 } else if is_f64::<T>() {
261 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
262 let out_buf = gpu_relu_f64(a_buf, &self.device)?;
263 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
264 Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
265 } else {
266 unary_cpu_fallback(self, |x| {
267 let z = <T as num_traits::Zero>::zero();
268 if x > z { x } else { z }
269 })
270 }
271 }
272
273 pub fn matmul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
287 if self.ndim() != 2 {
289 return Err(GpuError::ShapeMismatch {
290 op: "matmul",
291 expected: vec![0, 0], got: self.shape.clone(),
293 });
294 }
295 if other.ndim() != 2 {
296 return Err(GpuError::ShapeMismatch {
297 op: "matmul",
298 expected: vec![0, 0],
299 got: other.shape.clone(),
300 });
301 }
302
303 let m = self.shape[0];
304 let k = self.shape[1];
305 let k2 = other.shape[0];
306 let n = other.shape[1];
307
308 if k != k2 {
309 return Err(GpuError::ShapeMismatch {
310 op: "matmul",
311 expected: vec![k, n],
312 got: vec![k2, n],
313 });
314 }
315
316 if self.device.ordinal() != other.device.ordinal() {
317 return Err(GpuError::DeviceMismatch {
318 expected: self.device.ordinal(),
319 got: other.device.ordinal(),
320 });
321 }
322
323 if is_f32::<T>() {
324 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
325 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
326 let out_buf = gpu_matmul_f32(a_buf, b_buf, m, k, n, &self.device)?;
327 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
328 Ok(GpuTensor {
329 buffer: out_buf,
330 shape: vec![m, n],
331 device: self.device.clone(),
332 })
333 } else {
334 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
336 let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
337 let out_buf = gpu_matmul_f64(a_buf, b_buf, m, k, n, &self.device)?;
338 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
339 Ok(GpuTensor {
340 buffer: out_buf,
341 shape: vec![m, n],
342 device: self.device.clone(),
343 })
344 }
345 }
346
347 pub fn conv2d(
365 &self,
366 weight: &GpuTensor<T>,
367 bias: Option<&GpuTensor<T>>,
368 stride: (usize, usize),
369 padding: (usize, usize),
370 ) -> GpuResult<GpuTensor<T>> {
371 if self.ndim() != 4 {
373 return Err(GpuError::ShapeMismatch {
374 op: "conv2d",
375 expected: vec![0, 0, 0, 0],
376 got: self.shape.clone(),
377 });
378 }
379 if weight.ndim() != 4 {
381 return Err(GpuError::ShapeMismatch {
382 op: "conv2d",
383 expected: vec![0, 0, 0, 0],
384 got: weight.shape.clone(),
385 });
386 }
387 if let Some(b) = bias {
389 if b.ndim() != 1 {
390 return Err(GpuError::ShapeMismatch {
391 op: "conv2d",
392 expected: vec![weight.shape[0]],
393 got: b.shape.clone(),
394 });
395 }
396 }
397 if self.device.ordinal() != weight.device.ordinal() {
399 return Err(GpuError::DeviceMismatch {
400 expected: self.device.ordinal(),
401 got: weight.device.ordinal(),
402 });
403 }
404 if let Some(b) = bias {
405 if self.device.ordinal() != b.device.ordinal() {
406 return Err(GpuError::DeviceMismatch {
407 expected: self.device.ordinal(),
408 got: b.device.ordinal(),
409 });
410 }
411 }
412
413 if !is_f32::<T>() {
414 return Err(GpuError::ShapeMismatch {
415 op: "conv2d",
416 expected: vec![],
417 got: vec![],
418 });
419 }
420
421 let input_shape: [usize; 4] = [self.shape[0], self.shape[1], self.shape[2], self.shape[3]];
422 let weight_shape: [usize; 4] = [
423 weight.shape[0],
424 weight.shape[1],
425 weight.shape[2],
426 weight.shape[3],
427 ];
428
429 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
430 let w_buf = unsafe { transmute_buffer_ref::<T, f32>(&weight.buffer) };
431 let b_buf = bias.map(|b| unsafe { transmute_buffer_ref::<T, f32>(&b.buffer) });
432
433 let (out_buf, out_shape) = gpu_conv2d_f32(
434 a_buf,
435 w_buf,
436 b_buf,
437 input_shape,
438 weight_shape,
439 stride,
440 padding,
441 &self.device,
442 )?;
443
444 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
445 Ok(GpuTensor {
446 buffer: out_buf,
447 shape: out_shape.to_vec(),
448 device: self.device.clone(),
449 })
450 }
451}
452
453#[cfg(feature = "cuda")]
468unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
469 debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
470 debug_assert_eq!(std::mem::align_of::<T>(), std::mem::align_of::<U>());
471 unsafe { &*(buf as *const CudaBuffer<T> as *const CudaBuffer<U>) }
475}
476
477#[cfg(feature = "cuda")]
484unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
485 debug_assert_eq!(std::mem::size_of::<U>(), std::mem::size_of::<T>());
486 debug_assert_eq!(std::mem::align_of::<U>(), std::mem::align_of::<T>());
487 let result = unsafe { std::ptr::read(&buf as *const CudaBuffer<U> as *const CudaBuffer<T>) };
489 std::mem::forget(buf);
490 result
491}
492
493#[cfg(not(feature = "cuda"))]
498unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
499 let _ = buf;
500 unreachable!("transmute_buffer_ref called without cuda feature")
501}
502
503#[cfg(not(feature = "cuda"))]
504unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
505 let _ = buf;
506 unreachable!("transmute_buffer called without cuda feature")
507}
508
509fn binary_cpu_fallback<T: GpuFloat>(
515 a: &GpuTensor<T>,
516 b: &GpuTensor<T>,
517 op: fn(T, T) -> T,
518) -> GpuResult<GpuTensor<T>> {
519 let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
520 let b_cpu = gpu_to_cpu(&b.buffer, &b.device)?;
521 let result: Vec<T> = a_cpu
522 .iter()
523 .zip(b_cpu.iter())
524 .map(|(&x, &y)| op(x, y))
525 .collect();
526 let out_buf = cpu_to_gpu(&result, &a.device)?;
527 Ok(GpuTensor {
528 buffer: out_buf,
529 shape: a.shape.clone(),
530 device: a.device.clone(),
531 })
532}
533
534fn unary_cpu_fallback<T: GpuFloat>(a: &GpuTensor<T>, op: fn(T) -> T) -> GpuResult<GpuTensor<T>> {
536 let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
537 let result: Vec<T> = a_cpu.iter().map(|&x| op(x)).collect();
538 let out_buf = cpu_to_gpu(&result, &a.device)?;
539 Ok(GpuTensor {
540 buffer: out_buf,
541 shape: a.shape.clone(),
542 device: a.device.clone(),
543 })
544}
545
546pub fn tensor_to_gpu<T: GpuFloat>(
560 tensor: &Tensor<T>,
561 device: &GpuDevice,
562) -> GpuResult<GpuTensor<T>> {
563 if !tensor.is_contiguous() {
565 return Err(GpuError::LengthMismatch {
566 a: tensor.numel(),
567 b: tensor.data().map_or(0, |d| d.len()),
568 });
569 }
570
571 let data = tensor.data().map_err(|_e| GpuError::InvalidDevice {
573 ordinal: device.ordinal(),
574 count: 0,
575 })?;
576
577 let buffer = cpu_to_gpu(data, device)?;
578 Ok(GpuTensor {
579 buffer,
580 shape: tensor.shape().to_vec(),
581 device: device.clone(),
582 })
583}
584
585pub fn tensor_to_cpu<T: GpuFloat>(gpu_tensor: &GpuTensor<T>) -> FerrotorchResult<Tensor<T>> {
594 let host_data = gpu_to_cpu(&gpu_tensor.buffer, &gpu_tensor.device).map_err(|e| {
595 FerrotorchError::InvalidArgument {
596 message: format!("GPU-to-CPU transfer failed: {e}"),
597 }
598 })?;
599
600 let storage = TensorStorage::cpu(host_data);
601 Tensor::from_storage(storage, gpu_tensor.shape.clone(), false)
602}
603
604pub fn cuda<T: GpuFloat>(tensor: &Tensor<T>, ordinal: usize) -> GpuResult<GpuTensor<T>> {
617 let device = GpuDevice::new(ordinal)?;
618 tensor_to_gpu(tensor, &device)
619}
620
621pub fn cuda_default<T: GpuFloat>(tensor: &Tensor<T>) -> GpuResult<GpuTensor<T>> {
625 cuda(tensor, 0)
626}
627
628#[cfg(test)]
633#[cfg(feature = "cuda")]
634mod tests {
635 use super::*;
636 use ferrotorch_core::{Tensor, TensorStorage};
637
638 fn cpu_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
640 let storage = TensorStorage::cpu(data);
641 Tensor::from_storage(storage, shape, false).expect("cpu_tensor")
642 }
643
644 #[test]
647 fn tensor_to_gpu_round_trip() {
648 let t = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
649 let gpu = cuda_default(&t).expect("cuda_default");
650 let back = gpu.cpu().expect("cpu()");
651
652 assert_eq!(back.shape(), &[2, 3]);
653 assert_eq!(back.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
654 }
655
656 #[test]
659 fn gpu_tensor_shape_preserved() {
660 let t = cpu_tensor(vec![1.0; 24], vec![2, 3, 4]);
661 let gpu = cuda_default(&t).expect("cuda_default");
662
663 assert_eq!(gpu.shape(), &[2, 3, 4]);
664 assert_eq!(gpu.numel(), 24);
665 assert_eq!(gpu.ndim(), 3);
666 }
667
668 #[test]
671 fn gpu_tensor_add() {
672 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
673 let b = cpu_tensor(vec![10.0, 20.0, 30.0, 40.0], vec![4]);
674
675 let device = GpuDevice::new(0).expect("CUDA device 0");
676 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
677 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
678
679 let gc = ga.add(&gb).expect("gpu add");
680 let result = gc.cpu().expect("cpu");
681
682 assert_eq!(result.shape(), &[4]);
683 let data = result.data().unwrap();
684 assert!((data[0] - 11.0).abs() < 1e-6);
685 assert!((data[1] - 22.0).abs() < 1e-6);
686 assert!((data[2] - 33.0).abs() < 1e-6);
687 assert!((data[3] - 44.0).abs() < 1e-6);
688 }
689
690 #[test]
693 fn gpu_tensor_relu() {
694 let t = cpu_tensor(vec![-3.0, -1.0, 0.0, 1.0, 3.0], vec![5]);
695 let gpu = cuda_default(&t).expect("cuda_default");
696 let out = gpu.relu().expect("relu");
697 let result = out.cpu().expect("cpu");
698
699 let data = result.data().unwrap();
700 assert!((data[0] - 0.0).abs() < 1e-6);
701 assert!((data[1] - 0.0).abs() < 1e-6);
702 assert!((data[2] - 0.0).abs() < 1e-6);
703 assert!((data[3] - 1.0).abs() < 1e-6);
704 assert!((data[4] - 3.0).abs() < 1e-6);
705 }
706
707 #[test]
710 fn tensor_to_cpu_correct_values() {
711 let original = vec![0.5, -1.5, 2.25, 0.0, 100.0, -0.001];
712 let t = cpu_tensor(original.clone(), vec![2, 3]);
713 let gpu = cuda_default(&t).expect("cuda_default");
714 let back = tensor_to_cpu(&gpu).expect("tensor_to_cpu");
715
716 let data = back.data().unwrap();
717 for (i, (&got, &expected)) in data.iter().zip(original.iter()).enumerate() {
718 assert!(
719 (got - expected).abs() < 1e-6,
720 "element {i}: got {got}, expected {expected}",
721 );
722 }
723 }
724
725 #[test]
728 fn gpu_tensor_sub() {
729 let a = cpu_tensor(vec![10.0, 20.0, 30.0], vec![3]);
730 let b = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
731
732 let device = GpuDevice::new(0).expect("CUDA device 0");
733 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
734 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
735
736 let gc = ga.sub(&gb).expect("gpu sub");
737 let result = gc.cpu().expect("cpu");
738 let data = result.data().unwrap();
739 assert!((data[0] - 9.0).abs() < 1e-6);
740 assert!((data[1] - 18.0).abs() < 1e-6);
741 assert!((data[2] - 27.0).abs() < 1e-6);
742 }
743
744 #[test]
747 fn gpu_tensor_mul() {
748 let a = cpu_tensor(vec![2.0, 3.0, 4.0], vec![3]);
749 let b = cpu_tensor(vec![10.0, 10.0, 10.0], vec![3]);
750
751 let device = GpuDevice::new(0).expect("CUDA device 0");
752 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
753 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
754
755 let gc = ga.mul(&gb).expect("gpu mul");
756 let result = gc.cpu().expect("cpu");
757 let data = result.data().unwrap();
758 assert!((data[0] - 20.0).abs() < 1e-6);
759 assert!((data[1] - 30.0).abs() < 1e-6);
760 assert!((data[2] - 40.0).abs() < 1e-6);
761 }
762
763 #[test]
766 fn gpu_tensor_neg() {
767 let t = cpu_tensor(vec![1.0, -2.0, 0.0, 3.5], vec![4]);
768 let gpu = cuda_default(&t).expect("cuda_default");
769 let out = gpu.neg().expect("neg");
770 let result = out.cpu().expect("cpu");
771 let data = result.data().unwrap();
772 assert!((data[0] - (-1.0)).abs() < 1e-6);
773 assert!((data[1] - 2.0).abs() < 1e-6);
774 assert!((data[2] - 0.0).abs() < 1e-6);
775 assert!((data[3] - (-3.5)).abs() < 1e-6);
776 }
777
778 #[test]
781 fn gpu_tensor_matmul_basic() {
782 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
786 let b = cpu_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
787
788 let device = GpuDevice::new(0).expect("CUDA device 0");
789 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
790 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
791
792 let gc = ga.matmul(&gb).expect("gpu matmul");
793 assert_eq!(gc.shape(), &[2, 2]);
794
795 let result = gc.cpu().expect("cpu");
796 let data = result.data().unwrap();
797 assert!((data[0] - 58.0).abs() < 1e-4);
798 assert!((data[1] - 64.0).abs() < 1e-4);
799 assert!((data[2] - 139.0).abs() < 1e-4);
800 assert!((data[3] - 154.0).abs() < 1e-4);
801 }
802
803 #[test]
804 fn gpu_tensor_matmul_identity() {
805 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
806 let i = cpu_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
807
808 let device = GpuDevice::new(0).expect("CUDA device 0");
809 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
810 let gi = tensor_to_gpu(&i, &device).expect("i to gpu");
811
812 let gc = ga.matmul(&gi).expect("gpu matmul identity");
813 let result = gc.cpu().expect("cpu");
814 let data = result.data().unwrap();
815 assert!((data[0] - 1.0).abs() < 1e-6);
816 assert!((data[1] - 2.0).abs() < 1e-6);
817 assert!((data[2] - 3.0).abs() < 1e-6);
818 assert!((data[3] - 4.0).abs() < 1e-6);
819 }
820
821 #[test]
822 fn gpu_tensor_matmul_inner_dim_mismatch() {
823 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
825 let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
826
827 let device = GpuDevice::new(0).expect("CUDA device 0");
828 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
829 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
830
831 let err = ga.matmul(&gb).unwrap_err();
832 match err {
833 GpuError::ShapeMismatch { op: "matmul", .. } => {}
834 other => panic!("unexpected error: {other}"),
835 }
836 }
837
838 #[test]
839 fn gpu_tensor_matmul_not_2d() {
840 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
842 let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
843
844 let device = GpuDevice::new(0).expect("CUDA device 0");
845 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
846 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
847
848 let err = ga.matmul(&gb).unwrap_err();
849 match err {
850 GpuError::ShapeMismatch { op: "matmul", .. } => {}
851 other => panic!("unexpected error: {other}"),
852 }
853 }
854
855 #[test]
858 fn gpu_tensor_add_shape_mismatch() {
859 let a = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
860 let b = cpu_tensor(vec![1.0, 2.0], vec![2]);
861
862 let device = GpuDevice::new(0).expect("CUDA device 0");
863 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
864 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
865
866 let err = ga.add(&gb).unwrap_err();
867 match err {
868 GpuError::LengthMismatch { .. } => {}
869 other => panic!("unexpected error: {other}"),
870 }
871 }
872
873 #[test]
876 fn gpu_tensor_empty_round_trip() {
877 let t = cpu_tensor(vec![], vec![0]);
878 let gpu = cuda_default(&t).expect("cuda_default");
879 assert_eq!(gpu.numel(), 0);
880 assert_eq!(gpu.shape(), &[0]);
881
882 let back = gpu.cpu().expect("cpu");
883 assert_eq!(back.shape(), &[0]);
884 assert_eq!(back.data().unwrap().len(), 0);
885 }
886
887 #[test]
890 fn gpu_tensor_scalar_round_trip() {
891 let storage = TensorStorage::cpu(vec![42.0f32]);
892 let t = Tensor::from_storage(storage, vec![], false).expect("scalar");
893 let gpu = cuda_default(&t).expect("cuda_default");
894 assert_eq!(gpu.shape(), &[] as &[usize]);
895 assert_eq!(gpu.numel(), 1);
896
897 let back = gpu.cpu().expect("cpu");
898 assert!(back.is_scalar());
899 assert!((back.item().unwrap() - 42.0).abs() < 1e-6);
900 }
901}