1use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::Device;
21use axonml_core::backends::CpuBackend;
22#[cfg(feature = "cuda")]
23use axonml_core::backends::CudaBackend;
24use axonml_core::dtype::{Float, Numeric, Scalar};
25use axonml_core::error::{Error, Result};
26use axonml_core::storage::Storage;
27use num_traits::NumCast;
28
29#[cfg(feature = "cuda")]
34mod cuda_accel {
35 use super::*;
36 use axonml_core::backends::cuda::get_cuda_backend;
37
38 pub fn get_cuda() -> Option<&'static CudaBackend> {
40 get_cuda_backend()
41 }
42
43 pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46 let cuda = get_cuda()?;
47
48 let a_gpu = cuda.htod_copy(a).ok()?;
49 let b_gpu = cuda.htod_copy(b).ok()?;
50 let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52 cuda.gemm_f32(
55 false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
56 )
57 .ok()?;
58
59 cuda.dtoh_copy(&c_gpu).ok()
60 }
61}
62
63use crate::shape::{
64 Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
65 linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
66 unsqueeze,
67};
68
69#[cfg(feature = "cuda")]
77unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
78 &*(t as *const Tensor<T> as *const Tensor<f32>)
79}
80
81#[cfg(feature = "cuda")]
82unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
83 let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
84 std::mem::forget(t);
85 out
86}
87
88#[cfg(feature = "cuda")]
89fn is_f32<T: 'static>() -> bool {
90 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
91}
92
93#[derive(Clone)]
103pub struct Tensor<T: Scalar> {
104 pub(crate) storage: Storage<T>,
106 pub(crate) shape: Shape,
108 pub(crate) strides: Strides,
110 pub(crate) offset: usize,
112}
113
114impl<T: Scalar> Tensor<T> {
115 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
128 let total = numel(shape);
129 if total != storage.len() {
130 return Err(Error::shape_mismatch(&[storage.len()], shape));
131 }
132
133 let shape = Shape::from_slice(shape);
134 let strides = contiguous_strides(&shape);
135
136 Ok(Self {
137 storage,
138 shape,
139 strides,
140 offset: 0,
141 })
142 }
143
144 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
153 let storage = Storage::from_vec(data, Device::Cpu);
154 Self::from_storage(storage, shape)
155 }
156
157 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
166 let storage = Storage::from_slice(data, Device::Cpu);
167 Self::from_storage(storage, shape)
168 }
169
170 pub fn scalar(value: T) -> Self {
178 Self {
179 storage: Storage::from_vec(vec![value], Device::Cpu),
180 shape: Shape::new(),
181 strides: Strides::new(),
182 offset: 0,
183 }
184 }
185
186 #[must_use]
188 pub fn zeros(shape: &[usize]) -> Self {
189 crate::creation::zeros(shape)
190 }
191
192 #[must_use]
194 pub fn ones(shape: &[usize]) -> Self
195 where
196 T: Numeric,
197 {
198 crate::creation::ones(shape)
199 }
200
201 #[must_use]
203 pub fn full(shape: &[usize], value: T) -> Self {
204 crate::creation::full(shape, value)
205 }
206
207 #[must_use]
209 pub fn randn(shape: &[usize]) -> Self
210 where
211 T: Float,
212 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
213 {
214 crate::creation::randn(shape)
215 }
216
217 #[must_use]
219 pub fn rand(shape: &[usize]) -> Self
220 where
221 T: Float,
222 rand::distributions::Standard: rand::distributions::Distribution<T>,
223 {
224 crate::creation::rand(shape)
225 }
226
227 #[must_use]
233 pub fn shape(&self) -> &[usize] {
234 &self.shape
235 }
236
237 #[must_use]
239 pub fn strides(&self) -> &[isize] {
240 &self.strides
241 }
242
243 #[must_use]
245 pub fn ndim(&self) -> usize {
246 self.shape.len()
247 }
248
249 #[must_use]
251 pub fn numel(&self) -> usize {
252 numel(&self.shape)
253 }
254
255 #[must_use]
257 pub fn is_empty(&self) -> bool {
258 self.numel() == 0
259 }
260
261 pub fn size(&self, dim: i64) -> Result<usize> {
266 let idx = normalize_dim(dim, self.ndim())?;
267 Ok(self.shape[idx])
268 }
269
270 #[must_use]
272 pub fn device(&self) -> Device {
273 self.storage.device()
274 }
275
276 #[must_use]
278 pub fn is_contiguous(&self) -> bool {
279 is_contiguous(&self.shape, &self.strides)
280 }
281
282 #[must_use]
284 pub fn is_scalar(&self) -> bool {
285 self.shape.is_empty()
286 }
287
288 pub fn get(&self, indices: &[usize]) -> Result<T> {
297 if indices.len() != self.ndim() {
298 return Err(Error::invalid_operation(format!(
299 "Expected {} indices, got {}",
300 self.ndim(),
301 indices.len()
302 )));
303 }
304
305 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
306 if idx >= dim {
307 return Err(Error::IndexOutOfBounds {
308 index: idx,
309 size: dim,
310 });
311 }
312 }
313
314 let offset = self.offset + linear_index(indices, &self.strides);
315 Ok(self.storage.as_slice()[offset])
316 }
317
318 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
324 if indices.len() != self.ndim() {
325 return Err(Error::invalid_operation(format!(
326 "Expected {} indices, got {}",
327 self.ndim(),
328 indices.len()
329 )));
330 }
331
332 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
333 if idx >= dim {
334 return Err(Error::IndexOutOfBounds {
335 index: idx,
336 size: dim,
337 });
338 }
339 }
340
341 let offset = self.offset + linear_index(indices, &self.strides);
342 self.storage.as_slice_mut()[offset] = value;
343 Ok(())
344 }
345
346 pub fn item(&self) -> Result<T> {
348 if self.numel() != 1 {
349 return Err(Error::invalid_operation(
350 "item() only works on single-element tensors",
351 ));
352 }
353
354 if self.is_scalar() {
355 Ok(self.storage.as_slice()[self.offset])
356 } else {
357 let indices = vec![0; self.ndim()];
359 self.get(&indices)
360 }
361 }
362
363 #[must_use]
369 pub fn to_vec(&self) -> Vec<T> {
370 #[cfg(feature = "cuda")]
372 if self.storage.is_gpu() {
373 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
374 let self_f32 = unsafe { gpu_ref(self) };
375 let f32_vec = self_f32.to_vec_gpu();
376 unsafe {
377 let mut v = std::mem::ManuallyDrop::new(f32_vec);
378 return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
379 }
380 }
381
382 if self.is_contiguous() {
383 let storage = self.storage.as_slice();
384 storage[self.offset..self.offset + self.numel()].to_vec()
385 } else {
386 let mut result = Vec::with_capacity(self.numel());
387 self.copy_data_to(&mut result);
388 result
389 }
390 }
391
392 fn copy_data_to(&self, dst: &mut Vec<T>) {
394 dst.clear();
395 let storage = self.storage.as_slice();
396
397 let total = self.numel();
399 for i in 0..total {
400 let indices = crate::shape::unravel_index(i, &self.shape);
401 let offset = self.offset + linear_index(&indices, &self.strides);
402 dst.push(storage[offset]);
403 }
404 }
405
406 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
418 let shape = reshape(&self.shape, new_shape)?;
419
420 if self.is_contiguous() {
421 Ok(Self {
423 storage: self.storage.clone(),
424 strides: contiguous_strides(&shape),
425 shape,
426 offset: self.offset,
427 })
428 } else {
429 let contig = self.contiguous();
431 Ok(Self {
432 storage: contig.storage,
433 strides: contiguous_strides(&shape),
434 shape,
435 offset: 0,
436 })
437 }
438 }
439
440 #[must_use]
442 pub fn flatten(&self) -> Self {
443 self.reshape(&[-1]).expect("Flatten should never fail")
444 }
445
446 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
451 let dim = match dim {
452 Some(d) => Some(normalize_dim(d, self.ndim())?),
453 None => None,
454 };
455
456 let new_shape = squeeze(&self.shape, dim);
457 let new_strides: Strides = match dim {
458 Some(d) => {
459 let mut s = self.strides.clone();
460 if d < self.shape.len() && self.shape[d] == 1 {
461 s.remove(d);
462 }
463 s
464 }
465 None => self
466 .shape
467 .iter()
468 .zip(self.strides.iter())
469 .filter(|(dim, _)| **dim != 1)
470 .map(|(_, stride)| *stride)
471 .collect(),
472 };
473
474 Ok(Self {
475 storage: self.storage.clone(),
476 shape: new_shape,
477 strides: new_strides,
478 offset: self.offset,
479 })
480 }
481
482 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
487 let normalized = if dim < 0 {
488 (dim + self.ndim() as i64 + 1) as usize
489 } else {
490 dim as usize
491 };
492
493 let new_shape = unsqueeze(&self.shape, normalized)?;
494 let mut new_strides = Strides::with_capacity(new_shape.len());
495
496 for (i, _) in new_shape.iter().enumerate() {
497 if i < normalized {
498 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
499 } else if i == normalized {
500 new_strides.push(1);
502 } else {
503 new_strides.push(self.strides[i - 1]);
504 }
505 }
506
507 Ok(Self {
508 storage: self.storage.clone(),
509 shape: new_shape,
510 strides: new_strides,
511 offset: self.offset,
512 })
513 }
514
515 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
521 let d0 = normalize_dim(dim0, self.ndim())?;
522 let d1 = normalize_dim(dim1, self.ndim())?;
523
524 let new_shape = transpose_shape(&self.shape, d0, d1)?;
525 let new_strides = transpose_strides(&self.strides, d0, d1);
526
527 Ok(Self {
528 storage: self.storage.clone(),
529 shape: new_shape,
530 strides: new_strides,
531 offset: self.offset,
532 })
533 }
534
535 pub fn t(&self) -> Result<Self> {
537 if self.ndim() != 2 {
538 return Err(Error::invalid_operation("t() only works on 2D tensors"));
539 }
540 self.transpose(0, 1)
541 }
542
543 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
548 if dims.len() != self.ndim() {
549 return Err(Error::invalid_operation(format!(
550 "Expected {} dimensions, got {}",
551 self.ndim(),
552 dims.len()
553 )));
554 }
555
556 let mut seen = vec![false; self.ndim()];
558 for &d in dims {
559 if d >= self.ndim() {
560 return Err(Error::InvalidDimension {
561 index: d as i64,
562 ndim: self.ndim(),
563 });
564 }
565 if seen[d] {
566 return Err(Error::invalid_operation("Duplicate dimension in permute"));
567 }
568 seen[d] = true;
569 }
570
571 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
572 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
573
574 Ok(Self {
575 storage: self.storage.clone(),
576 shape: new_shape,
577 strides: new_strides,
578 offset: self.offset,
579 })
580 }
581
582 #[must_use]
584 pub fn contiguous(&self) -> Self {
585 if self.is_contiguous() && self.offset == 0 {
586 return self.clone();
587 }
588
589 #[cfg(feature = "cuda")]
590 if self.storage.is_gpu() {
591 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
592 let self_f32 = unsafe { gpu_ref(self) };
593 let result = self_f32.contiguous_gpu();
594 return unsafe { gpu_into(result) };
595 }
596
597 let data = self.to_vec();
598 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
599 }
600
601 pub fn to_device(&self, device: Device) -> Result<Self> {
610 if self.device() == device {
611 return Ok(self.clone());
612 }
613
614 #[cfg(feature = "cuda")]
615 if self.storage.is_gpu() || device.is_gpu() {
616 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
617 let self_f32 = unsafe { gpu_ref(self) };
618 let result = self_f32.to_device_f32(device)?;
619 return Ok(unsafe { gpu_into(result) });
620 }
621
622 let contig = self.contiguous();
623 let new_storage = contig.storage.to_device(device)?;
624
625 Ok(Self {
626 storage: new_storage,
627 shape: self.shape.clone(),
628 strides: self.strides.clone(),
629 offset: 0,
630 })
631 }
632
633 pub fn cpu(&self) -> Result<Self> {
635 self.to_device(Device::Cpu)
636 }
637
638 #[must_use]
644 pub fn clone_deep(&self) -> Self {
645 let data = self.to_vec();
646 let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
647 #[cfg(feature = "cuda")]
648 if self.device().is_gpu() {
649 return cpu.to_device(self.device()).unwrap();
650 }
651 cpu
652 }
653}
654
655impl<T: Numeric> Tensor<T> {
660 pub fn fill_(&self, value: T) {
662 #[cfg(feature = "cuda")]
663 if self.storage.is_gpu() {
664 panic!("fill_() not supported on GPU tensors — create a new tensor instead");
665 }
666 let mut data = self.storage.as_slice_mut();
667 CpuBackend::fill(&mut data, value);
668 }
669
670 pub fn zero_(&self) {
672 self.fill_(T::zero());
673 }
674
675 #[must_use]
681 pub fn sum(&self) -> Self {
682 let data = self.to_vec();
683 let result = CpuBackend::sum(&data);
684 let s = Self::scalar(result);
685 #[cfg(feature = "cuda")]
686 if self.device().is_gpu() {
687 return s.to_device(self.device()).unwrap();
688 }
689 s
690 }
691
692 #[must_use]
694 pub fn prod(&self) -> Self {
695 let data = self.to_vec();
696 let result = CpuBackend::prod(&data);
697 let s = Self::scalar(result);
698 #[cfg(feature = "cuda")]
699 if self.device().is_gpu() {
700 return s.to_device(self.device()).unwrap();
701 }
702 s
703 }
704
705 pub fn max(&self) -> Result<Self> {
707 if self.is_empty() {
708 return Err(Error::EmptyTensor);
709 }
710 let data = self.to_vec();
711 let result = CpuBackend::max(&data).unwrap();
712 let s = Self::scalar(result);
713 #[cfg(feature = "cuda")]
714 if self.device().is_gpu() {
715 return Ok(s.to_device(self.device()).unwrap());
716 }
717 Ok(s)
718 }
719
720 pub fn min(&self) -> Result<Self> {
722 if self.is_empty() {
723 return Err(Error::EmptyTensor);
724 }
725 let data = self.to_vec();
726 let result = CpuBackend::min(&data).unwrap();
727 let s = Self::scalar(result);
728 #[cfg(feature = "cuda")]
729 if self.device().is_gpu() {
730 return Ok(s.to_device(self.device()).unwrap());
731 }
732 Ok(s)
733 }
734
735 pub fn argmax(&self) -> Result<usize> {
737 if self.is_empty() {
738 return Err(Error::EmptyTensor);
739 }
740 let data = self.to_vec();
741 Ok(CpuBackend::argmax(&data).unwrap())
742 }
743
744 pub fn argmin(&self) -> Result<usize> {
746 if self.is_empty() {
747 return Err(Error::EmptyTensor);
748 }
749 let data = self.to_vec();
750 Ok(CpuBackend::argmin(&data).unwrap())
751 }
752
753 pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
757 if tensors.is_empty() {
758 return Err(Error::invalid_operation("cat requires at least one tensor"));
759 }
760 let ndim = tensors[0].ndim();
761 if dim >= ndim {
762 return Err(Error::invalid_operation("cat dimension out of range"));
763 }
764
765 for t in &tensors[1..] {
766 if t.ndim() != ndim {
767 return Err(Error::invalid_operation(
768 "cat: all tensors must have same ndim",
769 ));
770 }
771 for d in 0..ndim {
772 if d != dim && t.shape[d] != tensors[0].shape[d] {
773 return Err(Error::invalid_operation(
774 "cat: shapes must match on non-cat dims",
775 ));
776 }
777 }
778 }
779
780 let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
781 let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
782 out_shape[dim] = total_dim_size;
783
784 let outer_size: usize = out_shape[..dim].iter().product();
785 let inner_size: usize = out_shape[dim + 1..].iter().product();
786 let total_numel: usize = out_shape.iter().product();
787 let mut result = vec![T::zero(); total_numel];
788
789 let mut dim_offset = 0;
790 for t in tensors {
791 let t_data = t.contiguous().to_vec();
792 let t_dim_size = t.shape[dim];
793 for outer in 0..outer_size {
794 for d in 0..t_dim_size {
795 let src_base = outer * t_dim_size * inner_size + d * inner_size;
796 let dst_base =
797 outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
798 result[dst_base..dst_base + inner_size]
799 .copy_from_slice(&t_data[src_base..src_base + inner_size]);
800 }
801 }
802 dim_offset += t_dim_size;
803 }
804
805 let out = Self::from_vec(result, &out_shape)?;
806 #[cfg(feature = "cuda")]
807 if tensors[0].device().is_gpu() {
808 return Ok(out.to_device(tensors[0].device()).unwrap());
809 }
810 Ok(out)
811 }
812}
813
814impl<T: Float> Tensor<T> {
819 pub fn mean(&self) -> Result<Self> {
821 if self.is_empty() {
822 return Err(Error::EmptyTensor);
823 }
824 let data = self.to_vec();
825 let result = CpuBackend::mean(&data).unwrap();
826 let s = Self::scalar(result);
827 #[cfg(feature = "cuda")]
828 if self.device().is_gpu() {
829 return Ok(s.to_device(self.device()).unwrap());
830 }
831 Ok(s)
832 }
833
834 #[must_use]
840 pub fn relu(&self) -> Self {
841 #[cfg(feature = "cuda")]
842 if self.device().is_gpu() {
843 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
844 return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
845 }
846 let data = self.to_vec();
847 let mut result = vec![T::zero(); data.len()];
848 CpuBackend::relu(&mut result, &data);
849 Self::from_vec(result, &self.shape).unwrap()
850 }
851
852 #[must_use]
854 pub fn sigmoid(&self) -> Self {
855 #[cfg(feature = "cuda")]
856 if self.device().is_gpu() {
857 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
858 return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
859 }
860 let data = self.to_vec();
861 let mut result = vec![T::zero(); data.len()];
862 CpuBackend::sigmoid(&mut result, &data);
863 Self::from_vec(result, &self.shape).unwrap()
864 }
865
866 #[must_use]
868 pub fn tanh(&self) -> Self {
869 #[cfg(feature = "cuda")]
870 if self.device().is_gpu() {
871 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
872 return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
873 }
874 let data = self.to_vec();
875 let mut result = vec![T::zero(); data.len()];
876 CpuBackend::tanh(&mut result, &data);
877 Self::from_vec(result, &self.shape).unwrap()
878 }
879
880 #[must_use]
882 pub fn exp(&self) -> Self {
883 #[cfg(feature = "cuda")]
884 if self.device().is_gpu() {
885 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
886 return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
887 }
888 let data = self.to_vec();
889 let mut result = vec![T::zero(); data.len()];
890 CpuBackend::exp(&mut result, &data);
891 Self::from_vec(result, &self.shape).unwrap()
892 }
893
894 #[must_use]
896 pub fn ln(&self) -> Self {
897 #[cfg(feature = "cuda")]
898 if self.device().is_gpu() {
899 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
900 return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
901 }
902 let data = self.to_vec();
903 let mut result = vec![T::zero(); data.len()];
904 CpuBackend::ln(&mut result, &data);
905 Self::from_vec(result, &self.shape).unwrap()
906 }
907
908 #[must_use]
910 pub fn sqrt(&self) -> Self {
911 #[cfg(feature = "cuda")]
912 if self.device().is_gpu() {
913 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
914 return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
915 }
916 let data = self.to_vec();
917 let mut result = vec![T::zero(); data.len()];
918 CpuBackend::sqrt(&mut result, &data);
919 Self::from_vec(result, &self.shape).unwrap()
920 }
921
922 #[must_use]
924 pub fn pow(&self, exp: T) -> Self {
925 #[cfg(feature = "cuda")]
926 if self.device().is_gpu() {
927 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
928 let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
929 return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
930 }
931 let data = self.to_vec();
932 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
933 Self::from_vec(result, &self.shape).unwrap()
934 }
935
936 #[must_use]
938 pub fn gelu(&self) -> Self {
939 #[cfg(feature = "cuda")]
940 if self.device().is_gpu() {
941 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
942 return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
943 }
944 crate::ops::gelu(self)
945 }
946
947 #[must_use]
949 pub fn silu(&self) -> Self {
950 #[cfg(feature = "cuda")]
951 if self.device().is_gpu() {
952 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
953 return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
954 }
955 crate::ops::silu(self)
956 }
957
958 #[must_use]
960 pub fn softmax(&self, dim: i32) -> Self {
961 #[cfg(feature = "cuda")]
962 if self.device().is_gpu() {
963 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
964 let self_f32 = unsafe { gpu_ref(self) };
965 return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
966 }
967 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
968 }
969
970 #[must_use]
972 pub fn log_softmax(&self, dim: i32) -> Self {
973 let softmax_result = self.softmax(dim);
974 softmax_result.ln()
975 }
976
977 #[must_use]
979 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
980 let ndim = self.ndim();
981 let dim = if dim < 0 {
982 (ndim as i32 + dim) as usize
983 } else {
984 dim as usize
985 };
986
987 if dim >= ndim {
988 return self.clone();
989 }
990
991 #[cfg(feature = "cuda")]
993 if self.device().is_gpu() {
994 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
995 let self_f32 = unsafe { gpu_ref(self) };
996 let summed = if keepdim {
997 self_f32.sum_dim_keepdim_cuda(dim)
998 } else {
999 self_f32.sum_dim_cuda(dim)
1000 };
1001 let dim_size = self.shape[dim];
1002 let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1003 return unsafe { gpu_into(result) };
1004 }
1005
1006 let dim_size = self.shape[dim];
1007 let data = self.to_vec();
1008 let mut new_shape = self.shape.clone();
1009
1010 if keepdim {
1011 new_shape[dim] = 1;
1012 } else {
1013 new_shape.remove(dim);
1014 }
1015
1016 if new_shape.is_empty() {
1017 new_shape = smallvec::smallvec![1];
1018 }
1019
1020 let new_numel: usize = new_shape.iter().product();
1021 let mut result = vec![T::zero(); new_numel];
1022
1023 let outer_size: usize = self.shape[..dim].iter().product();
1024 let inner_size: usize = self.shape[dim + 1..].iter().product();
1025
1026 for outer in 0..outer_size {
1027 for inner in 0..inner_size {
1028 let mut sum = T::zero();
1029 for d in 0..dim_size {
1030 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1031 sum = sum + data[idx];
1032 }
1033 let mean = sum / NumCast::from(dim_size).unwrap();
1034 let result_idx = outer * inner_size + inner;
1035 result[result_idx] = mean;
1036 }
1037 }
1038
1039 Self::from_vec(result, &new_shape).unwrap()
1040 }
1041
1042 #[must_use]
1044 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1045 let ndim = self.ndim();
1046 let dim = if dim < 0 {
1047 (ndim as i32 + dim) as usize
1048 } else {
1049 dim as usize
1050 };
1051
1052 if dim >= ndim {
1053 return self.clone();
1054 }
1055
1056 #[cfg(feature = "cuda")]
1058 if self.device().is_gpu() {
1059 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1060 let self_f32 = unsafe { gpu_ref(self) };
1061 let result = if keepdim {
1062 self_f32.sum_dim_keepdim_cuda(dim)
1063 } else {
1064 self_f32.sum_dim_cuda(dim)
1065 };
1066 return unsafe { gpu_into(result) };
1067 }
1068
1069 let dim_size = self.shape[dim];
1070 let data = self.to_vec();
1071 let mut new_shape = self.shape.clone();
1072
1073 if keepdim {
1074 new_shape[dim] = 1;
1075 } else {
1076 new_shape.remove(dim);
1077 }
1078
1079 if new_shape.is_empty() {
1080 new_shape = smallvec::smallvec![1];
1081 }
1082
1083 let new_numel: usize = new_shape.iter().product();
1084 let mut result = vec![T::zero(); new_numel];
1085
1086 let outer_size: usize = self.shape[..dim].iter().product();
1087 let inner_size: usize = self.shape[dim + 1..].iter().product();
1088
1089 for outer in 0..outer_size {
1090 for inner in 0..inner_size {
1091 let mut sum = T::zero();
1092 for d in 0..dim_size {
1093 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1094 sum = sum + data[idx];
1095 }
1096 let result_idx = outer * inner_size + inner;
1097 result[result_idx] = sum;
1098 }
1099 }
1100
1101 Self::from_vec(result, &new_shape).unwrap()
1102 }
1103
1104 #[must_use]
1106 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1107 let mean = self.mean_dim(dim, true);
1109 let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1110 let mean_sq = sq.mean_dim(dim, keepdim);
1111 let mean_keepdim = if keepdim {
1112 mean.clone()
1113 } else {
1114 self.mean_dim(dim, keepdim)
1115 };
1116 let mean_squared = mean_keepdim
1117 .mul(&mean_keepdim)
1118 .unwrap_or_else(|_| mean_keepdim.clone());
1119 mean_sq
1120 .sub(&mean_squared)
1121 .unwrap_or_else(|_| mean_sq.clone())
1122 }
1123
1124 #[must_use]
1126 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1127 if self.shape.as_slice() == shape {
1128 return self.clone();
1129 }
1130
1131 #[cfg(feature = "cuda")]
1132 if self.device().is_gpu() {
1133 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1134 let self_f32 = unsafe { gpu_ref(self) };
1135 return unsafe {
1136 gpu_into(
1137 self_f32
1138 .broadcast_to_cuda(shape)
1139 .expect("CUDA broadcast_to failed"),
1140 )
1141 };
1142 }
1143
1144 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1145 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1146
1147 let total = numel(&result_shape);
1148 let mut result_data = vec![T::zero(); total];
1149 let self_data = self.storage.as_slice();
1150
1151 for i in 0..total {
1152 let indices = crate::shape::unravel_index(i, &result_shape);
1153 let self_idx = self.offset + linear_index(&indices, &self_strides);
1154 result_data[i] = self_data[self_idx];
1155 }
1156
1157 Self::from_vec(result_data, &result_shape).unwrap()
1158 }
1159
1160 #[must_use]
1162 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1163 let mut new_shape = Vec::with_capacity(self.ndim());
1164 for (i, range) in ranges.iter().enumerate() {
1165 if i < self.ndim() {
1166 new_shape.push(range.end - range.start);
1167 }
1168 }
1169 for i in ranges.len()..self.ndim() {
1171 new_shape.push(self.shape[i]);
1172 }
1173
1174 let new_numel: usize = new_shape.iter().product();
1175 let mut result_data = vec![T::zero(); new_numel];
1176 let self_data = self.to_vec();
1177
1178 let mut result_idx = 0;
1180 Self::slice_recursive(
1181 &self_data,
1182 &self.shape,
1183 ranges,
1184 0,
1185 0,
1186 &mut result_data,
1187 &mut result_idx,
1188 );
1189
1190 let out = Self::from_vec(result_data, &new_shape).unwrap();
1191 #[cfg(feature = "cuda")]
1192 if self.device().is_gpu() {
1193 return out.to_device(self.device()).unwrap();
1194 }
1195 out
1196 }
1197
1198 fn slice_recursive(
1199 data: &[T],
1200 shape: &[usize],
1201 ranges: &[std::ops::Range<usize>],
1202 dim: usize,
1203 offset: usize,
1204 result: &mut [T],
1205 result_idx: &mut usize,
1206 ) {
1207 if dim == shape.len() {
1208 result[*result_idx] = data[offset];
1209 *result_idx += 1;
1210 return;
1211 }
1212
1213 let stride: usize = shape[dim + 1..].iter().product();
1214 let (start, end) = if dim < ranges.len() {
1215 (ranges[dim].start, ranges[dim].end)
1216 } else {
1217 (0, shape[dim])
1218 };
1219
1220 for i in start..end {
1221 Self::slice_recursive(
1222 data,
1223 shape,
1224 ranges,
1225 dim + 1,
1226 offset + i * stride,
1227 result,
1228 result_idx,
1229 );
1230 }
1231 }
1232}
1233
1234impl<T: Numeric> Tensor<T> {
1239 pub fn add(&self, other: &Self) -> Result<Self> {
1241 #[cfg(feature = "cuda")]
1242 {
1243 let self_gpu = self.device().is_gpu();
1244 let other_gpu = other.device().is_gpu();
1245 if self_gpu || other_gpu {
1246 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1247 if self_gpu && other_gpu {
1248 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1249 if self.shape == other.shape {
1250 return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1251 } else {
1252 return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1253 }
1254 }
1255 let target_device = if self_gpu {
1257 self.device()
1258 } else {
1259 other.device()
1260 };
1261 let a_gpu = if self_gpu {
1262 self.clone()
1263 } else {
1264 self.to_device(target_device)?
1265 };
1266 let b_gpu = if other_gpu {
1267 other.clone()
1268 } else {
1269 other.to_device(target_device)?
1270 };
1271 return a_gpu.add(&b_gpu);
1272 }
1273 }
1274 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1275 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1276 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1277
1278 let total = numel(&result_shape);
1279 let mut result_data = vec![T::zero(); total];
1280
1281 let self_data = self.storage.as_slice();
1282 let other_data = other.storage.as_slice();
1283
1284 for i in 0..total {
1285 let indices = crate::shape::unravel_index(i, &result_shape);
1286 let self_idx = self.offset + linear_index(&indices, &self_strides);
1287 let other_idx = other.offset + linear_index(&indices, &other_strides);
1288 result_data[i] = self_data[self_idx] + other_data[other_idx];
1289 }
1290
1291 Self::from_vec(result_data, &result_shape)
1292 }
1293
1294 pub fn sub(&self, other: &Self) -> Result<Self> {
1296 #[cfg(feature = "cuda")]
1297 {
1298 let self_gpu = self.device().is_gpu();
1299 let other_gpu = other.device().is_gpu();
1300 if self_gpu || other_gpu {
1301 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1302 if self_gpu && other_gpu {
1303 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1304 if self.shape == other.shape {
1305 return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1306 } else {
1307 return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1308 }
1309 }
1310 let target = if self_gpu {
1311 self.device()
1312 } else {
1313 other.device()
1314 };
1315 let a_gpu = if self_gpu {
1316 self.clone()
1317 } else {
1318 self.to_device(target)?
1319 };
1320 let b_gpu = if other_gpu {
1321 other.clone()
1322 } else {
1323 other.to_device(target)?
1324 };
1325 return a_gpu.sub(&b_gpu);
1326 }
1327 }
1328 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1329 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1330 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1331
1332 let total = numel(&result_shape);
1333 let mut result_data = vec![T::zero(); total];
1334
1335 let self_data = self.storage.as_slice();
1336 let other_data = other.storage.as_slice();
1337
1338 for i in 0..total {
1339 let indices = crate::shape::unravel_index(i, &result_shape);
1340 let self_idx = self.offset + linear_index(&indices, &self_strides);
1341 let other_idx = other.offset + linear_index(&indices, &other_strides);
1342 result_data[i] = self_data[self_idx] - other_data[other_idx];
1343 }
1344
1345 Self::from_vec(result_data, &result_shape)
1346 }
1347
1348 pub fn mul(&self, other: &Self) -> Result<Self> {
1350 #[cfg(feature = "cuda")]
1351 {
1352 let self_gpu = self.device().is_gpu();
1353 let other_gpu = other.device().is_gpu();
1354 if self_gpu || other_gpu {
1355 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1356 if self_gpu && other_gpu {
1357 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1358 if self.shape == other.shape {
1359 return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1360 } else {
1361 return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1362 }
1363 }
1364 let target = if self_gpu {
1365 self.device()
1366 } else {
1367 other.device()
1368 };
1369 let a_gpu = if self_gpu {
1370 self.clone()
1371 } else {
1372 self.to_device(target)?
1373 };
1374 let b_gpu = if other_gpu {
1375 other.clone()
1376 } else {
1377 other.to_device(target)?
1378 };
1379 return a_gpu.mul(&b_gpu);
1380 }
1381 }
1382 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1383 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1384 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1385
1386 let total = numel(&result_shape);
1387 let mut result_data = vec![T::zero(); total];
1388
1389 let self_data = self.storage.as_slice();
1390 let other_data = other.storage.as_slice();
1391
1392 for i in 0..total {
1393 let indices = crate::shape::unravel_index(i, &result_shape);
1394 let self_idx = self.offset + linear_index(&indices, &self_strides);
1395 let other_idx = other.offset + linear_index(&indices, &other_strides);
1396 result_data[i] = self_data[self_idx] * other_data[other_idx];
1397 }
1398
1399 Self::from_vec(result_data, &result_shape)
1400 }
1401
1402 pub fn div(&self, other: &Self) -> Result<Self> {
1404 #[cfg(feature = "cuda")]
1405 {
1406 let self_gpu = self.device().is_gpu();
1407 let other_gpu = other.device().is_gpu();
1408 if self_gpu || other_gpu {
1409 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1410 if self_gpu && other_gpu {
1411 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1412 if self.shape == other.shape {
1413 return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1414 } else {
1415 return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1416 }
1417 }
1418 let target = if self_gpu {
1419 self.device()
1420 } else {
1421 other.device()
1422 };
1423 let a_gpu = if self_gpu {
1424 self.clone()
1425 } else {
1426 self.to_device(target)?
1427 };
1428 let b_gpu = if other_gpu {
1429 other.clone()
1430 } else {
1431 other.to_device(target)?
1432 };
1433 return a_gpu.div(&b_gpu);
1434 }
1435 }
1436 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1437 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1438 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1439
1440 let total = numel(&result_shape);
1441 let mut result_data = vec![T::zero(); total];
1442
1443 let self_data = self.storage.as_slice();
1444 let other_data = other.storage.as_slice();
1445
1446 for i in 0..total {
1447 let indices = crate::shape::unravel_index(i, &result_shape);
1448 let self_idx = self.offset + linear_index(&indices, &self_strides);
1449 let other_idx = other.offset + linear_index(&indices, &other_strides);
1450 result_data[i] = self_data[self_idx] / other_data[other_idx];
1451 }
1452
1453 Self::from_vec(result_data, &result_shape)
1454 }
1455
1456 #[must_use]
1458 pub fn add_scalar(&self, scalar: T) -> Self {
1459 #[cfg(feature = "cuda")]
1460 if self.device().is_gpu() {
1461 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1462 let self_f32 = unsafe { gpu_ref(self) };
1463 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1464 return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1465 }
1466 let data = self.to_vec();
1467 let mut result = vec![T::zero(); data.len()];
1468 CpuBackend::add_scalar(&mut result, &data, scalar);
1469 Self::from_vec(result, &self.shape).unwrap()
1470 }
1471
1472 #[must_use]
1474 pub fn mul_scalar(&self, scalar: T) -> Self {
1475 #[cfg(feature = "cuda")]
1476 if self.device().is_gpu() {
1477 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1478 let self_f32 = unsafe { gpu_ref(self) };
1479 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1480 return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1481 }
1482 let data = self.to_vec();
1483 let mut result = vec![T::zero(); data.len()];
1484 CpuBackend::mul_scalar(&mut result, &data, scalar);
1485 Self::from_vec(result, &self.shape).unwrap()
1486 }
1487
1488 #[must_use]
1490 pub fn neg(&self) -> Self {
1491 #[cfg(feature = "cuda")]
1492 if self.device().is_gpu() {
1493 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1494 let self_f32 = unsafe { gpu_ref(self) };
1495 return unsafe { gpu_into(self_f32.neg_cuda()) };
1496 }
1497 let data = self.to_vec();
1498 let mut result = vec![T::zero(); data.len()];
1499 CpuBackend::neg(&mut result, &data);
1500 Self::from_vec(result, &self.shape).unwrap()
1501 }
1502
1503 pub fn matmul(&self, other: &Self) -> Result<Self> {
1510 #[cfg(feature = "cuda")]
1511 if self.device().is_gpu() {
1512 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1513 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1514 return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1515 }
1516 if self.ndim() < 2 || other.ndim() < 2 {
1517 return Err(Error::invalid_operation(
1518 "matmul requires at least 2D tensors",
1519 ));
1520 }
1521
1522 let m = self.shape[self.ndim() - 2];
1523 let k1 = self.shape[self.ndim() - 1];
1524 let k2 = other.shape[other.ndim() - 2];
1525 let n = other.shape[other.ndim() - 1];
1526
1527 if k1 != k2 {
1528 return Err(Error::invalid_operation(format!(
1529 "matmul inner dimensions must match: {k1} vs {k2}"
1530 )));
1531 }
1532
1533 if self.ndim() == 2 && other.ndim() == 2 {
1535 let a_data = self.contiguous().to_vec();
1536 let b_data = other.contiguous().to_vec();
1537
1538 #[cfg(feature = "cuda")]
1542 {
1543 let flops = m * n * k1;
1544 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1545 && flops >= 4_000_000
1546 {
1547 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1548 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1549 if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1550 let c_t: Vec<T> = unsafe {
1551 let mut v = std::mem::ManuallyDrop::new(c_f32);
1552 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1553 };
1554 return Self::from_vec(c_t, &[m, n]);
1555 }
1556 }
1557 }
1558
1559 let mut c_data = vec![T::zero(); m * n];
1560 CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1561 return Self::from_vec(c_data, &[m, n]);
1562 }
1563
1564 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1566 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1567
1568 if batch_dims_self != batch_dims_other {
1569 return Err(Error::invalid_operation(format!(
1570 "matmul batch dimensions must match: {:?} vs {:?}",
1571 batch_dims_self, batch_dims_other
1572 )));
1573 }
1574
1575 let batch_size: usize = batch_dims_self.iter().product();
1576 let a_stride = m * k1;
1577 let b_stride = k1 * n;
1578 let c_stride = m * n;
1579
1580 let a_data = self.contiguous().to_vec();
1581 let b_data = other.contiguous().to_vec();
1582 let mut c_data = vec![T::zero(); batch_size * m * n];
1583
1584 #[cfg(feature = "cuda")]
1586 {
1587 let flops = m * n * k1;
1588 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1589 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1590 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1591 let mut gpu_ok = true;
1592 for batch in 0..batch_size {
1593 let a_slice = &a_f32[batch * a_stride..(batch + 1) * a_stride];
1594 let b_slice = &b_f32[batch * b_stride..(batch + 1) * b_stride];
1595 if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1596 c_data[batch * c_stride..(batch + 1) * c_stride]
1597 .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1598 } else {
1599 gpu_ok = false;
1600 break;
1601 }
1602 }
1603 if gpu_ok {
1604 let mut output_shape = batch_dims_self;
1605 output_shape.push(m);
1606 output_shape.push(n);
1607 return Self::from_vec(c_data, &output_shape);
1608 }
1609 c_data = vec![T::zero(); batch_size * m * n];
1611 }
1612 }
1613
1614 for batch in 0..batch_size {
1616 let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1617 let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1618 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1619 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1620 }
1621
1622 let mut output_shape = batch_dims_self;
1624 output_shape.push(m);
1625 output_shape.push(n);
1626
1627 Self::from_vec(c_data, &output_shape)
1628 }
1629
1630 pub fn dot(&self, other: &Self) -> Result<Self> {
1632 if self.ndim() != 1 || other.ndim() != 1 {
1633 return Err(Error::invalid_operation("dot requires 1D tensors"));
1634 }
1635
1636 if self.shape[0] != other.shape[0] {
1637 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1638 }
1639
1640 let a_data = self.to_vec();
1641 let b_data = other.to_vec();
1642 let result = CpuBackend::dot(&a_data, &b_data);
1643
1644 Ok(Self::scalar(result))
1645 }
1646}
1647
1648impl<T: Numeric> Add for &Tensor<T> {
1653 type Output = Tensor<T>;
1654
1655 fn add(self, other: Self) -> Self::Output {
1656 self.add(other).expect("Addition failed")
1657 }
1658}
1659
1660impl<T: Numeric> Sub for &Tensor<T> {
1661 type Output = Tensor<T>;
1662
1663 fn sub(self, other: Self) -> Self::Output {
1664 self.sub(other).expect("Subtraction failed")
1665 }
1666}
1667
1668impl<T: Numeric> Mul for &Tensor<T> {
1669 type Output = Tensor<T>;
1670
1671 fn mul(self, other: Self) -> Self::Output {
1672 self.mul(other).expect("Multiplication failed")
1673 }
1674}
1675
1676impl<T: Numeric> Div for &Tensor<T> {
1677 type Output = Tensor<T>;
1678
1679 fn div(self, other: Self) -> Self::Output {
1680 self.div(other).expect("Division failed")
1681 }
1682}
1683
1684impl<T: Numeric> Neg for &Tensor<T> {
1685 type Output = Tensor<T>;
1686
1687 fn neg(self) -> Self::Output {
1688 self.neg()
1689 }
1690}
1691
1692impl<T: Numeric> Add<T> for &Tensor<T> {
1694 type Output = Tensor<T>;
1695
1696 fn add(self, scalar: T) -> Self::Output {
1697 self.add_scalar(scalar)
1698 }
1699}
1700
1701impl<T: Numeric> Mul<T> for &Tensor<T> {
1702 type Output = Tensor<T>;
1703
1704 fn mul(self, scalar: T) -> Self::Output {
1705 self.mul_scalar(scalar)
1706 }
1707}
1708
1709impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1714 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1715 write!(
1716 f,
1717 "Tensor(shape={:?}, device={}",
1718 self.shape(),
1719 self.device()
1720 )?;
1721 if self.numel() <= 10 {
1722 write!(f, ", data={:?}", self.to_vec())?;
1723 }
1724 write!(f, ")")
1725 }
1726}
1727
1728impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1729 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1730 if self.is_scalar() {
1731 write!(f, "{}", self.item().unwrap())
1732 } else if self.ndim() == 1 {
1733 write!(f, "[")?;
1734 let data = self.to_vec();
1735 for (i, val) in data.iter().enumerate() {
1736 if i > 0 {
1737 write!(f, ", ")?;
1738 }
1739 write!(f, "{val}")?;
1740 }
1741 write!(f, "]")
1742 } else {
1743 write!(f, "Tensor(shape={:?})", self.shape())
1744 }
1745 }
1746}
1747
1748#[cfg(test)]
1753mod tests {
1754 use super::*;
1755
1756 #[test]
1757 fn test_from_vec() {
1758 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1759 assert_eq!(t.shape(), &[2, 3]);
1760 assert_eq!(t.numel(), 6);
1761 }
1762
1763 #[test]
1764 fn test_get_set() {
1765 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1766 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1767 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1768 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1769 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1770
1771 t.set(&[0, 0], 99.0).unwrap();
1772 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1773 }
1774
1775 #[test]
1776 fn test_reshape() {
1777 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1778 let r = t.reshape(&[3, 2]).unwrap();
1779 assert_eq!(r.shape(), &[3, 2]);
1780
1781 let r = t.reshape(&[-1]).unwrap();
1782 assert_eq!(r.shape(), &[6]);
1783 }
1784
1785 #[test]
1786 fn test_transpose() {
1787 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1788 let r = t.t().unwrap();
1789 assert_eq!(r.shape(), &[3, 2]);
1790 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1791 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1792 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1793 }
1794
1795 #[test]
1796 fn test_arithmetic() {
1797 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1798 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1799
1800 let c = &a + &b;
1801 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1802
1803 let d = &a * &b;
1804 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1805 }
1806
1807 #[test]
1808 fn test_broadcasting() {
1809 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1810 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1811
1812 let c = &a + &b;
1813 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1814 }
1815
1816 #[test]
1817 fn test_sum() {
1818 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1819 let s = t.sum();
1820 assert_eq!(s.item().unwrap(), 10.0);
1821 }
1822
1823 #[test]
1824 fn test_matmul() {
1825 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1827 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1828 let c = a.matmul(&b).unwrap();
1829
1830 assert_eq!(c.shape(), &[2, 2]);
1831 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1832 }
1833
1834 #[test]
1835 fn test_relu() {
1836 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1837 let r = t.relu();
1838 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1839 }
1840
1841 #[test]
1842 fn test_scalar() {
1843 let s = Tensor::<f32>::scalar(42.0);
1844 assert!(s.is_scalar());
1845 assert_eq!(s.numel(), 1);
1846 assert_eq!(s.item().unwrap(), 42.0);
1847 }
1848}