1use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::Device;
21use axonml_core::backends::CpuBackend;
22#[cfg(feature = "cuda")]
23use axonml_core::backends::CudaBackend;
24use axonml_core::dtype::{Float, Numeric, Scalar};
25use axonml_core::error::{Error, Result};
26use axonml_core::storage::Storage;
27use num_traits::NumCast;
28
29#[cfg(feature = "cuda")]
34mod cuda_accel {
35 use super::*;
36 use axonml_core::backends::cuda::get_cuda_backend;
37
38 pub fn get_cuda() -> Option<&'static CudaBackend> {
40 get_cuda_backend()
41 }
42
43 pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46 let cuda = get_cuda()?;
47
48 let a_gpu = cuda.htod_copy(a).ok()?;
49 let b_gpu = cuda.htod_copy(b).ok()?;
50 let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52 cuda.gemm_f32(
55 false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
56 )
57 .ok()?;
58
59 cuda.dtoh_copy(&c_gpu).ok()
60 }
61}
62
63use crate::shape::{
64 Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
65 linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
66 unsqueeze,
67};
68
69#[cfg(feature = "cuda")]
77unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
78 assert!(
79 is_f32::<T>(),
80 "gpu_ref: only Tensor<f32> can be used for GPU operations, got {:?}",
81 T::DTYPE
82 );
83 unsafe { &*(t as *const Tensor<T> as *const Tensor<f32>) }
85}
86
87#[cfg(feature = "cuda")]
88unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
89 assert!(
90 is_f32::<T>(),
91 "gpu_into: only Tensor<f32> can be produced from GPU operations, got {:?}",
92 T::DTYPE
93 );
94 unsafe {
96 let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
97 std::mem::forget(t);
98 out
99 }
100}
101
102#[cfg(feature = "cuda")]
103fn is_f32<T: 'static>() -> bool {
104 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
105}
106
107#[derive(Clone)]
117pub struct Tensor<T: Scalar> {
118 pub(crate) storage: Storage<T>,
120 pub(crate) shape: Shape,
122 pub(crate) strides: Strides,
124 pub(crate) offset: usize,
126}
127
128impl<T: Scalar> Tensor<T> {
129 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
142 let total = numel(shape);
143 if total != storage.len() {
144 return Err(Error::shape_mismatch(&[storage.len()], shape));
145 }
146
147 let shape = Shape::from_slice(shape);
148 let strides = contiguous_strides(&shape);
149
150 Ok(Self {
151 storage,
152 shape,
153 strides,
154 offset: 0,
155 })
156 }
157
158 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
167 let storage = Storage::from_vec(data, Device::Cpu);
168 Self::from_storage(storage, shape)
169 }
170
171 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
180 let storage = Storage::from_slice(data, Device::Cpu);
181 Self::from_storage(storage, shape)
182 }
183
184 pub fn scalar(value: T) -> Self {
192 Self {
193 storage: Storage::from_vec(vec![value], Device::Cpu),
194 shape: Shape::new(),
195 strides: Strides::new(),
196 offset: 0,
197 }
198 }
199
200 #[must_use]
202 pub fn zeros(shape: &[usize]) -> Self {
203 crate::creation::zeros(shape)
204 }
205
206 #[must_use]
208 pub fn ones(shape: &[usize]) -> Self
209 where
210 T: Numeric,
211 {
212 crate::creation::ones(shape)
213 }
214
215 #[must_use]
217 pub fn full(shape: &[usize], value: T) -> Self {
218 crate::creation::full(shape, value)
219 }
220
221 #[must_use]
223 pub fn randn(shape: &[usize]) -> Self
224 where
225 T: Float,
226 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
227 {
228 crate::creation::randn(shape)
229 }
230
231 #[must_use]
233 pub fn rand(shape: &[usize]) -> Self
234 where
235 T: Float,
236 rand::distributions::Standard: rand::distributions::Distribution<T>,
237 {
238 crate::creation::rand(shape)
239 }
240
241 #[must_use]
247 pub fn shape(&self) -> &[usize] {
248 &self.shape
249 }
250
251 #[must_use]
253 pub fn strides(&self) -> &[isize] {
254 &self.strides
255 }
256
257 #[must_use]
259 pub fn ndim(&self) -> usize {
260 self.shape.len()
261 }
262
263 #[must_use]
265 pub fn numel(&self) -> usize {
266 numel(&self.shape)
267 }
268
269 #[must_use]
271 pub fn is_empty(&self) -> bool {
272 self.numel() == 0
273 }
274
275 pub fn size(&self, dim: i64) -> Result<usize> {
280 let idx = normalize_dim(dim, self.ndim())?;
281 Ok(self.shape[idx])
282 }
283
284 #[must_use]
286 pub fn device(&self) -> Device {
287 self.storage.device()
288 }
289
290 #[must_use]
292 pub fn is_contiguous(&self) -> bool {
293 is_contiguous(&self.shape, &self.strides)
294 }
295
296 #[must_use]
298 pub fn is_scalar(&self) -> bool {
299 self.shape.is_empty()
300 }
301
302 pub fn get(&self, indices: &[usize]) -> Result<T> {
311 if indices.len() != self.ndim() {
312 return Err(Error::invalid_operation(format!(
313 "Expected {} indices, got {}",
314 self.ndim(),
315 indices.len()
316 )));
317 }
318
319 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
320 if idx >= dim {
321 return Err(Error::IndexOutOfBounds {
322 index: idx,
323 size: dim,
324 });
325 }
326 }
327
328 let offset = self.offset + linear_index(indices, &self.strides);
329 Ok(self.storage.as_slice()[offset])
330 }
331
332 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
338 if indices.len() != self.ndim() {
339 return Err(Error::invalid_operation(format!(
340 "Expected {} indices, got {}",
341 self.ndim(),
342 indices.len()
343 )));
344 }
345
346 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
347 if idx >= dim {
348 return Err(Error::IndexOutOfBounds {
349 index: idx,
350 size: dim,
351 });
352 }
353 }
354
355 let offset = self.offset + linear_index(indices, &self.strides);
356 self.storage.as_slice_mut()[offset] = value;
357 Ok(())
358 }
359
360 pub fn item(&self) -> Result<T> {
362 if self.numel() != 1 {
363 return Err(Error::invalid_operation(
364 "item() only works on single-element tensors",
365 ));
366 }
367
368 if self.is_scalar() {
369 Ok(self.storage.as_slice()[self.offset])
370 } else {
371 let indices = vec![0; self.ndim()];
373 self.get(&indices)
374 }
375 }
376
377 #[must_use]
383 pub fn to_vec(&self) -> Vec<T> {
384 #[cfg(feature = "cuda")]
386 if self.storage.is_gpu() {
387 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
388 let self_f32 = unsafe { gpu_ref(self) };
389 let f32_vec = self_f32.to_vec_gpu();
390 unsafe {
391 let mut v = std::mem::ManuallyDrop::new(f32_vec);
392 return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
393 }
394 }
395
396 if self.is_contiguous() {
397 let storage = self.storage.as_slice();
398 storage[self.offset..self.offset + self.numel()].to_vec()
399 } else {
400 let mut result = Vec::with_capacity(self.numel());
401 self.copy_data_to(&mut result);
402 result
403 }
404 }
405
406 fn copy_data_to(&self, dst: &mut Vec<T>) {
408 dst.clear();
409 let storage = self.storage.as_slice();
410
411 let total = self.numel();
413 for i in 0..total {
414 let indices = crate::shape::unravel_index(i, &self.shape);
415 let offset = self.offset + linear_index(&indices, &self.strides);
416 dst.push(storage[offset]);
417 }
418 }
419
420 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
432 let shape = reshape(&self.shape, new_shape)?;
433
434 if self.is_contiguous() {
435 Ok(Self {
437 storage: self.storage.clone(),
438 strides: contiguous_strides(&shape),
439 shape,
440 offset: self.offset,
441 })
442 } else {
443 let contig = self.contiguous();
445 Ok(Self {
446 storage: contig.storage,
447 strides: contiguous_strides(&shape),
448 shape,
449 offset: 0,
450 })
451 }
452 }
453
454 #[must_use]
456 pub fn flatten(&self) -> Self {
457 self.reshape(&[-1]).expect("Flatten should never fail")
458 }
459
460 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
465 let dim = match dim {
466 Some(d) => Some(normalize_dim(d, self.ndim())?),
467 None => None,
468 };
469
470 let new_shape = squeeze(&self.shape, dim);
471 let new_strides: Strides = match dim {
472 Some(d) => {
473 let mut s = self.strides.clone();
474 if d < self.shape.len() && self.shape[d] == 1 {
475 s.remove(d);
476 }
477 s
478 }
479 None => self
480 .shape
481 .iter()
482 .zip(self.strides.iter())
483 .filter(|(dim, _)| **dim != 1)
484 .map(|(_, stride)| *stride)
485 .collect(),
486 };
487
488 Ok(Self {
489 storage: self.storage.clone(),
490 shape: new_shape,
491 strides: new_strides,
492 offset: self.offset,
493 })
494 }
495
496 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
501 let normalized = if dim < 0 {
502 (dim + self.ndim() as i64 + 1) as usize
503 } else {
504 dim as usize
505 };
506
507 let new_shape = unsqueeze(&self.shape, normalized)?;
508 let mut new_strides = Strides::with_capacity(new_shape.len());
509
510 for (i, _) in new_shape.iter().enumerate() {
511 if i < normalized {
512 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
513 } else if i == normalized {
514 new_strides.push(1);
516 } else {
517 new_strides.push(self.strides[i - 1]);
518 }
519 }
520
521 Ok(Self {
522 storage: self.storage.clone(),
523 shape: new_shape,
524 strides: new_strides,
525 offset: self.offset,
526 })
527 }
528
529 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
535 let d0 = normalize_dim(dim0, self.ndim())?;
536 let d1 = normalize_dim(dim1, self.ndim())?;
537
538 let new_shape = transpose_shape(&self.shape, d0, d1)?;
539 let new_strides = transpose_strides(&self.strides, d0, d1);
540
541 Ok(Self {
542 storage: self.storage.clone(),
543 shape: new_shape,
544 strides: new_strides,
545 offset: self.offset,
546 })
547 }
548
549 pub fn t(&self) -> Result<Self> {
551 if self.ndim() != 2 {
552 return Err(Error::invalid_operation("t() only works on 2D tensors"));
553 }
554 self.transpose(0, 1)
555 }
556
557 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
562 if dims.len() != self.ndim() {
563 return Err(Error::invalid_operation(format!(
564 "Expected {} dimensions, got {}",
565 self.ndim(),
566 dims.len()
567 )));
568 }
569
570 let mut seen = vec![false; self.ndim()];
572 for &d in dims {
573 if d >= self.ndim() {
574 return Err(Error::InvalidDimension {
575 index: d as i64,
576 ndim: self.ndim(),
577 });
578 }
579 if seen[d] {
580 return Err(Error::invalid_operation("Duplicate dimension in permute"));
581 }
582 seen[d] = true;
583 }
584
585 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
586 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
587
588 Ok(Self {
589 storage: self.storage.clone(),
590 shape: new_shape,
591 strides: new_strides,
592 offset: self.offset,
593 })
594 }
595
596 #[must_use]
598 pub fn contiguous(&self) -> Self {
599 if self.is_contiguous() && self.offset == 0 {
600 return self.clone();
601 }
602
603 #[cfg(feature = "cuda")]
604 if self.storage.is_gpu() {
605 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
606 let self_f32 = unsafe { gpu_ref(self) };
607 let result = self_f32.contiguous_gpu();
608 return unsafe { gpu_into(result) };
609 }
610
611 let data = self.to_vec();
612 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
613 }
614
615 pub fn to_device(&self, device: Device) -> Result<Self> {
624 if self.device() == device {
625 return Ok(self.clone());
626 }
627
628 #[cfg(feature = "cuda")]
629 if self.storage.is_gpu() || device.is_gpu() {
630 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
631 let self_f32 = unsafe { gpu_ref(self) };
632 let result = self_f32.to_device_f32(device)?;
633 return Ok(unsafe { gpu_into(result) });
634 }
635
636 let contig = self.contiguous();
637 let new_storage = contig.storage.to_device(device)?;
638
639 Ok(Self {
640 storage: new_storage,
641 shape: self.shape.clone(),
642 strides: self.strides.clone(),
643 offset: 0,
644 })
645 }
646
647 pub fn cpu(&self) -> Result<Self> {
649 self.to_device(Device::Cpu)
650 }
651
652 #[must_use]
658 pub fn clone_deep(&self) -> Self {
659 let data = self.to_vec();
660 let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
661 #[cfg(feature = "cuda")]
662 if self.device().is_gpu() {
663 return cpu.to_device(self.device()).unwrap();
664 }
665 cpu
666 }
667}
668
669impl<T: Numeric> Tensor<T> {
674 pub fn fill_(&self, value: T) {
680 assert!(
681 self.storage.is_cpu(),
682 "fill_() not supported on GPU tensors — create a new tensor and transfer instead"
683 );
684 let mut data = self.storage.as_slice_mut();
685 CpuBackend::fill(&mut data, value);
686 }
687
688 pub fn zero_(&self) {
690 self.fill_(T::zero());
691 }
692
693 #[must_use]
701 pub fn sum(&self) -> Self {
702 #[cfg(feature = "cuda")]
703 if self.device().is_gpu() {
704 let mut t = self.clone();
706 while t.ndim() > 1 {
707 t = t.sum_dim_cuda(0);
708 }
709 if t.numel() > 1 {
711 t = t.sum_dim_cuda(0);
712 }
713 return t;
714 }
715
716 let data = self.to_vec();
717 let result = CpuBackend::sum(&data);
718 Self::scalar(result)
719 }
720
721 #[must_use]
725 pub fn prod(&self) -> Self {
726 let data = self.to_vec();
727 let result = CpuBackend::prod(&data);
728 let s = Self::scalar(result);
729 #[cfg(feature = "cuda")]
730 if self.device().is_gpu() {
731 return s.to_device(self.device()).expect("prod: device transfer failed");
732 }
733 s
734 }
735
736 pub fn max(&self) -> Result<Self> {
740 if self.is_empty() {
741 return Err(Error::EmptyTensor);
742 }
743 let data = self.to_vec();
744 let result = CpuBackend::max(&data).expect("max on non-empty tensor");
745 let s = Self::scalar(result);
746 #[cfg(feature = "cuda")]
747 if self.device().is_gpu() {
748 return Ok(s.to_device(self.device()).expect("max: device transfer failed"));
749 }
750 Ok(s)
751 }
752
753 pub fn min(&self) -> Result<Self> {
757 if self.is_empty() {
758 return Err(Error::EmptyTensor);
759 }
760 let data = self.to_vec();
761 let result = CpuBackend::min(&data).expect("min on non-empty tensor");
762 let s = Self::scalar(result);
763 #[cfg(feature = "cuda")]
764 if self.device().is_gpu() {
765 return Ok(s.to_device(self.device()).expect("min: device transfer failed"));
766 }
767 Ok(s)
768 }
769
770 pub fn argmax(&self) -> Result<usize> {
772 if self.is_empty() {
773 return Err(Error::EmptyTensor);
774 }
775 let data = self.to_vec();
776 Ok(CpuBackend::argmax(&data).unwrap())
777 }
778
779 pub fn argmin(&self) -> Result<usize> {
781 if self.is_empty() {
782 return Err(Error::EmptyTensor);
783 }
784 let data = self.to_vec();
785 Ok(CpuBackend::argmin(&data).unwrap())
786 }
787
788 pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
792 if tensors.is_empty() {
793 return Err(Error::invalid_operation("cat requires at least one tensor"));
794 }
795 let ndim = tensors[0].ndim();
796 if dim >= ndim {
797 return Err(Error::invalid_operation("cat dimension out of range"));
798 }
799
800 for t in &tensors[1..] {
801 if t.ndim() != ndim {
802 return Err(Error::invalid_operation(
803 "cat: all tensors must have same ndim",
804 ));
805 }
806 for d in 0..ndim {
807 if d != dim && t.shape[d] != tensors[0].shape[d] {
808 return Err(Error::invalid_operation(
809 "cat: shapes must match on non-cat dims",
810 ));
811 }
812 }
813 }
814
815 let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
816 let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
817 out_shape[dim] = total_dim_size;
818
819 let outer_size: usize = out_shape[..dim].iter().product();
820 let inner_size: usize = out_shape[dim + 1..].iter().product();
821 let total_numel: usize = out_shape.iter().product();
822 let mut result = vec![T::zero(); total_numel];
823
824 let mut dim_offset = 0;
825 for t in tensors {
826 let t_data = t.contiguous().to_vec();
827 let t_dim_size = t.shape[dim];
828 for outer in 0..outer_size {
829 for d in 0..t_dim_size {
830 let src_base = outer * t_dim_size * inner_size + d * inner_size;
831 let dst_base =
832 outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
833 result[dst_base..dst_base + inner_size]
834 .copy_from_slice(&t_data[src_base..src_base + inner_size]);
835 }
836 }
837 dim_offset += t_dim_size;
838 }
839
840 let out = Self::from_vec(result, &out_shape)?;
841 #[cfg(feature = "cuda")]
842 if tensors[0].device().is_gpu() {
843 return Ok(out.to_device(tensors[0].device()).unwrap());
844 }
845 Ok(out)
846 }
847}
848
849impl<T: Float> Tensor<T> {
854 pub fn mean(&self) -> Result<Self> {
859 if self.is_empty() {
860 return Err(Error::EmptyTensor);
861 }
862 #[cfg(feature = "cuda")]
863 if self.device().is_gpu() {
864 let s = self.sum(); let n = self.numel() as f32;
866 return Ok(s.mul_scalar(T::from(1.0 / n as f64).unwrap_or(T::zero())));
868 }
869
870 let data = self.to_vec();
871 let result = CpuBackend::mean(&data).expect("mean on non-empty tensor");
872 Ok(Self::scalar(result))
873 }
874
875 #[must_use]
881 pub fn relu(&self) -> Self {
882 #[cfg(feature = "cuda")]
883 if self.device().is_gpu() {
884 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
885 return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
886 }
887 let data = self.to_vec();
888 let mut result = vec![T::zero(); data.len()];
889 CpuBackend::relu(&mut result, &data);
890 Self::from_vec(result, &self.shape).unwrap()
891 }
892
893 #[must_use]
895 pub fn sigmoid(&self) -> Self {
896 #[cfg(feature = "cuda")]
897 if self.device().is_gpu() {
898 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
899 return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
900 }
901 let data = self.to_vec();
902 let mut result = vec![T::zero(); data.len()];
903 CpuBackend::sigmoid(&mut result, &data);
904 Self::from_vec(result, &self.shape).unwrap()
905 }
906
907 #[must_use]
909 pub fn tanh(&self) -> Self {
910 #[cfg(feature = "cuda")]
911 if self.device().is_gpu() {
912 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
913 return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
914 }
915 let data = self.to_vec();
916 let mut result = vec![T::zero(); data.len()];
917 CpuBackend::tanh(&mut result, &data);
918 Self::from_vec(result, &self.shape).unwrap()
919 }
920
921 #[must_use]
923 pub fn exp(&self) -> Self {
924 #[cfg(feature = "cuda")]
925 if self.device().is_gpu() {
926 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
927 return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
928 }
929 let data = self.to_vec();
930 let mut result = vec![T::zero(); data.len()];
931 CpuBackend::exp(&mut result, &data);
932 Self::from_vec(result, &self.shape).unwrap()
933 }
934
935 #[must_use]
937 pub fn ln(&self) -> Self {
938 #[cfg(feature = "cuda")]
939 if self.device().is_gpu() {
940 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
941 return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
942 }
943 let data = self.to_vec();
944 let mut result = vec![T::zero(); data.len()];
945 CpuBackend::ln(&mut result, &data);
946 Self::from_vec(result, &self.shape).unwrap()
947 }
948
949 #[must_use]
951 pub fn sqrt(&self) -> Self {
952 #[cfg(feature = "cuda")]
953 if self.device().is_gpu() {
954 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
955 return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
956 }
957 let data = self.to_vec();
958 let mut result = vec![T::zero(); data.len()];
959 CpuBackend::sqrt(&mut result, &data);
960 Self::from_vec(result, &self.shape).unwrap()
961 }
962
963 #[must_use]
965 pub fn pow(&self, exp: T) -> Self {
966 #[cfg(feature = "cuda")]
967 if self.device().is_gpu() {
968 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
969 let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
970 return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
971 }
972 let data = self.to_vec();
973 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
974 Self::from_vec(result, &self.shape).unwrap()
975 }
976
977 #[must_use]
979 pub fn gelu(&self) -> Self {
980 #[cfg(feature = "cuda")]
981 if self.device().is_gpu() {
982 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
983 return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
984 }
985 crate::ops::gelu(self)
986 }
987
988 #[must_use]
990 pub fn silu(&self) -> Self {
991 #[cfg(feature = "cuda")]
992 if self.device().is_gpu() {
993 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
994 return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
995 }
996 crate::ops::silu(self)
997 }
998
999 #[must_use]
1001 pub fn softmax(&self, dim: i32) -> Self {
1002 #[cfg(feature = "cuda")]
1003 if self.device().is_gpu() {
1004 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1005 let self_f32 = unsafe { gpu_ref(self) };
1006 return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
1007 }
1008 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
1009 }
1010
1011 #[must_use]
1013 pub fn log_softmax(&self, dim: i32) -> Self {
1014 let softmax_result = self.softmax(dim);
1015 softmax_result.ln()
1016 }
1017
1018 #[must_use]
1020 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
1021 let ndim = self.ndim();
1022 let dim = if dim < 0 {
1023 (ndim as i32 + dim) as usize
1024 } else {
1025 dim as usize
1026 };
1027
1028 if dim >= ndim {
1029 return self.clone();
1030 }
1031
1032 #[cfg(feature = "cuda")]
1034 if self.device().is_gpu() {
1035 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1036 let self_f32 = unsafe { gpu_ref(self) };
1037 let summed = if keepdim {
1038 self_f32.sum_dim_keepdim_cuda(dim)
1039 } else {
1040 self_f32.sum_dim_cuda(dim)
1041 };
1042 let dim_size = self.shape[dim];
1043 let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1044 return unsafe { gpu_into(result) };
1045 }
1046
1047 let dim_size = self.shape[dim];
1048 let data = self.to_vec();
1049 let mut new_shape = self.shape.clone();
1050
1051 if keepdim {
1052 new_shape[dim] = 1;
1053 } else {
1054 new_shape.remove(dim);
1055 }
1056
1057 if new_shape.is_empty() {
1058 new_shape = smallvec::smallvec![1];
1059 }
1060
1061 let new_numel: usize = new_shape.iter().product();
1062 let mut result = vec![T::zero(); new_numel];
1063
1064 let outer_size: usize = self.shape[..dim].iter().product();
1065 let inner_size: usize = self.shape[dim + 1..].iter().product();
1066
1067 for outer in 0..outer_size {
1068 for inner in 0..inner_size {
1069 let mut sum = T::zero();
1070 for d in 0..dim_size {
1071 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1072 sum = sum + data[idx];
1073 }
1074 let mean = sum / NumCast::from(dim_size).unwrap();
1075 let result_idx = outer * inner_size + inner;
1076 result[result_idx] = mean;
1077 }
1078 }
1079
1080 Self::from_vec(result, &new_shape).unwrap()
1081 }
1082
1083 #[must_use]
1085 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1086 let ndim = self.ndim();
1087 let dim = if dim < 0 {
1088 (ndim as i32 + dim) as usize
1089 } else {
1090 dim as usize
1091 };
1092
1093 if dim >= ndim {
1094 return self.clone();
1095 }
1096
1097 #[cfg(feature = "cuda")]
1099 if self.device().is_gpu() {
1100 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1101 let self_f32 = unsafe { gpu_ref(self) };
1102 let result = if keepdim {
1103 self_f32.sum_dim_keepdim_cuda(dim)
1104 } else {
1105 self_f32.sum_dim_cuda(dim)
1106 };
1107 return unsafe { gpu_into(result) };
1108 }
1109
1110 let dim_size = self.shape[dim];
1111 let data = self.to_vec();
1112 let mut new_shape = self.shape.clone();
1113
1114 if keepdim {
1115 new_shape[dim] = 1;
1116 } else {
1117 new_shape.remove(dim);
1118 }
1119
1120 if new_shape.is_empty() {
1121 new_shape = smallvec::smallvec![1];
1122 }
1123
1124 let new_numel: usize = new_shape.iter().product();
1125 let mut result = vec![T::zero(); new_numel];
1126
1127 let outer_size: usize = self.shape[..dim].iter().product();
1128 let inner_size: usize = self.shape[dim + 1..].iter().product();
1129
1130 for outer in 0..outer_size {
1131 for inner in 0..inner_size {
1132 let mut sum = T::zero();
1133 for d in 0..dim_size {
1134 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1135 sum = sum + data[idx];
1136 }
1137 let result_idx = outer * inner_size + inner;
1138 result[result_idx] = sum;
1139 }
1140 }
1141
1142 Self::from_vec(result, &new_shape).unwrap()
1143 }
1144
1145 #[must_use]
1147 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1148 let mean = self.mean_dim(dim, true);
1150 let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1151 let mean_sq = sq.mean_dim(dim, keepdim);
1152 let mean_keepdim = if keepdim {
1153 mean.clone()
1154 } else {
1155 self.mean_dim(dim, keepdim)
1156 };
1157 let mean_squared = mean_keepdim
1158 .mul(&mean_keepdim)
1159 .unwrap_or_else(|_| mean_keepdim.clone());
1160 mean_sq
1161 .sub(&mean_squared)
1162 .unwrap_or_else(|_| mean_sq.clone())
1163 }
1164
1165 #[must_use]
1167 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1168 if self.shape.as_slice() == shape {
1169 return self.clone();
1170 }
1171
1172 #[cfg(feature = "cuda")]
1173 if self.device().is_gpu() {
1174 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1175 let self_f32 = unsafe { gpu_ref(self) };
1176 return unsafe {
1177 gpu_into(
1178 self_f32
1179 .broadcast_to_cuda(shape)
1180 .expect("CUDA broadcast_to failed"),
1181 )
1182 };
1183 }
1184
1185 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1186 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1187
1188 let total = numel(&result_shape);
1189 let mut result_data = vec![T::zero(); total];
1190 let self_data = self.storage.as_slice();
1191
1192 for i in 0..total {
1193 let indices = crate::shape::unravel_index(i, &result_shape);
1194 let self_idx = self.offset + linear_index(&indices, &self_strides);
1195 result_data[i] = self_data[self_idx];
1196 }
1197
1198 Self::from_vec(result_data, &result_shape).unwrap()
1199 }
1200
1201 #[must_use]
1203 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1204 let mut new_shape = Vec::with_capacity(self.ndim());
1205 for (i, range) in ranges.iter().enumerate() {
1206 if i < self.ndim() {
1207 new_shape.push(range.end - range.start);
1208 }
1209 }
1210 for i in ranges.len()..self.ndim() {
1212 new_shape.push(self.shape[i]);
1213 }
1214
1215 let new_numel: usize = new_shape.iter().product();
1216 let mut result_data = vec![T::zero(); new_numel];
1217 let self_data = self.to_vec();
1218
1219 let mut result_idx = 0;
1221 Self::slice_recursive(
1222 &self_data,
1223 &self.shape,
1224 ranges,
1225 0,
1226 0,
1227 &mut result_data,
1228 &mut result_idx,
1229 );
1230
1231 let out = Self::from_vec(result_data, &new_shape).unwrap();
1232 #[cfg(feature = "cuda")]
1233 if self.device().is_gpu() {
1234 return out.to_device(self.device()).unwrap();
1235 }
1236 out
1237 }
1238
1239 fn slice_recursive(
1240 data: &[T],
1241 shape: &[usize],
1242 ranges: &[std::ops::Range<usize>],
1243 dim: usize,
1244 offset: usize,
1245 result: &mut [T],
1246 result_idx: &mut usize,
1247 ) {
1248 if dim == shape.len() {
1249 result[*result_idx] = data[offset];
1250 *result_idx += 1;
1251 return;
1252 }
1253
1254 let stride: usize = shape[dim + 1..].iter().product();
1255 let (start, end) = if dim < ranges.len() {
1256 (ranges[dim].start, ranges[dim].end)
1257 } else {
1258 (0, shape[dim])
1259 };
1260
1261 for i in start..end {
1262 Self::slice_recursive(
1263 data,
1264 shape,
1265 ranges,
1266 dim + 1,
1267 offset + i * stride,
1268 result,
1269 result_idx,
1270 );
1271 }
1272 }
1273}
1274
1275impl<T: Numeric> Tensor<T> {
1280 pub fn add(&self, other: &Self) -> Result<Self> {
1282 #[cfg(feature = "cuda")]
1283 {
1284 let self_gpu = self.device().is_gpu();
1285 let other_gpu = other.device().is_gpu();
1286 if self_gpu || other_gpu {
1287 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1288 if self_gpu && other_gpu {
1289 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1290 if self.shape == other.shape {
1291 return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1292 } else {
1293 return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1294 }
1295 }
1296 let target_device = if self_gpu {
1298 self.device()
1299 } else {
1300 other.device()
1301 };
1302 let a_gpu = if self_gpu {
1303 self.clone()
1304 } else {
1305 self.to_device(target_device)?
1306 };
1307 let b_gpu = if other_gpu {
1308 other.clone()
1309 } else {
1310 other.to_device(target_device)?
1311 };
1312 return a_gpu.add(&b_gpu);
1313 }
1314 }
1315 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1316 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1317 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1318
1319 let total = numel(&result_shape);
1320 let mut result_data = vec![T::zero(); total];
1321
1322 let self_data = self.storage.as_slice();
1323 let other_data = other.storage.as_slice();
1324
1325 for i in 0..total {
1326 let indices = crate::shape::unravel_index(i, &result_shape);
1327 let self_idx = self.offset + linear_index(&indices, &self_strides);
1328 let other_idx = other.offset + linear_index(&indices, &other_strides);
1329 result_data[i] = self_data[self_idx] + other_data[other_idx];
1330 }
1331
1332 Self::from_vec(result_data, &result_shape)
1333 }
1334
1335 pub fn sub(&self, other: &Self) -> Result<Self> {
1337 #[cfg(feature = "cuda")]
1338 {
1339 let self_gpu = self.device().is_gpu();
1340 let other_gpu = other.device().is_gpu();
1341 if self_gpu || other_gpu {
1342 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1343 if self_gpu && other_gpu {
1344 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1345 if self.shape == other.shape {
1346 return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1347 } else {
1348 return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1349 }
1350 }
1351 let target = if self_gpu {
1352 self.device()
1353 } else {
1354 other.device()
1355 };
1356 let a_gpu = if self_gpu {
1357 self.clone()
1358 } else {
1359 self.to_device(target)?
1360 };
1361 let b_gpu = if other_gpu {
1362 other.clone()
1363 } else {
1364 other.to_device(target)?
1365 };
1366 return a_gpu.sub(&b_gpu);
1367 }
1368 }
1369 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1370 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1371 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1372
1373 let total = numel(&result_shape);
1374 let mut result_data = vec![T::zero(); total];
1375
1376 let self_data = self.storage.as_slice();
1377 let other_data = other.storage.as_slice();
1378
1379 for i in 0..total {
1380 let indices = crate::shape::unravel_index(i, &result_shape);
1381 let self_idx = self.offset + linear_index(&indices, &self_strides);
1382 let other_idx = other.offset + linear_index(&indices, &other_strides);
1383 result_data[i] = self_data[self_idx] - other_data[other_idx];
1384 }
1385
1386 Self::from_vec(result_data, &result_shape)
1387 }
1388
1389 pub fn mul(&self, other: &Self) -> Result<Self> {
1391 #[cfg(feature = "cuda")]
1392 {
1393 let self_gpu = self.device().is_gpu();
1394 let other_gpu = other.device().is_gpu();
1395 if self_gpu || other_gpu {
1396 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1397 if self_gpu && other_gpu {
1398 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1399 if self.shape == other.shape {
1400 return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1401 } else {
1402 return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1403 }
1404 }
1405 let target = if self_gpu {
1406 self.device()
1407 } else {
1408 other.device()
1409 };
1410 let a_gpu = if self_gpu {
1411 self.clone()
1412 } else {
1413 self.to_device(target)?
1414 };
1415 let b_gpu = if other_gpu {
1416 other.clone()
1417 } else {
1418 other.to_device(target)?
1419 };
1420 return a_gpu.mul(&b_gpu);
1421 }
1422 }
1423 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1424 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1425 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1426
1427 let total = numel(&result_shape);
1428 let mut result_data = vec![T::zero(); total];
1429
1430 let self_data = self.storage.as_slice();
1431 let other_data = other.storage.as_slice();
1432
1433 for i in 0..total {
1434 let indices = crate::shape::unravel_index(i, &result_shape);
1435 let self_idx = self.offset + linear_index(&indices, &self_strides);
1436 let other_idx = other.offset + linear_index(&indices, &other_strides);
1437 result_data[i] = self_data[self_idx] * other_data[other_idx];
1438 }
1439
1440 Self::from_vec(result_data, &result_shape)
1441 }
1442
1443 pub fn div(&self, other: &Self) -> Result<Self> {
1445 #[cfg(feature = "cuda")]
1446 {
1447 let self_gpu = self.device().is_gpu();
1448 let other_gpu = other.device().is_gpu();
1449 if self_gpu || other_gpu {
1450 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1451 if self_gpu && other_gpu {
1452 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1453 if self.shape == other.shape {
1454 return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1455 } else {
1456 return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1457 }
1458 }
1459 let target = if self_gpu {
1460 self.device()
1461 } else {
1462 other.device()
1463 };
1464 let a_gpu = if self_gpu {
1465 self.clone()
1466 } else {
1467 self.to_device(target)?
1468 };
1469 let b_gpu = if other_gpu {
1470 other.clone()
1471 } else {
1472 other.to_device(target)?
1473 };
1474 return a_gpu.div(&b_gpu);
1475 }
1476 }
1477 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1478 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1479 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1480
1481 let total = numel(&result_shape);
1482 let mut result_data = vec![T::zero(); total];
1483
1484 let self_data = self.storage.as_slice();
1485 let other_data = other.storage.as_slice();
1486
1487 for i in 0..total {
1488 let indices = crate::shape::unravel_index(i, &result_shape);
1489 let self_idx = self.offset + linear_index(&indices, &self_strides);
1490 let other_idx = other.offset + linear_index(&indices, &other_strides);
1491 result_data[i] = self_data[self_idx] / other_data[other_idx];
1492 }
1493
1494 Self::from_vec(result_data, &result_shape)
1495 }
1496
1497 #[must_use]
1499 pub fn add_scalar(&self, scalar: T) -> Self {
1500 #[cfg(feature = "cuda")]
1501 if self.device().is_gpu() {
1502 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1503 let self_f32 = unsafe { gpu_ref(self) };
1504 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1505 return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1506 }
1507 let data = self.to_vec();
1508 let mut result = vec![T::zero(); data.len()];
1509 CpuBackend::add_scalar(&mut result, &data, scalar);
1510 Self::from_vec(result, &self.shape).unwrap()
1511 }
1512
1513 #[must_use]
1515 pub fn mul_scalar(&self, scalar: T) -> Self {
1516 #[cfg(feature = "cuda")]
1517 if self.device().is_gpu() {
1518 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1519 let self_f32 = unsafe { gpu_ref(self) };
1520 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1521 return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1522 }
1523 let data = self.to_vec();
1524 let mut result = vec![T::zero(); data.len()];
1525 CpuBackend::mul_scalar(&mut result, &data, scalar);
1526 Self::from_vec(result, &self.shape).unwrap()
1527 }
1528
1529 #[must_use]
1531 pub fn neg(&self) -> Self {
1532 #[cfg(feature = "cuda")]
1533 if self.device().is_gpu() {
1534 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1535 let self_f32 = unsafe { gpu_ref(self) };
1536 return unsafe { gpu_into(self_f32.neg_cuda()) };
1537 }
1538 let data = self.to_vec();
1539 let mut result = vec![T::zero(); data.len()];
1540 CpuBackend::neg(&mut result, &data);
1541 Self::from_vec(result, &self.shape).unwrap()
1542 }
1543
1544 pub fn matmul(&self, other: &Self) -> Result<Self> {
1551 #[cfg(feature = "cuda")]
1552 if self.device().is_gpu() {
1553 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1554 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1555 return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1556 }
1557 if self.ndim() < 2 || other.ndim() < 2 {
1558 return Err(Error::invalid_operation(
1559 "matmul requires at least 2D tensors",
1560 ));
1561 }
1562
1563 let m = self.shape[self.ndim() - 2];
1564 let k1 = self.shape[self.ndim() - 1];
1565 let k2 = other.shape[other.ndim() - 2];
1566 let n = other.shape[other.ndim() - 1];
1567
1568 if k1 != k2 {
1569 return Err(Error::invalid_operation(format!(
1570 "matmul inner dimensions must match: {k1} vs {k2}"
1571 )));
1572 }
1573
1574 if self.ndim() == 2 && other.ndim() == 2 {
1576 let a_data = self.contiguous().to_vec();
1577 let b_data = other.contiguous().to_vec();
1578
1579 #[cfg(feature = "cuda")]
1583 {
1584 let flops = m * n * k1;
1585 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1586 && flops >= 4_000_000
1587 {
1588 debug_assert!(std::mem::size_of::<T>() == std::mem::size_of::<f32>());
1589 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1591 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1592 if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1593 let c_t: Vec<T> = unsafe {
1595 let mut v = std::mem::ManuallyDrop::new(c_f32);
1596 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1597 };
1598 return Self::from_vec(c_t, &[m, n]);
1599 }
1600 }
1601 }
1602
1603 let mut c_data = vec![T::zero(); m * n];
1604 CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1605 return Self::from_vec(c_data, &[m, n]);
1606 }
1607
1608 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1610 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1611
1612 let broadcast_batch = if batch_dims_self != batch_dims_other {
1614 let max_len = batch_dims_self.len().max(batch_dims_other.len());
1616 let pad_a = vec![1usize; max_len - batch_dims_self.len()];
1617 let pad_b = vec![1usize; max_len - batch_dims_other.len()];
1618 let a_dims: Vec<usize> = pad_a.iter().chain(batch_dims_self.iter()).copied().collect();
1619 let b_dims: Vec<usize> = pad_b.iter().chain(batch_dims_other.iter()).copied().collect();
1620
1621 let mut out_dims = Vec::with_capacity(max_len);
1622 for i in 0..max_len {
1623 if a_dims[i] == b_dims[i] {
1624 out_dims.push(a_dims[i]);
1625 } else if a_dims[i] == 1 {
1626 out_dims.push(b_dims[i]);
1627 } else if b_dims[i] == 1 {
1628 out_dims.push(a_dims[i]);
1629 } else {
1630 return Err(Error::invalid_operation(format!(
1631 "matmul batch dimensions not broadcastable: {:?} vs {:?}",
1632 batch_dims_self, batch_dims_other
1633 )));
1634 }
1635 }
1636 Some((a_dims, b_dims, out_dims))
1637 } else {
1638 None
1639 };
1640
1641 let (batch_size, a_batch_idx, b_batch_idx) = if let Some((a_dims, b_dims, out_dims)) = &broadcast_batch {
1642 let bs: usize = out_dims.iter().product();
1643 let mut a_idx = Vec::with_capacity(bs);
1645 let mut b_idx = Vec::with_capacity(bs);
1646 for flat in 0..bs {
1647 let mut remaining = flat;
1648 let mut ai = 0usize;
1649 let mut bi = 0usize;
1650 let mut a_stride_acc = 1usize;
1651 let mut b_stride_acc = 1usize;
1652 for d in (0..out_dims.len()).rev() {
1653 let out_d = out_dims[d];
1654 let idx = remaining % out_d;
1655 remaining /= out_d;
1656 let a_d = a_dims[d];
1657 let b_d = b_dims[d];
1658 ai += (idx % a_d) * a_stride_acc;
1659 bi += (idx % b_d) * b_stride_acc;
1660 a_stride_acc *= a_d;
1661 b_stride_acc *= b_d;
1662 }
1663 a_idx.push(ai);
1664 b_idx.push(bi);
1665 }
1666 (bs, a_idx, b_idx)
1667 } else {
1668 let bs: usize = batch_dims_self.iter().product();
1669 let idx: Vec<usize> = (0..bs).collect();
1670 (bs, idx.clone(), idx)
1671 };
1672
1673 let a_stride = m * k1;
1674 let b_stride = k1 * n;
1675 let c_stride = m * n;
1676
1677 let a_data = self.contiguous().to_vec();
1678 let b_data = other.contiguous().to_vec();
1679 let mut c_data = vec![T::zero(); batch_size * m * n];
1680
1681 #[cfg(feature = "cuda")]
1683 {
1684 let flops = m * n * k1;
1685 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1686 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1687 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1688 let mut gpu_ok = true;
1689 for batch in 0..batch_size {
1690 let ai = a_batch_idx[batch];
1691 let bi = b_batch_idx[batch];
1692 let a_slice = &a_f32[ai * a_stride..(ai + 1) * a_stride];
1693 let b_slice = &b_f32[bi * b_stride..(bi + 1) * b_stride];
1694 if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1695 c_data[batch * c_stride..(batch + 1) * c_stride]
1696 .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1697 } else {
1698 gpu_ok = false;
1699 break;
1700 }
1701 }
1702 if gpu_ok {
1703 let mut output_shape = batch_dims_self;
1704 output_shape.push(m);
1705 output_shape.push(n);
1706 return Self::from_vec(c_data, &output_shape);
1707 }
1708 c_data = vec![T::zero(); batch_size * m * n];
1710 }
1711 }
1712
1713 for batch in 0..batch_size {
1715 let ai = a_batch_idx[batch];
1716 let bi = b_batch_idx[batch];
1717 let a_slice = &a_data[ai * a_stride..(ai + 1) * a_stride];
1718 let b_slice = &b_data[bi * b_stride..(bi + 1) * b_stride];
1719 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1720 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1721 }
1722
1723 let mut output_shape = if let Some((_, _, ref out_dims)) = broadcast_batch {
1725 out_dims.clone()
1726 } else {
1727 batch_dims_self
1728 };
1729 output_shape.push(m);
1730 output_shape.push(n);
1731
1732 Self::from_vec(c_data, &output_shape)
1733 }
1734
1735 pub fn dot(&self, other: &Self) -> Result<Self> {
1737 if self.ndim() != 1 || other.ndim() != 1 {
1738 return Err(Error::invalid_operation("dot requires 1D tensors"));
1739 }
1740
1741 if self.shape[0] != other.shape[0] {
1742 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1743 }
1744
1745 let a_data = self.to_vec();
1746 let b_data = other.to_vec();
1747 let result = CpuBackend::dot(&a_data, &b_data);
1748
1749 Ok(Self::scalar(result))
1750 }
1751}
1752
1753impl<T: Numeric> Add for &Tensor<T> {
1758 type Output = Tensor<T>;
1759
1760 fn add(self, other: Self) -> Self::Output {
1761 self.add(other).expect("Addition failed")
1762 }
1763}
1764
1765impl<T: Numeric> Sub for &Tensor<T> {
1766 type Output = Tensor<T>;
1767
1768 fn sub(self, other: Self) -> Self::Output {
1769 self.sub(other).expect("Subtraction failed")
1770 }
1771}
1772
1773impl<T: Numeric> Mul for &Tensor<T> {
1774 type Output = Tensor<T>;
1775
1776 fn mul(self, other: Self) -> Self::Output {
1777 self.mul(other).expect("Multiplication failed")
1778 }
1779}
1780
1781impl<T: Numeric> Div for &Tensor<T> {
1782 type Output = Tensor<T>;
1783
1784 fn div(self, other: Self) -> Self::Output {
1785 self.div(other).expect("Division failed")
1786 }
1787}
1788
1789impl<T: Numeric> Neg for &Tensor<T> {
1790 type Output = Tensor<T>;
1791
1792 fn neg(self) -> Self::Output {
1793 self.neg()
1794 }
1795}
1796
1797impl<T: Numeric> Add<T> for &Tensor<T> {
1799 type Output = Tensor<T>;
1800
1801 fn add(self, scalar: T) -> Self::Output {
1802 self.add_scalar(scalar)
1803 }
1804}
1805
1806impl<T: Numeric> Mul<T> for &Tensor<T> {
1807 type Output = Tensor<T>;
1808
1809 fn mul(self, scalar: T) -> Self::Output {
1810 self.mul_scalar(scalar)
1811 }
1812}
1813
1814impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1819 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1820 write!(
1821 f,
1822 "Tensor(shape={:?}, device={}",
1823 self.shape(),
1824 self.device()
1825 )?;
1826 if self.numel() <= 10 {
1827 write!(f, ", data={:?}", self.to_vec())?;
1828 }
1829 write!(f, ")")
1830 }
1831}
1832
1833impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1834 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1835 if self.is_scalar() {
1836 write!(f, "{}", self.item().unwrap())
1837 } else if self.ndim() == 1 {
1838 write!(f, "[")?;
1839 let data = self.to_vec();
1840 for (i, val) in data.iter().enumerate() {
1841 if i > 0 {
1842 write!(f, ", ")?;
1843 }
1844 write!(f, "{val}")?;
1845 }
1846 write!(f, "]")
1847 } else {
1848 write!(f, "Tensor(shape={:?})", self.shape())
1849 }
1850 }
1851}
1852
1853impl Tensor<f32> {
1858 #[must_use]
1867 pub fn to_f16_precision(&self) -> Self {
1868 let data = self.to_vec();
1869 let f16_data: Vec<f32> = data
1870 .iter()
1871 .map(|&v| {
1872 let h = half::f16::from_f32(v);
1873 h.to_f32()
1874 })
1875 .collect();
1876 Self::from_vec(f16_data, self.shape()).unwrap()
1877 }
1878
1879 #[must_use]
1884 pub fn to_f32_precision(&self) -> Self {
1885 self.clone()
1886 }
1887
1888 #[must_use]
1891 pub fn has_f16_rounding_error(&self) -> bool {
1892 let data = self.to_vec();
1893 data.iter().any(|&v| {
1894 let h = half::f16::from_f32(v);
1895 (h.to_f32() - v).abs() > f32::EPSILON
1896 })
1897 }
1898}
1899
1900#[cfg(test)]
1905mod tests {
1906 use super::*;
1907
1908 #[test]
1909 fn test_from_vec() {
1910 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1911 assert_eq!(t.shape(), &[2, 3]);
1912 assert_eq!(t.numel(), 6);
1913 }
1914
1915 #[test]
1916 fn test_get_set() {
1917 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1918 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1919 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1920 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1921 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1922
1923 t.set(&[0, 0], 99.0).unwrap();
1924 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1925 }
1926
1927 #[test]
1928 fn test_reshape() {
1929 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1930 let r = t.reshape(&[3, 2]).expect("reshape failed");
1931 assert_eq!(r.shape(), &[3, 2]);
1932
1933 let r = t.reshape(&[-1]).expect("reshape failed");
1934 assert_eq!(r.shape(), &[6]);
1935 }
1936
1937 #[test]
1938 fn test_transpose() {
1939 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1940 let r = t.t().unwrap();
1941 assert_eq!(r.shape(), &[3, 2]);
1942 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1943 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1944 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1945 }
1946
1947 #[test]
1948 fn test_arithmetic() {
1949 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1950 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1951
1952 let c = &a + &b;
1953 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1954
1955 let d = &a * &b;
1956 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1957 }
1958
1959 #[test]
1960 fn test_broadcasting() {
1961 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1962 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1963
1964 let c = &a + &b;
1965 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1966 }
1967
1968 #[test]
1969 fn test_sum() {
1970 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1971 let s = t.sum();
1972 assert_eq!(s.item().unwrap(), 10.0);
1973 }
1974
1975 #[test]
1976 fn test_matmul() {
1977 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1979 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1980 let c = a.matmul(&b).unwrap();
1981
1982 assert_eq!(c.shape(), &[2, 2]);
1983 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1984 }
1985
1986 #[test]
1987 fn test_relu() {
1988 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1989 let r = t.relu();
1990 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1991 }
1992
1993 #[test]
1994 fn test_scalar() {
1995 let s = Tensor::<f32>::scalar(42.0);
1996 assert!(s.is_scalar());
1997 assert_eq!(s.numel(), 1);
1998 assert_eq!(s.item().unwrap(), 42.0);
1999 }
2000}