1use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
19
20use crate::blas::{gpu_matmul_f32, gpu_matmul_f64};
21use crate::buffer::CudaBuffer;
22use crate::conv::gpu_conv2d_f32;
23use crate::device::GpuDevice;
24use crate::error::{GpuError, GpuResult};
25use crate::kernels::{gpu_add, gpu_mul, gpu_neg, gpu_relu, gpu_sub};
26use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
27
28#[cfg(feature = "cuda")]
41pub trait GpuFloat: Float + cudarc::driver::DeviceRepr {}
42
43#[cfg(feature = "cuda")]
44impl GpuFloat for f32 {}
45#[cfg(feature = "cuda")]
46impl GpuFloat for f64 {}
47
48#[cfg(not(feature = "cuda"))]
49pub trait GpuFloat: Float {}
50
51#[cfg(not(feature = "cuda"))]
52impl GpuFloat for f32 {}
53#[cfg(not(feature = "cuda"))]
54impl GpuFloat for f64 {}
55
56pub struct GpuTensor<T: GpuFloat> {
69 buffer: CudaBuffer<T>,
70 shape: Vec<usize>,
71 device: GpuDevice,
72}
73
74impl<T: GpuFloat> GpuTensor<T> {
75 #[inline]
77 pub fn shape(&self) -> &[usize] {
78 &self.shape
79 }
80
81 #[inline]
83 pub fn numel(&self) -> usize {
84 self.shape.iter().product()
85 }
86
87 #[inline]
89 pub fn device(&self) -> &GpuDevice {
90 &self.device
91 }
92
93 #[inline]
95 pub fn buffer(&self) -> &CudaBuffer<T> {
96 &self.buffer
97 }
98
99 #[inline]
101 pub fn ndim(&self) -> usize {
102 self.shape.len()
103 }
104
105 pub fn cpu(&self) -> FerrotorchResult<Tensor<T>> {
109 tensor_to_cpu(self)
110 }
111}
112
113impl<T: GpuFloat> std::fmt::Debug for GpuTensor<T> {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("GpuTensor")
116 .field("shape", &self.shape)
117 .field("numel", &self.numel())
118 .field("device_ordinal", &self.device.ordinal())
119 .finish_non_exhaustive()
120 }
121}
122
123#[inline]
129fn is_f32<T: GpuFloat>() -> bool {
130 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
131}
132
133fn validate_shapes<T: GpuFloat>(a: &GpuTensor<T>, b: &GpuTensor<T>) -> GpuResult<()> {
135 if a.shape() != b.shape() {
136 return Err(GpuError::LengthMismatch {
137 a: a.numel(),
138 b: b.numel(),
139 });
140 }
141 if a.device.ordinal() != b.device.ordinal() {
142 return Err(GpuError::DeviceMismatch {
143 expected: a.device.ordinal(),
144 got: b.device.ordinal(),
145 });
146 }
147 Ok(())
148}
149
150impl<T: GpuFloat> GpuTensor<T> {
151 pub fn add(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
161 validate_shapes(self, other)?;
162 if is_f32::<T>() {
163 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
165 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
166 let out_buf = gpu_add(a_buf, b_buf, &self.device)?;
167 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
168 Ok(GpuTensor {
169 buffer: out_buf,
170 shape: self.shape.clone(),
171 device: self.device.clone(),
172 })
173 } else {
174 binary_cpu_fallback(self, other, |a, b| a + b)
175 }
176 }
177
178 pub fn sub(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
182 validate_shapes(self, other)?;
183 if is_f32::<T>() {
184 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
185 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
186 let out_buf = gpu_sub(a_buf, b_buf, &self.device)?;
187 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
188 Ok(GpuTensor {
189 buffer: out_buf,
190 shape: self.shape.clone(),
191 device: self.device.clone(),
192 })
193 } else {
194 binary_cpu_fallback(self, other, |a, b| a - b)
195 }
196 }
197
198 pub fn mul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
202 validate_shapes(self, other)?;
203 if is_f32::<T>() {
204 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
205 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
206 let out_buf = gpu_mul(a_buf, b_buf, &self.device)?;
207 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
208 Ok(GpuTensor {
209 buffer: out_buf,
210 shape: self.shape.clone(),
211 device: self.device.clone(),
212 })
213 } else {
214 binary_cpu_fallback(self, other, |a, b| a * b)
215 }
216 }
217
218 pub fn neg(&self) -> GpuResult<GpuTensor<T>> {
222 if is_f32::<T>() {
223 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
224 let out_buf = gpu_neg(a_buf, &self.device)?;
225 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
226 Ok(GpuTensor {
227 buffer: out_buf,
228 shape: self.shape.clone(),
229 device: self.device.clone(),
230 })
231 } else {
232 unary_cpu_fallback(self, |x| -x)
233 }
234 }
235
236 pub fn relu(&self) -> GpuResult<GpuTensor<T>> {
240 if is_f32::<T>() {
241 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
242 let out_buf = gpu_relu(a_buf, &self.device)?;
243 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
244 Ok(GpuTensor {
245 buffer: out_buf,
246 shape: self.shape.clone(),
247 device: self.device.clone(),
248 })
249 } else {
250 unary_cpu_fallback(self, |x| {
251 let z = <T as num_traits::Zero>::zero();
252 if x > z { x } else { z }
253 })
254 }
255 }
256
257 pub fn matmul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
271 if self.ndim() != 2 {
273 return Err(GpuError::ShapeMismatch {
274 op: "matmul",
275 expected: vec![0, 0], got: self.shape.clone(),
277 });
278 }
279 if other.ndim() != 2 {
280 return Err(GpuError::ShapeMismatch {
281 op: "matmul",
282 expected: vec![0, 0],
283 got: other.shape.clone(),
284 });
285 }
286
287 let m = self.shape[0];
288 let k = self.shape[1];
289 let k2 = other.shape[0];
290 let n = other.shape[1];
291
292 if k != k2 {
293 return Err(GpuError::ShapeMismatch {
294 op: "matmul",
295 expected: vec![k, n],
296 got: vec![k2, n],
297 });
298 }
299
300 if self.device.ordinal() != other.device.ordinal() {
301 return Err(GpuError::DeviceMismatch {
302 expected: self.device.ordinal(),
303 got: other.device.ordinal(),
304 });
305 }
306
307 if is_f32::<T>() {
308 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
309 let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
310 let out_buf = gpu_matmul_f32(a_buf, b_buf, m, k, n, &self.device)?;
311 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
312 Ok(GpuTensor {
313 buffer: out_buf,
314 shape: vec![m, n],
315 device: self.device.clone(),
316 })
317 } else {
318 let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
320 let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
321 let out_buf = gpu_matmul_f64(a_buf, b_buf, m, k, n, &self.device)?;
322 let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
323 Ok(GpuTensor {
324 buffer: out_buf,
325 shape: vec![m, n],
326 device: self.device.clone(),
327 })
328 }
329 }
330
331 pub fn conv2d(
349 &self,
350 weight: &GpuTensor<T>,
351 bias: Option<&GpuTensor<T>>,
352 stride: (usize, usize),
353 padding: (usize, usize),
354 ) -> GpuResult<GpuTensor<T>> {
355 if self.ndim() != 4 {
357 return Err(GpuError::ShapeMismatch {
358 op: "conv2d",
359 expected: vec![0, 0, 0, 0],
360 got: self.shape.clone(),
361 });
362 }
363 if weight.ndim() != 4 {
365 return Err(GpuError::ShapeMismatch {
366 op: "conv2d",
367 expected: vec![0, 0, 0, 0],
368 got: weight.shape.clone(),
369 });
370 }
371 if let Some(b) = bias {
373 if b.ndim() != 1 {
374 return Err(GpuError::ShapeMismatch {
375 op: "conv2d",
376 expected: vec![weight.shape[0]],
377 got: b.shape.clone(),
378 });
379 }
380 }
381 if self.device.ordinal() != weight.device.ordinal() {
383 return Err(GpuError::DeviceMismatch {
384 expected: self.device.ordinal(),
385 got: weight.device.ordinal(),
386 });
387 }
388 if let Some(b) = bias {
389 if self.device.ordinal() != b.device.ordinal() {
390 return Err(GpuError::DeviceMismatch {
391 expected: self.device.ordinal(),
392 got: b.device.ordinal(),
393 });
394 }
395 }
396
397 if !is_f32::<T>() {
398 return Err(GpuError::ShapeMismatch {
399 op: "conv2d",
400 expected: vec![],
401 got: vec![],
402 });
403 }
404
405 let input_shape: [usize; 4] = [self.shape[0], self.shape[1], self.shape[2], self.shape[3]];
406 let weight_shape: [usize; 4] = [
407 weight.shape[0],
408 weight.shape[1],
409 weight.shape[2],
410 weight.shape[3],
411 ];
412
413 let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
414 let w_buf = unsafe { transmute_buffer_ref::<T, f32>(&weight.buffer) };
415 let b_buf = bias.map(|b| unsafe { transmute_buffer_ref::<T, f32>(&b.buffer) });
416
417 let (out_buf, out_shape) = gpu_conv2d_f32(
418 a_buf,
419 w_buf,
420 b_buf,
421 input_shape,
422 weight_shape,
423 stride,
424 padding,
425 &self.device,
426 )?;
427
428 let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
429 Ok(GpuTensor {
430 buffer: out_buf,
431 shape: out_shape.to_vec(),
432 device: self.device.clone(),
433 })
434 }
435}
436
437#[cfg(feature = "cuda")]
452unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
453 debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
454 debug_assert_eq!(std::mem::align_of::<T>(), std::mem::align_of::<U>());
455 unsafe { &*(buf as *const CudaBuffer<T> as *const CudaBuffer<U>) }
459}
460
461#[cfg(feature = "cuda")]
468unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
469 debug_assert_eq!(std::mem::size_of::<U>(), std::mem::size_of::<T>());
470 debug_assert_eq!(std::mem::align_of::<U>(), std::mem::align_of::<T>());
471 let result = unsafe { std::ptr::read(&buf as *const CudaBuffer<U> as *const CudaBuffer<T>) };
473 std::mem::forget(buf);
474 result
475}
476
477#[cfg(not(feature = "cuda"))]
482unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
483 let _ = buf;
484 unreachable!("transmute_buffer_ref called without cuda feature")
485}
486
487#[cfg(not(feature = "cuda"))]
488unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
489 let _ = buf;
490 unreachable!("transmute_buffer called without cuda feature")
491}
492
493fn binary_cpu_fallback<T: GpuFloat>(
499 a: &GpuTensor<T>,
500 b: &GpuTensor<T>,
501 op: fn(T, T) -> T,
502) -> GpuResult<GpuTensor<T>> {
503 let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
504 let b_cpu = gpu_to_cpu(&b.buffer, &b.device)?;
505 let result: Vec<T> = a_cpu
506 .iter()
507 .zip(b_cpu.iter())
508 .map(|(&x, &y)| op(x, y))
509 .collect();
510 let out_buf = cpu_to_gpu(&result, &a.device)?;
511 Ok(GpuTensor {
512 buffer: out_buf,
513 shape: a.shape.clone(),
514 device: a.device.clone(),
515 })
516}
517
518fn unary_cpu_fallback<T: GpuFloat>(a: &GpuTensor<T>, op: fn(T) -> T) -> GpuResult<GpuTensor<T>> {
520 let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
521 let result: Vec<T> = a_cpu.iter().map(|&x| op(x)).collect();
522 let out_buf = cpu_to_gpu(&result, &a.device)?;
523 Ok(GpuTensor {
524 buffer: out_buf,
525 shape: a.shape.clone(),
526 device: a.device.clone(),
527 })
528}
529
530pub fn tensor_to_gpu<T: GpuFloat>(
544 tensor: &Tensor<T>,
545 device: &GpuDevice,
546) -> GpuResult<GpuTensor<T>> {
547 if !tensor.is_contiguous() {
549 return Err(GpuError::LengthMismatch {
550 a: tensor.numel(),
551 b: tensor.data().map_or(0, |d| d.len()),
552 });
553 }
554
555 let data = tensor.data().map_err(|_e| GpuError::InvalidDevice {
557 ordinal: device.ordinal(),
558 count: 0,
559 })?;
560
561 let buffer = cpu_to_gpu(data, device)?;
562 Ok(GpuTensor {
563 buffer,
564 shape: tensor.shape().to_vec(),
565 device: device.clone(),
566 })
567}
568
569pub fn tensor_to_cpu<T: GpuFloat>(gpu_tensor: &GpuTensor<T>) -> FerrotorchResult<Tensor<T>> {
578 let host_data = gpu_to_cpu(&gpu_tensor.buffer, &gpu_tensor.device).map_err(|e| {
579 FerrotorchError::InvalidArgument {
580 message: format!("GPU-to-CPU transfer failed: {e}"),
581 }
582 })?;
583
584 let storage = TensorStorage::cpu(host_data);
585 Tensor::from_storage(storage, gpu_tensor.shape.clone(), false)
586}
587
588pub fn cuda<T: GpuFloat>(tensor: &Tensor<T>, ordinal: usize) -> GpuResult<GpuTensor<T>> {
601 let device = GpuDevice::new(ordinal)?;
602 tensor_to_gpu(tensor, &device)
603}
604
605pub fn cuda_default<T: GpuFloat>(tensor: &Tensor<T>) -> GpuResult<GpuTensor<T>> {
609 cuda(tensor, 0)
610}
611
612#[cfg(test)]
617#[cfg(feature = "cuda")]
618mod tests {
619 use super::*;
620 use ferrotorch_core::{Tensor, TensorStorage};
621
622 fn cpu_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
624 let storage = TensorStorage::cpu(data);
625 Tensor::from_storage(storage, shape, false).expect("cpu_tensor")
626 }
627
628 #[test]
631 fn tensor_to_gpu_round_trip() {
632 let t = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
633 let gpu = cuda_default(&t).expect("cuda_default");
634 let back = gpu.cpu().expect("cpu()");
635
636 assert_eq!(back.shape(), &[2, 3]);
637 assert_eq!(back.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
638 }
639
640 #[test]
643 fn gpu_tensor_shape_preserved() {
644 let t = cpu_tensor(vec![1.0; 24], vec![2, 3, 4]);
645 let gpu = cuda_default(&t).expect("cuda_default");
646
647 assert_eq!(gpu.shape(), &[2, 3, 4]);
648 assert_eq!(gpu.numel(), 24);
649 assert_eq!(gpu.ndim(), 3);
650 }
651
652 #[test]
655 fn gpu_tensor_add() {
656 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
657 let b = cpu_tensor(vec![10.0, 20.0, 30.0, 40.0], vec![4]);
658
659 let device = GpuDevice::new(0).expect("CUDA device 0");
660 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
661 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
662
663 let gc = ga.add(&gb).expect("gpu add");
664 let result = gc.cpu().expect("cpu");
665
666 assert_eq!(result.shape(), &[4]);
667 let data = result.data().unwrap();
668 assert!((data[0] - 11.0).abs() < 1e-6);
669 assert!((data[1] - 22.0).abs() < 1e-6);
670 assert!((data[2] - 33.0).abs() < 1e-6);
671 assert!((data[3] - 44.0).abs() < 1e-6);
672 }
673
674 #[test]
677 fn gpu_tensor_relu() {
678 let t = cpu_tensor(vec![-3.0, -1.0, 0.0, 1.0, 3.0], vec![5]);
679 let gpu = cuda_default(&t).expect("cuda_default");
680 let out = gpu.relu().expect("relu");
681 let result = out.cpu().expect("cpu");
682
683 let data = result.data().unwrap();
684 assert!((data[0] - 0.0).abs() < 1e-6);
685 assert!((data[1] - 0.0).abs() < 1e-6);
686 assert!((data[2] - 0.0).abs() < 1e-6);
687 assert!((data[3] - 1.0).abs() < 1e-6);
688 assert!((data[4] - 3.0).abs() < 1e-6);
689 }
690
691 #[test]
694 fn tensor_to_cpu_correct_values() {
695 let original = vec![0.5, -1.5, 2.25, 0.0, 100.0, -0.001];
696 let t = cpu_tensor(original.clone(), vec![2, 3]);
697 let gpu = cuda_default(&t).expect("cuda_default");
698 let back = tensor_to_cpu(&gpu).expect("tensor_to_cpu");
699
700 let data = back.data().unwrap();
701 for (i, (&got, &expected)) in data.iter().zip(original.iter()).enumerate() {
702 assert!(
703 (got - expected).abs() < 1e-6,
704 "element {i}: got {got}, expected {expected}",
705 );
706 }
707 }
708
709 #[test]
712 fn gpu_tensor_sub() {
713 let a = cpu_tensor(vec![10.0, 20.0, 30.0], vec![3]);
714 let b = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
715
716 let device = GpuDevice::new(0).expect("CUDA device 0");
717 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
718 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
719
720 let gc = ga.sub(&gb).expect("gpu sub");
721 let result = gc.cpu().expect("cpu");
722 let data = result.data().unwrap();
723 assert!((data[0] - 9.0).abs() < 1e-6);
724 assert!((data[1] - 18.0).abs() < 1e-6);
725 assert!((data[2] - 27.0).abs() < 1e-6);
726 }
727
728 #[test]
731 fn gpu_tensor_mul() {
732 let a = cpu_tensor(vec![2.0, 3.0, 4.0], vec![3]);
733 let b = cpu_tensor(vec![10.0, 10.0, 10.0], vec![3]);
734
735 let device = GpuDevice::new(0).expect("CUDA device 0");
736 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
737 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
738
739 let gc = ga.mul(&gb).expect("gpu mul");
740 let result = gc.cpu().expect("cpu");
741 let data = result.data().unwrap();
742 assert!((data[0] - 20.0).abs() < 1e-6);
743 assert!((data[1] - 30.0).abs() < 1e-6);
744 assert!((data[2] - 40.0).abs() < 1e-6);
745 }
746
747 #[test]
750 fn gpu_tensor_neg() {
751 let t = cpu_tensor(vec![1.0, -2.0, 0.0, 3.5], vec![4]);
752 let gpu = cuda_default(&t).expect("cuda_default");
753 let out = gpu.neg().expect("neg");
754 let result = out.cpu().expect("cpu");
755 let data = result.data().unwrap();
756 assert!((data[0] - (-1.0)).abs() < 1e-6);
757 assert!((data[1] - 2.0).abs() < 1e-6);
758 assert!((data[2] - 0.0).abs() < 1e-6);
759 assert!((data[3] - (-3.5)).abs() < 1e-6);
760 }
761
762 #[test]
765 fn gpu_tensor_matmul_basic() {
766 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
770 let b = cpu_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
771
772 let device = GpuDevice::new(0).expect("CUDA device 0");
773 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
774 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
775
776 let gc = ga.matmul(&gb).expect("gpu matmul");
777 assert_eq!(gc.shape(), &[2, 2]);
778
779 let result = gc.cpu().expect("cpu");
780 let data = result.data().unwrap();
781 assert!((data[0] - 58.0).abs() < 1e-4);
782 assert!((data[1] - 64.0).abs() < 1e-4);
783 assert!((data[2] - 139.0).abs() < 1e-4);
784 assert!((data[3] - 154.0).abs() < 1e-4);
785 }
786
787 #[test]
788 fn gpu_tensor_matmul_identity() {
789 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
790 let i = cpu_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
791
792 let device = GpuDevice::new(0).expect("CUDA device 0");
793 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
794 let gi = tensor_to_gpu(&i, &device).expect("i to gpu");
795
796 let gc = ga.matmul(&gi).expect("gpu matmul identity");
797 let result = gc.cpu().expect("cpu");
798 let data = result.data().unwrap();
799 assert!((data[0] - 1.0).abs() < 1e-6);
800 assert!((data[1] - 2.0).abs() < 1e-6);
801 assert!((data[2] - 3.0).abs() < 1e-6);
802 assert!((data[3] - 4.0).abs() < 1e-6);
803 }
804
805 #[test]
806 fn gpu_tensor_matmul_inner_dim_mismatch() {
807 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
809 let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
810
811 let device = GpuDevice::new(0).expect("CUDA device 0");
812 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
813 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
814
815 let err = ga.matmul(&gb).unwrap_err();
816 match err {
817 GpuError::ShapeMismatch { op: "matmul", .. } => {}
818 other => panic!("unexpected error: {other}"),
819 }
820 }
821
822 #[test]
823 fn gpu_tensor_matmul_not_2d() {
824 let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
826 let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
827
828 let device = GpuDevice::new(0).expect("CUDA device 0");
829 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
830 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
831
832 let err = ga.matmul(&gb).unwrap_err();
833 match err {
834 GpuError::ShapeMismatch { op: "matmul", .. } => {}
835 other => panic!("unexpected error: {other}"),
836 }
837 }
838
839 #[test]
842 fn gpu_tensor_add_shape_mismatch() {
843 let a = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
844 let b = cpu_tensor(vec![1.0, 2.0], vec![2]);
845
846 let device = GpuDevice::new(0).expect("CUDA device 0");
847 let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
848 let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
849
850 let err = ga.add(&gb).unwrap_err();
851 match err {
852 GpuError::LengthMismatch { .. } => {}
853 other => panic!("unexpected error: {other}"),
854 }
855 }
856
857 #[test]
860 fn gpu_tensor_empty_round_trip() {
861 let t = cpu_tensor(vec![], vec![0]);
862 let gpu = cuda_default(&t).expect("cuda_default");
863 assert_eq!(gpu.numel(), 0);
864 assert_eq!(gpu.shape(), &[0]);
865
866 let back = gpu.cpu().expect("cpu");
867 assert_eq!(back.shape(), &[0]);
868 assert_eq!(back.data().unwrap().len(), 0);
869 }
870
871 #[test]
874 fn gpu_tensor_scalar_round_trip() {
875 let storage = TensorStorage::cpu(vec![42.0f32]);
876 let t = Tensor::from_storage(storage, vec![], false).expect("scalar");
877 let gpu = cuda_default(&t).expect("cuda_default");
878 assert_eq!(gpu.shape(), &[] as &[usize]);
879 assert_eq!(gpu.numel(), 1);
880
881 let back = gpu.cpu().expect("cpu");
882 assert!(back.is_scalar());
883 assert!((back.item().unwrap() - 42.0).abs() < 1e-6);
884 }
885}