1use core::fmt;
32use core::ops::{Add, Div, Mul, Neg, Sub};
33
34use axonml_core::Device;
35use axonml_core::backends::CpuBackend;
36#[cfg(feature = "cuda")]
37use axonml_core::backends::CudaBackend;
38use axonml_core::dtype::{Float, Numeric, Scalar};
39use axonml_core::error::{Error, Result};
40use axonml_core::storage::Storage;
41use num_traits::NumCast;
42
43#[cfg(feature = "cuda")]
48mod cuda_accel {
49 use super::*;
50 use axonml_core::backends::cuda::get_cuda_backend;
51
52 pub fn get_cuda() -> Option<&'static CudaBackend> {
54 get_cuda_backend()
55 }
56
57 pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
60 let cuda = get_cuda()?;
61
62 let a_gpu = cuda.htod_copy(a).ok()?;
63 let b_gpu = cuda.htod_copy(b).ok()?;
64 let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
65
66 cuda.gemm_f32(
69 false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
70 )
71 .ok()?;
72
73 cuda.dtoh_copy(&c_gpu).ok()
74 }
75}
76
77use crate::shape::{
78 Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
79 linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
80 unsqueeze,
81};
82
83#[cfg(feature = "cuda")]
91unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
92 assert!(
93 is_f32::<T>(),
94 "gpu_ref: only Tensor<f32> can be used for GPU operations, got {:?}",
95 T::DTYPE
96 );
97 unsafe { &*(t as *const Tensor<T> as *const Tensor<f32>) }
99}
100
101#[cfg(feature = "cuda")]
102unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
103 assert!(
104 is_f32::<T>(),
105 "gpu_into: only Tensor<f32> can be produced from GPU operations, got {:?}",
106 T::DTYPE
107 );
108 unsafe {
110 let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
111 std::mem::forget(t);
112 out
113 }
114}
115
116#[cfg(feature = "cuda")]
117fn is_f32<T: 'static>() -> bool {
118 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
119}
120
121#[derive(Clone)]
131pub struct Tensor<T: Scalar> {
132 pub(crate) storage: Storage<T>,
134 pub(crate) shape: Shape,
136 pub(crate) strides: Strides,
138 pub(crate) offset: usize,
140}
141
142impl<T: Scalar> Tensor<T> {
143 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
156 let total = numel(shape);
157 if total != storage.len() {
158 return Err(Error::shape_mismatch(&[storage.len()], shape));
159 }
160
161 let shape = Shape::from_slice(shape);
162 let strides = contiguous_strides(&shape);
163
164 Ok(Self {
165 storage,
166 shape,
167 strides,
168 offset: 0,
169 })
170 }
171
172 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
181 let storage = Storage::from_vec(data, Device::Cpu);
182 Self::from_storage(storage, shape)
183 }
184
185 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
194 let storage = Storage::from_slice(data, Device::Cpu);
195 Self::from_storage(storage, shape)
196 }
197
198 pub fn scalar(value: T) -> Self {
206 Self {
207 storage: Storage::from_vec(vec![value], Device::Cpu),
208 shape: Shape::new(),
209 strides: Strides::new(),
210 offset: 0,
211 }
212 }
213
214 #[must_use]
216 pub fn zeros(shape: &[usize]) -> Self {
217 crate::creation::zeros(shape)
218 }
219
220 #[must_use]
222 pub fn ones(shape: &[usize]) -> Self
223 where
224 T: Numeric,
225 {
226 crate::creation::ones(shape)
227 }
228
229 #[must_use]
231 pub fn full(shape: &[usize], value: T) -> Self {
232 crate::creation::full(shape, value)
233 }
234
235 #[must_use]
237 pub fn randn(shape: &[usize]) -> Self
238 where
239 T: Float,
240 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
241 {
242 crate::creation::randn(shape)
243 }
244
245 #[must_use]
247 pub fn rand(shape: &[usize]) -> Self
248 where
249 T: Float,
250 rand::distributions::Standard: rand::distributions::Distribution<T>,
251 {
252 crate::creation::rand(shape)
253 }
254
255 #[must_use]
261 pub fn shape(&self) -> &[usize] {
262 &self.shape
263 }
264
265 #[must_use]
267 pub fn strides(&self) -> &[isize] {
268 &self.strides
269 }
270
271 #[must_use]
273 pub fn ndim(&self) -> usize {
274 self.shape.len()
275 }
276
277 #[must_use]
279 pub fn numel(&self) -> usize {
280 numel(&self.shape)
281 }
282
283 #[must_use]
285 pub fn is_empty(&self) -> bool {
286 self.numel() == 0
287 }
288
289 pub fn size(&self, dim: i64) -> Result<usize> {
294 let idx = normalize_dim(dim, self.ndim())?;
295 Ok(self.shape[idx])
296 }
297
298 #[must_use]
300 pub fn device(&self) -> Device {
301 self.storage.device()
302 }
303
304 #[must_use]
306 pub fn is_contiguous(&self) -> bool {
307 is_contiguous(&self.shape, &self.strides)
308 }
309
310 #[must_use]
312 pub fn is_scalar(&self) -> bool {
313 self.shape.is_empty()
314 }
315
316 pub fn get(&self, indices: &[usize]) -> Result<T> {
325 if indices.len() != self.ndim() {
326 return Err(Error::invalid_operation(format!(
327 "Expected {} indices, got {}",
328 self.ndim(),
329 indices.len()
330 )));
331 }
332
333 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
334 if idx >= dim {
335 return Err(Error::IndexOutOfBounds {
336 index: idx,
337 size: dim,
338 });
339 }
340 }
341
342 let offset = self.offset + linear_index(indices, &self.strides);
343 Ok(self.storage.as_slice()[offset])
344 }
345
346 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
352 if indices.len() != self.ndim() {
353 return Err(Error::invalid_operation(format!(
354 "Expected {} indices, got {}",
355 self.ndim(),
356 indices.len()
357 )));
358 }
359
360 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
361 if idx >= dim {
362 return Err(Error::IndexOutOfBounds {
363 index: idx,
364 size: dim,
365 });
366 }
367 }
368
369 let offset = self.offset + linear_index(indices, &self.strides);
370 self.storage.as_slice_mut()[offset] = value;
371 Ok(())
372 }
373
374 pub fn item(&self) -> Result<T> {
376 if self.numel() != 1 {
377 return Err(Error::invalid_operation(
378 "item() only works on single-element tensors",
379 ));
380 }
381
382 let data = self.to_vec();
384 if data.is_empty() {
385 Err(Error::invalid_operation("item() on empty tensor"))
386 } else {
387 Ok(data[0])
388 }
389 }
390
391 #[must_use]
397 pub fn to_vec(&self) -> Vec<T> {
398 #[cfg(feature = "cuda")]
400 if self.storage.is_gpu() {
401 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
402 let self_f32 = unsafe { gpu_ref(self) };
403 let f32_vec = self_f32.to_vec_gpu();
404 unsafe {
405 let mut v = std::mem::ManuallyDrop::new(f32_vec);
406 return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
407 }
408 }
409
410 if self.is_contiguous() {
411 let storage = self.storage.as_slice();
412 storage[self.offset..self.offset + self.numel()].to_vec()
413 } else {
414 let mut result = Vec::with_capacity(self.numel());
415 self.copy_data_to(&mut result);
416 result
417 }
418 }
419
420 fn copy_data_to(&self, dst: &mut Vec<T>) {
422 dst.clear();
423 let storage = self.storage.as_slice();
424
425 let total = self.numel();
427 for i in 0..total {
428 let indices = crate::shape::unravel_index(i, &self.shape);
429 let offset = self.offset + linear_index(&indices, &self.strides);
430 dst.push(storage[offset]);
431 }
432 }
433
434 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
446 let shape = reshape(&self.shape, new_shape)?;
447
448 if self.is_contiguous() {
449 Ok(Self {
451 storage: self.storage.clone(),
452 strides: contiguous_strides(&shape),
453 shape,
454 offset: self.offset,
455 })
456 } else {
457 let contig = self.contiguous();
459 Ok(Self {
460 storage: contig.storage,
461 strides: contiguous_strides(&shape),
462 shape,
463 offset: 0,
464 })
465 }
466 }
467
468 #[must_use]
470 pub fn flatten(&self) -> Self {
471 self.reshape(&[-1]).expect("Flatten should never fail")
472 }
473
474 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
479 let dim = match dim {
480 Some(d) => Some(normalize_dim(d, self.ndim())?),
481 None => None,
482 };
483
484 let new_shape = squeeze(&self.shape, dim);
485 let new_strides: Strides = match dim {
486 Some(d) => {
487 let mut s = self.strides.clone();
488 if d < self.shape.len() && self.shape[d] == 1 {
489 s.remove(d);
490 }
491 s
492 }
493 None => self
494 .shape
495 .iter()
496 .zip(self.strides.iter())
497 .filter(|(dim, _)| **dim != 1)
498 .map(|(_, stride)| *stride)
499 .collect(),
500 };
501
502 Ok(Self {
503 storage: self.storage.clone(),
504 shape: new_shape,
505 strides: new_strides,
506 offset: self.offset,
507 })
508 }
509
510 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
515 let normalized = if dim < 0 {
516 (dim + self.ndim() as i64 + 1) as usize
517 } else {
518 dim as usize
519 };
520
521 let new_shape = unsqueeze(&self.shape, normalized)?;
522 let mut new_strides = Strides::with_capacity(new_shape.len());
523
524 for (i, _) in new_shape.iter().enumerate() {
525 if i < normalized {
526 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
527 } else if i == normalized {
528 new_strides.push(1);
530 } else {
531 new_strides.push(self.strides[i - 1]);
532 }
533 }
534
535 Ok(Self {
536 storage: self.storage.clone(),
537 shape: new_shape,
538 strides: new_strides,
539 offset: self.offset,
540 })
541 }
542
543 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
549 let d0 = normalize_dim(dim0, self.ndim())?;
550 let d1 = normalize_dim(dim1, self.ndim())?;
551
552 let new_shape = transpose_shape(&self.shape, d0, d1)?;
553 let new_strides = transpose_strides(&self.strides, d0, d1);
554
555 Ok(Self {
556 storage: self.storage.clone(),
557 shape: new_shape,
558 strides: new_strides,
559 offset: self.offset,
560 })
561 }
562
563 pub fn t(&self) -> Result<Self> {
565 if self.ndim() != 2 {
566 return Err(Error::invalid_operation("t() only works on 2D tensors"));
567 }
568 self.transpose(0, 1)
569 }
570
571 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
576 if dims.len() != self.ndim() {
577 return Err(Error::invalid_operation(format!(
578 "Expected {} dimensions, got {}",
579 self.ndim(),
580 dims.len()
581 )));
582 }
583
584 let mut seen = vec![false; self.ndim()];
586 for &d in dims {
587 if d >= self.ndim() {
588 return Err(Error::InvalidDimension {
589 index: d as i64,
590 ndim: self.ndim(),
591 });
592 }
593 if seen[d] {
594 return Err(Error::invalid_operation("Duplicate dimension in permute"));
595 }
596 seen[d] = true;
597 }
598
599 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
600 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
601
602 Ok(Self {
603 storage: self.storage.clone(),
604 shape: new_shape,
605 strides: new_strides,
606 offset: self.offset,
607 })
608 }
609
610 #[must_use]
612 pub fn contiguous(&self) -> Self {
613 if self.is_contiguous() && self.offset == 0 {
614 return self.clone();
615 }
616
617 #[cfg(feature = "cuda")]
618 if self.storage.is_gpu() {
619 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
620 let self_f32 = unsafe { gpu_ref(self) };
621 let result = self_f32.contiguous_gpu();
622 return unsafe { gpu_into(result) };
623 }
624
625 let data = self.to_vec();
626 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
627 }
628
629 #[must_use]
638 pub fn map<F: Fn(T) -> T>(&self, f: F) -> Self {
639 let data = self.to_vec(); let result: Vec<T> = data.into_iter().map(f).collect();
641 Self::from_vec(result, &self.shape).unwrap()
642 }
643
644 #[must_use]
650 pub fn zip_map<F: Fn(T, T) -> T>(&self, other: &Self, f: F) -> Self {
651 let a = self.to_vec();
652 let b = other.to_vec();
653 debug_assert_eq!(
654 a.len(),
655 b.len(),
656 "zip_map requires same number of elements: {} vs {}",
657 a.len(),
658 b.len()
659 );
660 let result: Vec<T> = a.into_iter().zip(b).map(|(x, y)| f(x, y)).collect();
661 Self::from_vec(result, &self.shape).unwrap()
662 }
663
664 #[must_use]
666 pub fn zip_map3<F: Fn(T, T, T) -> T>(&self, b: &Self, c: &Self, f: F) -> Self {
667 let a_data = self.to_vec();
668 let b_data = b.to_vec();
669 let c_data = c.to_vec();
670 debug_assert_eq!(a_data.len(), b_data.len());
671 debug_assert_eq!(a_data.len(), c_data.len());
672 let result: Vec<T> = a_data
673 .into_iter()
674 .zip(b_data)
675 .zip(c_data)
676 .map(|((a, b), c)| f(a, b, c))
677 .collect();
678 Self::from_vec(result, &self.shape).unwrap()
679 }
680
681 pub fn to_device(&self, device: Device) -> Result<Self> {
690 if self.device() == device {
691 return Ok(self.clone());
692 }
693
694 #[cfg(feature = "cuda")]
695 if self.storage.is_gpu() || device.is_gpu() {
696 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
697 let self_f32 = unsafe { gpu_ref(self) };
698 let result = self_f32.to_device_f32(device)?;
699 return Ok(unsafe { gpu_into(result) });
700 }
701
702 let contig = self.contiguous();
703 let new_storage = contig.storage.to_device(device)?;
704
705 Ok(Self {
706 storage: new_storage,
707 shape: self.shape.clone(),
708 strides: self.strides.clone(),
709 offset: 0,
710 })
711 }
712
713 pub fn cpu(&self) -> Result<Self> {
715 self.to_device(Device::Cpu)
716 }
717
718 #[must_use]
724 pub fn clone_deep(&self) -> Self {
725 let data = self.to_vec();
726 let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
727 #[cfg(feature = "cuda")]
728 if self.device().is_gpu() {
729 return cpu.to_device(self.device()).unwrap();
730 }
731 cpu
732 }
733}
734
735impl<T: Numeric> Tensor<T> {
740 pub fn fill_(&self, value: T) {
746 assert!(
747 self.storage.is_cpu(),
748 "fill_() not supported on GPU tensors — create a new tensor and transfer instead"
749 );
750 let mut data = self.storage.as_slice_mut();
751 CpuBackend::fill(&mut data, value);
752 }
753
754 pub fn zero_(&self) {
756 self.fill_(T::zero());
757 }
758
759 #[must_use]
767 pub fn sum(&self) -> Self {
768 #[cfg(feature = "cuda")]
769 if self.device().is_gpu() {
770 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
771 let self_f32 = unsafe { gpu_ref(self) };
772 let mut t = self_f32.clone();
773 while t.ndim() > 1 {
774 t = t.sum_dim_cuda(0);
775 }
776 if t.numel() > 1 {
777 t = t.sum_dim_cuda(0);
778 }
779 return unsafe { gpu_into(t) };
780 }
781
782 let data = self.to_vec();
783 let result = CpuBackend::sum(&data);
784 Self::scalar(result)
785 }
786
787 #[must_use]
791 pub fn prod(&self) -> Self {
792 let data = self.to_vec();
793 let result = CpuBackend::prod(&data);
794 let s = Self::scalar(result);
795 #[cfg(feature = "cuda")]
796 if self.device().is_gpu() {
797 return s
798 .to_device(self.device())
799 .expect("prod: device transfer failed");
800 }
801 s
802 }
803
804 pub fn max(&self) -> Result<Self> {
808 if self.is_empty() {
809 return Err(Error::EmptyTensor);
810 }
811 let data = self.to_vec();
812 let result = CpuBackend::max(&data).expect("max on non-empty tensor");
813 let s = Self::scalar(result);
814 #[cfg(feature = "cuda")]
815 if self.device().is_gpu() {
816 return Ok(s
817 .to_device(self.device())
818 .expect("max: device transfer failed"));
819 }
820 Ok(s)
821 }
822
823 pub fn min(&self) -> Result<Self> {
827 if self.is_empty() {
828 return Err(Error::EmptyTensor);
829 }
830 let data = self.to_vec();
831 let result = CpuBackend::min(&data).expect("min on non-empty tensor");
832 let s = Self::scalar(result);
833 #[cfg(feature = "cuda")]
834 if self.device().is_gpu() {
835 return Ok(s
836 .to_device(self.device())
837 .expect("min: device transfer failed"));
838 }
839 Ok(s)
840 }
841
842 pub fn argmax(&self) -> Result<usize> {
844 if self.is_empty() {
845 return Err(Error::EmptyTensor);
846 }
847 let data = self.to_vec();
848 Ok(CpuBackend::argmax(&data).unwrap())
849 }
850
851 pub fn argmin(&self) -> Result<usize> {
853 if self.is_empty() {
854 return Err(Error::EmptyTensor);
855 }
856 let data = self.to_vec();
857 Ok(CpuBackend::argmin(&data).unwrap())
858 }
859
860 pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
864 if tensors.is_empty() {
865 return Err(Error::invalid_operation("cat requires at least one tensor"));
866 }
867 let ndim = tensors[0].ndim();
868 if dim >= ndim {
869 return Err(Error::invalid_operation("cat dimension out of range"));
870 }
871
872 for t in &tensors[1..] {
873 if t.ndim() != ndim {
874 return Err(Error::invalid_operation(
875 "cat: all tensors must have same ndim",
876 ));
877 }
878 for d in 0..ndim {
879 if d != dim && t.shape[d] != tensors[0].shape[d] {
880 return Err(Error::invalid_operation(
881 "cat: shapes must match on non-cat dims",
882 ));
883 }
884 }
885 }
886
887 let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
888 let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
889 out_shape[dim] = total_dim_size;
890
891 let outer_size: usize = out_shape[..dim].iter().product();
892 let inner_size: usize = out_shape[dim + 1..].iter().product();
893 let total_numel: usize = out_shape.iter().product();
894 let mut result = vec![T::zero(); total_numel];
895
896 let mut dim_offset = 0;
897 for t in tensors {
898 let t_data = t.contiguous().to_vec();
899 let t_dim_size = t.shape[dim];
900 for outer in 0..outer_size {
901 for d in 0..t_dim_size {
902 let src_base = outer * t_dim_size * inner_size + d * inner_size;
903 let dst_base =
904 outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
905 result[dst_base..dst_base + inner_size]
906 .copy_from_slice(&t_data[src_base..src_base + inner_size]);
907 }
908 }
909 dim_offset += t_dim_size;
910 }
911
912 let out = Self::from_vec(result, &out_shape)?;
913 #[cfg(feature = "cuda")]
914 if tensors[0].device().is_gpu() {
915 return Ok(out.to_device(tensors[0].device()).unwrap());
916 }
917 Ok(out)
918 }
919}
920
921impl<T: Float> Tensor<T> {
926 pub fn mean(&self) -> Result<Self> {
931 if self.is_empty() {
932 return Err(Error::EmptyTensor);
933 }
934 #[cfg(feature = "cuda")]
935 if self.device().is_gpu() {
936 let s = self.sum(); let n = self.numel() as f32;
938 return Ok(s.mul_scalar(T::from(1.0 / n as f64).unwrap_or(T::zero())));
940 }
941
942 let data = self.to_vec();
943 let result = CpuBackend::mean(&data).expect("mean on non-empty tensor");
944 Ok(Self::scalar(result))
945 }
946
947 #[must_use]
953 pub fn relu(&self) -> Self {
954 #[cfg(feature = "cuda")]
955 if self.device().is_gpu() {
956 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
957 return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
958 }
959 let data = self.to_vec();
960 let mut result = vec![T::zero(); data.len()];
961 CpuBackend::relu(&mut result, &data);
962 Self::from_vec(result, &self.shape).unwrap()
963 }
964
965 #[must_use]
967 pub fn sigmoid(&self) -> Self {
968 #[cfg(feature = "cuda")]
969 if self.device().is_gpu() {
970 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
971 return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
972 }
973 let data = self.to_vec();
974 let mut result = vec![T::zero(); data.len()];
975 CpuBackend::sigmoid(&mut result, &data);
976 Self::from_vec(result, &self.shape).unwrap()
977 }
978
979 #[must_use]
981 pub fn tanh(&self) -> Self {
982 #[cfg(feature = "cuda")]
983 if self.device().is_gpu() {
984 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
985 return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
986 }
987 let data = self.to_vec();
988 let mut result = vec![T::zero(); data.len()];
989 CpuBackend::tanh(&mut result, &data);
990 Self::from_vec(result, &self.shape).unwrap()
991 }
992
993 #[must_use]
995 pub fn exp(&self) -> Self {
996 #[cfg(feature = "cuda")]
997 if self.device().is_gpu() {
998 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
999 return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
1000 }
1001 let data = self.to_vec();
1002 let mut result = vec![T::zero(); data.len()];
1003 CpuBackend::exp(&mut result, &data);
1004 Self::from_vec(result, &self.shape).unwrap()
1005 }
1006
1007 #[must_use]
1009 pub fn ln(&self) -> Self {
1010 #[cfg(feature = "cuda")]
1011 if self.device().is_gpu() {
1012 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1013 return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
1014 }
1015 let data = self.to_vec();
1016 let mut result = vec![T::zero(); data.len()];
1017 CpuBackend::ln(&mut result, &data);
1018 Self::from_vec(result, &self.shape).unwrap()
1019 }
1020
1021 #[must_use]
1023 pub fn sqrt(&self) -> Self {
1024 #[cfg(feature = "cuda")]
1025 if self.device().is_gpu() {
1026 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1027 return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
1028 }
1029 let data = self.to_vec();
1030 let mut result = vec![T::zero(); data.len()];
1031 CpuBackend::sqrt(&mut result, &data);
1032 Self::from_vec(result, &self.shape).unwrap()
1033 }
1034
1035 #[must_use]
1037 pub fn pow(&self, exp: T) -> Self {
1038 #[cfg(feature = "cuda")]
1039 if self.device().is_gpu() {
1040 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1041 let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
1042 return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
1043 }
1044 let data = self.to_vec();
1045 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
1046 Self::from_vec(result, &self.shape).unwrap()
1047 }
1048
1049 #[must_use]
1051 pub fn gelu(&self) -> Self {
1052 #[cfg(feature = "cuda")]
1053 if self.device().is_gpu() {
1054 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1055 return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
1056 }
1057 crate::ops::gelu(self)
1058 }
1059
1060 #[must_use]
1062 pub fn silu(&self) -> Self {
1063 #[cfg(feature = "cuda")]
1064 if self.device().is_gpu() {
1065 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1066 return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
1067 }
1068 crate::ops::silu(self)
1069 }
1070
1071 #[must_use]
1073 pub fn softmax(&self, dim: i32) -> Self {
1074 #[cfg(feature = "cuda")]
1075 if self.device().is_gpu() {
1076 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1077 let self_f32 = unsafe { gpu_ref(self) };
1078 return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
1079 }
1080 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
1081 }
1082
1083 #[must_use]
1085 pub fn log_softmax(&self, dim: i32) -> Self {
1086 let softmax_result = self.softmax(dim);
1087 softmax_result.ln()
1088 }
1089
1090 #[must_use]
1092 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
1093 let ndim = self.ndim();
1094 let dim = if dim < 0 {
1095 (ndim as i32 + dim) as usize
1096 } else {
1097 dim as usize
1098 };
1099
1100 if dim >= ndim {
1101 return self.clone();
1102 }
1103
1104 #[cfg(feature = "cuda")]
1106 if self.device().is_gpu() {
1107 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1108 let self_f32 = unsafe { gpu_ref(self) };
1109 let summed = if keepdim {
1110 self_f32.sum_dim_keepdim_cuda(dim)
1111 } else {
1112 self_f32.sum_dim_cuda(dim)
1113 };
1114 let dim_size = self.shape[dim];
1115 let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1116 return unsafe { gpu_into(result) };
1117 }
1118
1119 let dim_size = self.shape[dim];
1120 let data = self.to_vec();
1121 let mut new_shape = self.shape.clone();
1122
1123 if keepdim {
1124 new_shape[dim] = 1;
1125 } else {
1126 new_shape.remove(dim);
1127 }
1128
1129 if new_shape.is_empty() {
1130 new_shape = smallvec::smallvec![1];
1131 }
1132
1133 let new_numel: usize = new_shape.iter().product();
1134 let mut result = vec![T::zero(); new_numel];
1135
1136 let outer_size: usize = self.shape[..dim].iter().product();
1137 let inner_size: usize = self.shape[dim + 1..].iter().product();
1138
1139 for outer in 0..outer_size {
1140 for inner in 0..inner_size {
1141 let mut sum = T::zero();
1142 for d in 0..dim_size {
1143 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1144 sum = sum + data[idx];
1145 }
1146 let mean = sum / NumCast::from(dim_size).unwrap();
1147 let result_idx = outer * inner_size + inner;
1148 result[result_idx] = mean;
1149 }
1150 }
1151
1152 Self::from_vec(result, &new_shape).unwrap()
1153 }
1154
1155 #[must_use]
1157 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1158 let ndim = self.ndim();
1159 let dim = if dim < 0 {
1160 (ndim as i32 + dim) as usize
1161 } else {
1162 dim as usize
1163 };
1164
1165 if dim >= ndim {
1166 return self.clone();
1167 }
1168
1169 #[cfg(feature = "cuda")]
1171 if self.device().is_gpu() {
1172 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1173 let self_f32 = unsafe { gpu_ref(self) };
1174 let result = if keepdim {
1175 self_f32.sum_dim_keepdim_cuda(dim)
1176 } else {
1177 self_f32.sum_dim_cuda(dim)
1178 };
1179 return unsafe { gpu_into(result) };
1180 }
1181
1182 let dim_size = self.shape[dim];
1183 let data = self.to_vec();
1184 let mut new_shape = self.shape.clone();
1185
1186 if keepdim {
1187 new_shape[dim] = 1;
1188 } else {
1189 new_shape.remove(dim);
1190 }
1191
1192 if new_shape.is_empty() {
1193 new_shape = smallvec::smallvec![1];
1194 }
1195
1196 let new_numel: usize = new_shape.iter().product();
1197 let mut result = vec![T::zero(); new_numel];
1198
1199 let outer_size: usize = self.shape[..dim].iter().product();
1200 let inner_size: usize = self.shape[dim + 1..].iter().product();
1201
1202 for outer in 0..outer_size {
1203 for inner in 0..inner_size {
1204 let mut sum = T::zero();
1205 for d in 0..dim_size {
1206 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1207 sum = sum + data[idx];
1208 }
1209 let result_idx = outer * inner_size + inner;
1210 result[result_idx] = sum;
1211 }
1212 }
1213
1214 Self::from_vec(result, &new_shape).unwrap()
1215 }
1216
1217 #[must_use]
1219 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1220 let mean = self.mean_dim(dim, true);
1222 let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1223 let mean_sq = sq.mean_dim(dim, keepdim);
1224 let mean_keepdim = if keepdim {
1225 mean.clone()
1226 } else {
1227 self.mean_dim(dim, keepdim)
1228 };
1229 let mean_squared = mean_keepdim
1230 .mul(&mean_keepdim)
1231 .unwrap_or_else(|_| mean_keepdim.clone());
1232 mean_sq
1233 .sub(&mean_squared)
1234 .unwrap_or_else(|_| mean_sq.clone())
1235 }
1236
1237 #[must_use]
1239 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1240 if self.shape.as_slice() == shape {
1241 return self.clone();
1242 }
1243
1244 #[cfg(feature = "cuda")]
1245 if self.device().is_gpu() {
1246 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1247 let self_f32 = unsafe { gpu_ref(self) };
1248 return unsafe {
1249 gpu_into(
1250 self_f32
1251 .broadcast_to_cuda(shape)
1252 .expect("CUDA broadcast_to failed"),
1253 )
1254 };
1255 }
1256
1257 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1258 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1259
1260 let total = numel(&result_shape);
1261 let mut result_data = vec![T::zero(); total];
1262 let self_data = self.storage.as_slice();
1263
1264 for i in 0..total {
1265 let indices = crate::shape::unravel_index(i, &result_shape);
1266 let self_idx = self.offset + linear_index(&indices, &self_strides);
1267 result_data[i] = self_data[self_idx];
1268 }
1269
1270 Self::from_vec(result_data, &result_shape).unwrap()
1271 }
1272
1273 #[must_use]
1275 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1276 let mut new_shape = Vec::with_capacity(self.ndim());
1277 for (i, range) in ranges.iter().enumerate() {
1278 if i < self.ndim() {
1279 new_shape.push(range.end - range.start);
1280 }
1281 }
1282 for i in ranges.len()..self.ndim() {
1284 new_shape.push(self.shape[i]);
1285 }
1286
1287 let new_numel: usize = new_shape.iter().product();
1288 let mut result_data = vec![T::zero(); new_numel];
1289 let self_data = self.to_vec();
1290
1291 let mut result_idx = 0;
1293 Self::slice_recursive(
1294 &self_data,
1295 &self.shape,
1296 ranges,
1297 0,
1298 0,
1299 &mut result_data,
1300 &mut result_idx,
1301 );
1302
1303 let out = Self::from_vec(result_data, &new_shape).unwrap();
1304 #[cfg(feature = "cuda")]
1305 if self.device().is_gpu() {
1306 return out.to_device(self.device()).unwrap();
1307 }
1308 out
1309 }
1310
1311 fn slice_recursive(
1312 data: &[T],
1313 shape: &[usize],
1314 ranges: &[std::ops::Range<usize>],
1315 dim: usize,
1316 offset: usize,
1317 result: &mut [T],
1318 result_idx: &mut usize,
1319 ) {
1320 if dim == shape.len() {
1321 result[*result_idx] = data[offset];
1322 *result_idx += 1;
1323 return;
1324 }
1325
1326 let stride: usize = shape[dim + 1..].iter().product();
1327 let (start, end) = if dim < ranges.len() {
1328 (ranges[dim].start, ranges[dim].end)
1329 } else {
1330 (0, shape[dim])
1331 };
1332
1333 for i in start..end {
1334 Self::slice_recursive(
1335 data,
1336 shape,
1337 ranges,
1338 dim + 1,
1339 offset + i * stride,
1340 result,
1341 result_idx,
1342 );
1343 }
1344 }
1345}
1346
1347impl<T: Numeric> Tensor<T> {
1352 pub fn add(&self, other: &Self) -> Result<Self> {
1354 #[cfg(feature = "cuda")]
1355 {
1356 let self_gpu = self.device().is_gpu();
1357 let other_gpu = other.device().is_gpu();
1358 if self_gpu || other_gpu {
1359 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1360 if self_gpu && other_gpu {
1361 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1362 if self.shape == other.shape {
1363 return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1364 } else {
1365 return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1366 }
1367 }
1368 let target_device = if self_gpu {
1370 self.device()
1371 } else {
1372 other.device()
1373 };
1374 let a_gpu = if self_gpu {
1375 self.clone()
1376 } else {
1377 self.to_device(target_device)?
1378 };
1379 let b_gpu = if other_gpu {
1380 other.clone()
1381 } else {
1382 other.to_device(target_device)?
1383 };
1384 return a_gpu.add(&b_gpu);
1385 }
1386 }
1387 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1389 let a = self.storage.as_slice();
1390 let b = other.storage.as_slice();
1391 let ao = self.offset;
1392 let bo = other.offset;
1393 let n = numel(&self.shape);
1394 let mut result_data = vec![T::zero(); n];
1395 for i in 0..n {
1396 result_data[i] = a[ao + i] + b[bo + i];
1397 }
1398 return Self::from_vec(result_data, &self.shape);
1399 }
1400
1401 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1402 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1403 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1404
1405 let total = numel(&result_shape);
1406 let mut result_data = vec![T::zero(); total];
1407
1408 let self_data = self.storage.as_slice();
1409 let other_data = other.storage.as_slice();
1410
1411 for i in 0..total {
1412 let indices = crate::shape::unravel_index(i, &result_shape);
1413 let self_idx = self.offset + linear_index(&indices, &self_strides);
1414 let other_idx = other.offset + linear_index(&indices, &other_strides);
1415 result_data[i] = self_data[self_idx] + other_data[other_idx];
1416 }
1417
1418 Self::from_vec(result_data, &result_shape)
1419 }
1420
1421 pub fn sub(&self, other: &Self) -> Result<Self> {
1423 #[cfg(feature = "cuda")]
1424 {
1425 let self_gpu = self.device().is_gpu();
1426 let other_gpu = other.device().is_gpu();
1427 if self_gpu || other_gpu {
1428 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1429 if self_gpu && other_gpu {
1430 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1431 if self.shape == other.shape {
1432 return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1433 } else {
1434 return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1435 }
1436 }
1437 let target = if self_gpu {
1438 self.device()
1439 } else {
1440 other.device()
1441 };
1442 let a_gpu = if self_gpu {
1443 self.clone()
1444 } else {
1445 self.to_device(target)?
1446 };
1447 let b_gpu = if other_gpu {
1448 other.clone()
1449 } else {
1450 other.to_device(target)?
1451 };
1452 return a_gpu.sub(&b_gpu);
1453 }
1454 }
1455 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1457 let a = self.storage.as_slice();
1458 let b = other.storage.as_slice();
1459 let (ao, bo) = (self.offset, other.offset);
1460 let n = numel(&self.shape);
1461 let mut r = vec![T::zero(); n];
1462 for i in 0..n {
1463 r[i] = a[ao + i] - b[bo + i];
1464 }
1465 return Self::from_vec(r, &self.shape);
1466 }
1467
1468 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1469 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1470 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1471
1472 let total = numel(&result_shape);
1473 let mut result_data = vec![T::zero(); total];
1474
1475 let self_data = self.storage.as_slice();
1476 let other_data = other.storage.as_slice();
1477
1478 for i in 0..total {
1479 let indices = crate::shape::unravel_index(i, &result_shape);
1480 let self_idx = self.offset + linear_index(&indices, &self_strides);
1481 let other_idx = other.offset + linear_index(&indices, &other_strides);
1482 result_data[i] = self_data[self_idx] - other_data[other_idx];
1483 }
1484
1485 Self::from_vec(result_data, &result_shape)
1486 }
1487
1488 pub fn mul(&self, other: &Self) -> Result<Self> {
1490 #[cfg(feature = "cuda")]
1491 {
1492 let self_gpu = self.device().is_gpu();
1493 let other_gpu = other.device().is_gpu();
1494 if self_gpu || other_gpu {
1495 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1496 if self_gpu && other_gpu {
1497 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1498 if self.shape == other.shape {
1499 return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1500 } else {
1501 return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1502 }
1503 }
1504 let target = if self_gpu {
1505 self.device()
1506 } else {
1507 other.device()
1508 };
1509 let a_gpu = if self_gpu {
1510 self.clone()
1511 } else {
1512 self.to_device(target)?
1513 };
1514 let b_gpu = if other_gpu {
1515 other.clone()
1516 } else {
1517 other.to_device(target)?
1518 };
1519 return a_gpu.mul(&b_gpu);
1520 }
1521 }
1522 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1524 let a = self.storage.as_slice();
1525 let b = other.storage.as_slice();
1526 let (ao, bo) = (self.offset, other.offset);
1527 let n = numel(&self.shape);
1528 let mut r = vec![T::zero(); n];
1529 for i in 0..n {
1530 r[i] = a[ao + i] * b[bo + i];
1531 }
1532 return Self::from_vec(r, &self.shape);
1533 }
1534
1535 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1536 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1537 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1538
1539 let total = numel(&result_shape);
1540 let mut result_data = vec![T::zero(); total];
1541
1542 let self_data = self.storage.as_slice();
1543 let other_data = other.storage.as_slice();
1544
1545 for i in 0..total {
1546 let indices = crate::shape::unravel_index(i, &result_shape);
1547 let self_idx = self.offset + linear_index(&indices, &self_strides);
1548 let other_idx = other.offset + linear_index(&indices, &other_strides);
1549 result_data[i] = self_data[self_idx] * other_data[other_idx];
1550 }
1551
1552 Self::from_vec(result_data, &result_shape)
1553 }
1554
1555 pub fn div(&self, other: &Self) -> Result<Self> {
1557 #[cfg(feature = "cuda")]
1558 {
1559 let self_gpu = self.device().is_gpu();
1560 let other_gpu = other.device().is_gpu();
1561 if self_gpu || other_gpu {
1562 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1563 if self_gpu && other_gpu {
1564 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1565 if self.shape == other.shape {
1566 return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1567 } else {
1568 return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1569 }
1570 }
1571 let target = if self_gpu {
1572 self.device()
1573 } else {
1574 other.device()
1575 };
1576 let a_gpu = if self_gpu {
1577 self.clone()
1578 } else {
1579 self.to_device(target)?
1580 };
1581 let b_gpu = if other_gpu {
1582 other.clone()
1583 } else {
1584 other.to_device(target)?
1585 };
1586 return a_gpu.div(&b_gpu);
1587 }
1588 }
1589 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1591 let a = self.storage.as_slice();
1592 let b = other.storage.as_slice();
1593 let (ao, bo) = (self.offset, other.offset);
1594 let n = numel(&self.shape);
1595 let mut r = vec![T::zero(); n];
1596 for i in 0..n {
1597 r[i] = a[ao + i] / b[bo + i];
1598 }
1599 return Self::from_vec(r, &self.shape);
1600 }
1601
1602 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1603 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1604 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1605
1606 let total = numel(&result_shape);
1607 let mut result_data = vec![T::zero(); total];
1608
1609 let self_data = self.storage.as_slice();
1610 let other_data = other.storage.as_slice();
1611
1612 for i in 0..total {
1613 let indices = crate::shape::unravel_index(i, &result_shape);
1614 let self_idx = self.offset + linear_index(&indices, &self_strides);
1615 let other_idx = other.offset + linear_index(&indices, &other_strides);
1616 result_data[i] = self_data[self_idx] / other_data[other_idx];
1617 }
1618
1619 Self::from_vec(result_data, &result_shape)
1620 }
1621
1622 #[must_use]
1624 pub fn add_scalar(&self, scalar: T) -> Self {
1625 #[cfg(feature = "cuda")]
1626 if self.device().is_gpu() {
1627 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1628 let self_f32 = unsafe { gpu_ref(self) };
1629 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1630 return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1631 }
1632 let data = self.to_vec();
1633 let mut result = vec![T::zero(); data.len()];
1634 CpuBackend::add_scalar(&mut result, &data, scalar);
1635 Self::from_vec(result, &self.shape).unwrap()
1636 }
1637
1638 #[must_use]
1640 pub fn mul_scalar(&self, scalar: T) -> Self {
1641 #[cfg(feature = "cuda")]
1642 if self.device().is_gpu() {
1643 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1644 let self_f32 = unsafe { gpu_ref(self) };
1645 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1646 return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1647 }
1648 let data = self.to_vec();
1649 let mut result = vec![T::zero(); data.len()];
1650 CpuBackend::mul_scalar(&mut result, &data, scalar);
1651 Self::from_vec(result, &self.shape).unwrap()
1652 }
1653
1654 #[must_use]
1656 pub fn neg(&self) -> Self {
1657 #[cfg(feature = "cuda")]
1658 if self.device().is_gpu() {
1659 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1660 let self_f32 = unsafe { gpu_ref(self) };
1661 return unsafe { gpu_into(self_f32.neg_cuda()) };
1662 }
1663 let data = self.to_vec();
1664 let mut result = vec![T::zero(); data.len()];
1665 CpuBackend::neg(&mut result, &data);
1666 Self::from_vec(result, &self.shape).unwrap()
1667 }
1668
1669 pub fn matmul(&self, other: &Self) -> Result<Self> {
1676 #[cfg(feature = "cuda")]
1677 if self.device().is_gpu() {
1678 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1679 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1680 return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1681 }
1682 if self.ndim() < 2 || other.ndim() < 2 {
1683 return Err(Error::invalid_operation(
1684 "matmul requires at least 2D tensors",
1685 ));
1686 }
1687
1688 let m = self.shape[self.ndim() - 2];
1689 let k1 = self.shape[self.ndim() - 1];
1690 let k2 = other.shape[other.ndim() - 2];
1691 let n = other.shape[other.ndim() - 1];
1692
1693 if k1 != k2 {
1694 return Err(Error::invalid_operation(format!(
1695 "matmul inner dimensions must match: {k1} vs {k2}"
1696 )));
1697 }
1698
1699 if self.ndim() == 2 && other.ndim() == 2 {
1710 let self_fast = self.is_contiguous() && self.offset == 0;
1711 let other_fast = other.is_contiguous() && other.offset == 0;
1712
1713 let self_storage = self.storage.as_slice();
1718 let other_storage = other.storage.as_slice();
1719
1720 let a_slice: &[T] = if self_fast {
1721 &self_storage[..m * k1]
1722 } else {
1723 &[]
1726 };
1727 let b_slice: &[T] = if other_fast {
1728 &other_storage[..k1 * n]
1729 } else {
1730 &[]
1731 };
1732
1733 let a_owned: Option<Vec<T>> = if self_fast {
1735 None
1736 } else {
1737 Some(self.contiguous().to_vec())
1738 };
1739 let b_owned: Option<Vec<T>> = if other_fast {
1740 None
1741 } else {
1742 Some(other.contiguous().to_vec())
1743 };
1744 let a: &[T] = a_owned.as_deref().unwrap_or(a_slice);
1745 let b: &[T] = b_owned.as_deref().unwrap_or(b_slice);
1746
1747 #[cfg(feature = "cuda")]
1751 {
1752 let flops = m * n * k1;
1753 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1754 && flops >= 4_000_000
1755 {
1756 debug_assert!(std::mem::size_of::<T>() == std::mem::size_of::<f32>());
1757 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
1759 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
1760 if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1761 let c_t: Vec<T> = unsafe {
1763 let mut v = std::mem::ManuallyDrop::new(c_f32);
1764 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1765 };
1766 return Self::from_vec(c_t, &[m, n]);
1767 }
1768 }
1769 }
1770
1771 let mut c_data: Vec<T> = vec![T::zero(); m * n];
1779 CpuBackend::matmul(&mut c_data, a, b, m, n, k1);
1780 return Self::from_vec(c_data, &[m, n]);
1781 }
1782
1783 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1785 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1786
1787 let broadcast_batch = if batch_dims_self == batch_dims_other {
1789 None
1790 } else {
1791 let max_len = batch_dims_self.len().max(batch_dims_other.len());
1793 let pad_a = vec![1usize; max_len - batch_dims_self.len()];
1794 let pad_b = vec![1usize; max_len - batch_dims_other.len()];
1795 let a_dims: Vec<usize> = pad_a
1796 .iter()
1797 .chain(batch_dims_self.iter())
1798 .copied()
1799 .collect();
1800 let b_dims: Vec<usize> = pad_b
1801 .iter()
1802 .chain(batch_dims_other.iter())
1803 .copied()
1804 .collect();
1805
1806 let mut out_dims = Vec::with_capacity(max_len);
1807 for i in 0..max_len {
1808 if a_dims[i] == b_dims[i] {
1809 out_dims.push(a_dims[i]);
1810 } else if a_dims[i] == 1 {
1811 out_dims.push(b_dims[i]);
1812 } else if b_dims[i] == 1 {
1813 out_dims.push(a_dims[i]);
1814 } else {
1815 return Err(Error::invalid_operation(format!(
1816 "matmul batch dimensions not broadcastable: {:?} vs {:?}",
1817 batch_dims_self, batch_dims_other
1818 )));
1819 }
1820 }
1821 Some((a_dims, b_dims, out_dims))
1822 };
1823
1824 let (batch_size, a_batch_idx, b_batch_idx) =
1825 if let Some((a_dims, b_dims, out_dims)) = &broadcast_batch {
1826 let bs: usize = out_dims.iter().product();
1827 let mut a_idx = Vec::with_capacity(bs);
1829 let mut b_idx = Vec::with_capacity(bs);
1830 for flat in 0..bs {
1831 let mut remaining = flat;
1832 let mut ai = 0usize;
1833 let mut bi = 0usize;
1834 let mut a_stride_acc = 1usize;
1835 let mut b_stride_acc = 1usize;
1836 for d in (0..out_dims.len()).rev() {
1837 let out_d = out_dims[d];
1838 let idx = remaining % out_d;
1839 remaining /= out_d;
1840 let a_d = a_dims[d];
1841 let b_d = b_dims[d];
1842 ai += (idx % a_d) * a_stride_acc;
1843 bi += (idx % b_d) * b_stride_acc;
1844 a_stride_acc *= a_d;
1845 b_stride_acc *= b_d;
1846 }
1847 a_idx.push(ai);
1848 b_idx.push(bi);
1849 }
1850 (bs, a_idx, b_idx)
1851 } else {
1852 let bs: usize = batch_dims_self.iter().product();
1853 let idx: Vec<usize> = (0..bs).collect();
1854 (bs, idx.clone(), idx)
1855 };
1856
1857 let a_stride = m * k1;
1858 let b_stride = k1 * n;
1859 let c_stride = m * n;
1860
1861 let a_data = self.contiguous().to_vec();
1862 let b_data = other.contiguous().to_vec();
1863 let mut c_data = vec![T::zero(); batch_size * m * n];
1864
1865 #[cfg(feature = "cuda")]
1867 {
1868 let flops = m * n * k1;
1869 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1870 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1871 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1872 let mut gpu_ok = true;
1873 for batch in 0..batch_size {
1874 let ai = a_batch_idx[batch];
1875 let bi = b_batch_idx[batch];
1876 let a_slice = &a_f32[ai * a_stride..(ai + 1) * a_stride];
1877 let b_slice = &b_f32[bi * b_stride..(bi + 1) * b_stride];
1878 if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1879 c_data[batch * c_stride..(batch + 1) * c_stride]
1880 .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1881 } else {
1882 gpu_ok = false;
1883 break;
1884 }
1885 }
1886 if gpu_ok {
1887 let mut output_shape = batch_dims_self;
1888 output_shape.push(m);
1889 output_shape.push(n);
1890 return Self::from_vec(c_data, &output_shape);
1891 }
1892 c_data = vec![T::zero(); batch_size * m * n];
1894 }
1895 }
1896
1897 for batch in 0..batch_size {
1899 let ai = a_batch_idx[batch];
1900 let bi = b_batch_idx[batch];
1901 let a_slice = &a_data[ai * a_stride..(ai + 1) * a_stride];
1902 let b_slice = &b_data[bi * b_stride..(bi + 1) * b_stride];
1903 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1904 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1905 }
1906
1907 let mut output_shape = if let Some((_, _, ref out_dims)) = broadcast_batch {
1909 out_dims.clone()
1910 } else {
1911 batch_dims_self
1912 };
1913 output_shape.push(m);
1914 output_shape.push(n);
1915
1916 Self::from_vec(c_data, &output_shape)
1917 }
1918
1919 pub fn dot(&self, other: &Self) -> Result<Self> {
1921 if self.ndim() != 1 || other.ndim() != 1 {
1922 return Err(Error::invalid_operation("dot requires 1D tensors"));
1923 }
1924
1925 if self.shape[0] != other.shape[0] {
1926 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1927 }
1928
1929 let a_data = self.to_vec();
1930 let b_data = other.to_vec();
1931 let result = CpuBackend::dot(&a_data, &b_data);
1932
1933 Ok(Self::scalar(result))
1934 }
1935}
1936
1937impl<T: Numeric> Add for &Tensor<T> {
1942 type Output = Tensor<T>;
1943
1944 fn add(self, other: Self) -> Self::Output {
1945 self.add(other).expect("Addition failed")
1946 }
1947}
1948
1949impl<T: Numeric> Sub for &Tensor<T> {
1950 type Output = Tensor<T>;
1951
1952 fn sub(self, other: Self) -> Self::Output {
1953 self.sub(other).expect("Subtraction failed")
1954 }
1955}
1956
1957impl<T: Numeric> Mul for &Tensor<T> {
1958 type Output = Tensor<T>;
1959
1960 fn mul(self, other: Self) -> Self::Output {
1961 self.mul(other).expect("Multiplication failed")
1962 }
1963}
1964
1965impl<T: Numeric> Div for &Tensor<T> {
1966 type Output = Tensor<T>;
1967
1968 fn div(self, other: Self) -> Self::Output {
1969 self.div(other).expect("Division failed")
1970 }
1971}
1972
1973impl<T: Numeric> Neg for &Tensor<T> {
1974 type Output = Tensor<T>;
1975
1976 fn neg(self) -> Self::Output {
1977 self.neg()
1978 }
1979}
1980
1981impl<T: Numeric> Add<T> for &Tensor<T> {
1983 type Output = Tensor<T>;
1984
1985 fn add(self, scalar: T) -> Self::Output {
1986 self.add_scalar(scalar)
1987 }
1988}
1989
1990impl<T: Numeric> Mul<T> for &Tensor<T> {
1991 type Output = Tensor<T>;
1992
1993 fn mul(self, scalar: T) -> Self::Output {
1994 self.mul_scalar(scalar)
1995 }
1996}
1997
1998impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
2003 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2004 write!(
2005 f,
2006 "Tensor(shape={:?}, device={}",
2007 self.shape(),
2008 self.device()
2009 )?;
2010 if self.numel() <= 10 {
2011 write!(f, ", data={:?}", self.to_vec())?;
2012 }
2013 write!(f, ")")
2014 }
2015}
2016
2017impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
2018 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2019 if self.is_scalar() {
2020 write!(f, "{}", self.item().unwrap())
2021 } else if self.ndim() == 1 {
2022 write!(f, "[")?;
2023 let data = self.to_vec();
2024 for (i, val) in data.iter().enumerate() {
2025 if i > 0 {
2026 write!(f, ", ")?;
2027 }
2028 write!(f, "{val}")?;
2029 }
2030 write!(f, "]")
2031 } else {
2032 write!(f, "Tensor(shape={:?})", self.shape())
2033 }
2034 }
2035}
2036
2037impl Tensor<f32> {
2042 #[must_use]
2051 pub fn to_f16_precision(&self) -> Self {
2052 let data = self.to_vec();
2053 let f16_data: Vec<f32> = data
2054 .iter()
2055 .map(|&v| {
2056 let h = half::f16::from_f32(v);
2057 h.to_f32()
2058 })
2059 .collect();
2060 Self::from_vec(f16_data, self.shape()).unwrap()
2061 }
2062
2063 #[must_use]
2068 pub fn to_f32_precision(&self) -> Self {
2069 self.clone()
2070 }
2071
2072 #[must_use]
2075 pub fn has_f16_rounding_error(&self) -> bool {
2076 let data = self.to_vec();
2077 data.iter().any(|&v| {
2078 let h = half::f16::from_f32(v);
2079 (h.to_f32() - v).abs() > f32::EPSILON
2080 })
2081 }
2082}
2083
2084#[cfg(test)]
2089mod tests {
2090 use super::*;
2091
2092 #[test]
2093 fn test_from_vec() {
2094 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2095 assert_eq!(t.shape(), &[2, 3]);
2096 assert_eq!(t.numel(), 6);
2097 }
2098
2099 #[test]
2100 fn test_get_set() {
2101 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2102 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
2103 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
2104 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
2105 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
2106
2107 t.set(&[0, 0], 99.0).unwrap();
2108 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
2109 }
2110
2111 #[test]
2112 fn test_reshape() {
2113 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2114 let r = t.reshape(&[3, 2]).expect("reshape failed");
2115 assert_eq!(r.shape(), &[3, 2]);
2116
2117 let r = t.reshape(&[-1]).expect("reshape failed");
2118 assert_eq!(r.shape(), &[6]);
2119 }
2120
2121 #[test]
2122 fn test_transpose() {
2123 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2124 let r = t.t().unwrap();
2125 assert_eq!(r.shape(), &[3, 2]);
2126 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
2127 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
2128 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
2129 }
2130
2131 #[test]
2132 fn test_arithmetic() {
2133 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2134 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
2135
2136 let c = &a + &b;
2137 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
2138
2139 let d = &a * &b;
2140 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
2141 }
2142
2143 #[test]
2144 fn test_broadcasting() {
2145 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2146 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
2147
2148 let c = &a + &b;
2149 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
2150 }
2151
2152 #[test]
2153 fn test_sum() {
2154 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
2155 let s = t.sum();
2156 assert_eq!(s.item().unwrap(), 10.0);
2157 }
2158
2159 #[test]
2160 fn test_matmul() {
2161 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2163 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
2164 let c = a.matmul(&b).unwrap();
2165
2166 assert_eq!(c.shape(), &[2, 2]);
2167 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
2168 }
2169
2170 #[test]
2171 fn test_relu() {
2172 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
2173 let r = t.relu();
2174 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
2175 }
2176
2177 #[test]
2178 fn test_scalar() {
2179 let s = Tensor::<f32>::scalar(42.0);
2180 assert!(s.is_scalar());
2181 assert_eq!(s.numel(), 1);
2182 assert_eq!(s.item().unwrap(), 42.0);
2183 }
2184}