1use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::backends::CpuBackend;
21#[cfg(feature = "cuda")]
22use axonml_core::backends::CudaBackend;
23use axonml_core::dtype::{Float, Numeric, Scalar};
24use axonml_core::error::{Error, Result};
25use axonml_core::storage::Storage;
26use axonml_core::Device;
27use num_traits::NumCast;
28
29#[cfg(feature = "cuda")]
34mod cuda_accel {
35 use super::*;
36 use axonml_core::backends::cuda::get_cuda_backend;
37
38 pub fn get_cuda() -> Option<&'static CudaBackend> {
40 get_cuda_backend()
41 }
42
43 pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46 let cuda = get_cuda()?;
47
48 let a_gpu = cuda.htod_copy(a).ok()?;
49 let b_gpu = cuda.htod_copy(b).ok()?;
50 let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52 cuda.gemm_f32(
55 false, false,
56 n, m, k,
57 1.0,
58 &b_gpu, n,
59 &a_gpu, k,
60 0.0,
61 &mut c_gpu, n,
62 ).ok()?;
63
64 cuda.dtoh_copy(&c_gpu).ok()
65 }
66}
67
68use crate::shape::{
69 broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous, linear_index,
70 normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides, unsqueeze, Shape,
71 Strides,
72};
73
74#[cfg(feature = "cuda")]
82unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
83 &*(t as *const Tensor<T> as *const Tensor<f32>)
84}
85
86#[cfg(feature = "cuda")]
87unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
88 let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
89 std::mem::forget(t);
90 out
91}
92
93#[cfg(feature = "cuda")]
94fn is_f32<T: 'static>() -> bool {
95 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
96}
97
98#[derive(Clone)]
108pub struct Tensor<T: Scalar> {
109 pub(crate) storage: Storage<T>,
111 pub(crate) shape: Shape,
113 pub(crate) strides: Strides,
115 pub(crate) offset: usize,
117}
118
119impl<T: Scalar> Tensor<T> {
120 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
133 let total = numel(shape);
134 if total != storage.len() {
135 return Err(Error::shape_mismatch(&[storage.len()], shape));
136 }
137
138 let shape = Shape::from_slice(shape);
139 let strides = contiguous_strides(&shape);
140
141 Ok(Self {
142 storage,
143 shape,
144 strides,
145 offset: 0,
146 })
147 }
148
149 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
158 let storage = Storage::from_vec(data, Device::Cpu);
159 Self::from_storage(storage, shape)
160 }
161
162 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
171 let storage = Storage::from_slice(data, Device::Cpu);
172 Self::from_storage(storage, shape)
173 }
174
175 pub fn scalar(value: T) -> Self {
183 Self {
184 storage: Storage::from_vec(vec![value], Device::Cpu),
185 shape: Shape::new(),
186 strides: Strides::new(),
187 offset: 0,
188 }
189 }
190
191 #[must_use]
193 pub fn zeros(shape: &[usize]) -> Self {
194 crate::creation::zeros(shape)
195 }
196
197 #[must_use]
199 pub fn ones(shape: &[usize]) -> Self
200 where
201 T: Numeric,
202 {
203 crate::creation::ones(shape)
204 }
205
206 #[must_use]
208 pub fn full(shape: &[usize], value: T) -> Self {
209 crate::creation::full(shape, value)
210 }
211
212 #[must_use]
214 pub fn randn(shape: &[usize]) -> Self
215 where
216 T: Float,
217 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
218 {
219 crate::creation::randn(shape)
220 }
221
222 #[must_use]
224 pub fn rand(shape: &[usize]) -> Self
225 where
226 T: Float,
227 rand::distributions::Standard: rand::distributions::Distribution<T>,
228 {
229 crate::creation::rand(shape)
230 }
231
232 #[must_use]
238 pub fn shape(&self) -> &[usize] {
239 &self.shape
240 }
241
242 #[must_use]
244 pub fn strides(&self) -> &[isize] {
245 &self.strides
246 }
247
248 #[must_use]
250 pub fn ndim(&self) -> usize {
251 self.shape.len()
252 }
253
254 #[must_use]
256 pub fn numel(&self) -> usize {
257 numel(&self.shape)
258 }
259
260 #[must_use]
262 pub fn is_empty(&self) -> bool {
263 self.numel() == 0
264 }
265
266 pub fn size(&self, dim: i64) -> Result<usize> {
271 let idx = normalize_dim(dim, self.ndim())?;
272 Ok(self.shape[idx])
273 }
274
275 #[must_use]
277 pub fn device(&self) -> Device {
278 self.storage.device()
279 }
280
281 #[must_use]
283 pub fn is_contiguous(&self) -> bool {
284 is_contiguous(&self.shape, &self.strides)
285 }
286
287 #[must_use]
289 pub fn is_scalar(&self) -> bool {
290 self.shape.is_empty()
291 }
292
293 pub fn get(&self, indices: &[usize]) -> Result<T> {
302 if indices.len() != self.ndim() {
303 return Err(Error::invalid_operation(format!(
304 "Expected {} indices, got {}",
305 self.ndim(),
306 indices.len()
307 )));
308 }
309
310 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
311 if idx >= dim {
312 return Err(Error::IndexOutOfBounds {
313 index: idx,
314 size: dim,
315 });
316 }
317 }
318
319 let offset = self.offset + linear_index(indices, &self.strides);
320 Ok(self.storage.as_slice()[offset])
321 }
322
323 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
329 if indices.len() != self.ndim() {
330 return Err(Error::invalid_operation(format!(
331 "Expected {} indices, got {}",
332 self.ndim(),
333 indices.len()
334 )));
335 }
336
337 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
338 if idx >= dim {
339 return Err(Error::IndexOutOfBounds {
340 index: idx,
341 size: dim,
342 });
343 }
344 }
345
346 let offset = self.offset + linear_index(indices, &self.strides);
347 self.storage.as_slice_mut()[offset] = value;
348 Ok(())
349 }
350
351 pub fn item(&self) -> Result<T> {
353 if self.numel() != 1 {
354 return Err(Error::invalid_operation(
355 "item() only works on single-element tensors",
356 ));
357 }
358
359 if self.is_scalar() {
360 Ok(self.storage.as_slice()[self.offset])
361 } else {
362 let indices = vec![0; self.ndim()];
364 self.get(&indices)
365 }
366 }
367
368 #[must_use]
374 pub fn to_vec(&self) -> Vec<T> {
375 #[cfg(feature = "cuda")]
377 if self.storage.is_gpu() {
378 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
379 let self_f32 = unsafe { gpu_ref(self) };
380 let f32_vec = self_f32.to_vec_gpu();
381 unsafe {
382 let mut v = std::mem::ManuallyDrop::new(f32_vec);
383 return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
384 }
385 }
386
387 if self.is_contiguous() {
388 let storage = self.storage.as_slice();
389 storage[self.offset..self.offset + self.numel()].to_vec()
390 } else {
391 let mut result = Vec::with_capacity(self.numel());
392 self.copy_data_to(&mut result);
393 result
394 }
395 }
396
397 fn copy_data_to(&self, dst: &mut Vec<T>) {
399 dst.clear();
400 let storage = self.storage.as_slice();
401
402 let total = self.numel();
404 for i in 0..total {
405 let indices = crate::shape::unravel_index(i, &self.shape);
406 let offset = self.offset + linear_index(&indices, &self.strides);
407 dst.push(storage[offset]);
408 }
409 }
410
411 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
423 let shape = reshape(&self.shape, new_shape)?;
424
425 if self.is_contiguous() {
426 Ok(Self {
428 storage: self.storage.clone(),
429 strides: contiguous_strides(&shape),
430 shape,
431 offset: self.offset,
432 })
433 } else {
434 let contig = self.contiguous();
436 Ok(Self {
437 storage: contig.storage,
438 strides: contiguous_strides(&shape),
439 shape,
440 offset: 0,
441 })
442 }
443 }
444
445 #[must_use]
447 pub fn flatten(&self) -> Self {
448 self.reshape(&[-1]).expect("Flatten should never fail")
449 }
450
451 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
456 let dim = match dim {
457 Some(d) => Some(normalize_dim(d, self.ndim())?),
458 None => None,
459 };
460
461 let new_shape = squeeze(&self.shape, dim);
462 let new_strides: Strides = match dim {
463 Some(d) => {
464 let mut s = self.strides.clone();
465 if d < self.shape.len() && self.shape[d] == 1 {
466 s.remove(d);
467 }
468 s
469 }
470 None => self
471 .shape
472 .iter()
473 .zip(self.strides.iter())
474 .filter(|(&dim, _)| dim != 1)
475 .map(|(_, &stride)| stride)
476 .collect(),
477 };
478
479 Ok(Self {
480 storage: self.storage.clone(),
481 shape: new_shape,
482 strides: new_strides,
483 offset: self.offset,
484 })
485 }
486
487 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
492 let normalized = if dim < 0 {
493 (dim + self.ndim() as i64 + 1) as usize
494 } else {
495 dim as usize
496 };
497
498 let new_shape = unsqueeze(&self.shape, normalized)?;
499 let mut new_strides = Strides::with_capacity(new_shape.len());
500
501 for (i, _) in new_shape.iter().enumerate() {
502 if i < normalized {
503 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
504 } else if i == normalized {
505 new_strides.push(1);
507 } else {
508 new_strides.push(self.strides[i - 1]);
509 }
510 }
511
512 Ok(Self {
513 storage: self.storage.clone(),
514 shape: new_shape,
515 strides: new_strides,
516 offset: self.offset,
517 })
518 }
519
520 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
526 let d0 = normalize_dim(dim0, self.ndim())?;
527 let d1 = normalize_dim(dim1, self.ndim())?;
528
529 let new_shape = transpose_shape(&self.shape, d0, d1)?;
530 let new_strides = transpose_strides(&self.strides, d0, d1);
531
532 Ok(Self {
533 storage: self.storage.clone(),
534 shape: new_shape,
535 strides: new_strides,
536 offset: self.offset,
537 })
538 }
539
540 pub fn t(&self) -> Result<Self> {
542 if self.ndim() != 2 {
543 return Err(Error::invalid_operation("t() only works on 2D tensors"));
544 }
545 self.transpose(0, 1)
546 }
547
548 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
553 if dims.len() != self.ndim() {
554 return Err(Error::invalid_operation(format!(
555 "Expected {} dimensions, got {}",
556 self.ndim(),
557 dims.len()
558 )));
559 }
560
561 let mut seen = vec![false; self.ndim()];
563 for &d in dims {
564 if d >= self.ndim() {
565 return Err(Error::InvalidDimension {
566 index: d as i64,
567 ndim: self.ndim(),
568 });
569 }
570 if seen[d] {
571 return Err(Error::invalid_operation("Duplicate dimension in permute"));
572 }
573 seen[d] = true;
574 }
575
576 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
577 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
578
579 Ok(Self {
580 storage: self.storage.clone(),
581 shape: new_shape,
582 strides: new_strides,
583 offset: self.offset,
584 })
585 }
586
587 #[must_use]
589 pub fn contiguous(&self) -> Self {
590 if self.is_contiguous() && self.offset == 0 {
591 return self.clone();
592 }
593
594 #[cfg(feature = "cuda")]
595 if self.storage.is_gpu() {
596 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
597 let self_f32 = unsafe { gpu_ref(self) };
598 let result = self_f32.contiguous_gpu();
599 return unsafe { gpu_into(result) };
600 }
601
602 let data = self.to_vec();
603 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
604 }
605
606 pub fn to_device(&self, device: Device) -> Result<Self> {
615 if self.device() == device {
616 return Ok(self.clone());
617 }
618
619 #[cfg(feature = "cuda")]
620 if self.storage.is_gpu() || device.is_gpu() {
621 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
622 let self_f32 = unsafe { gpu_ref(self) };
623 let result = self_f32.to_device_f32(device)?;
624 return Ok(unsafe { gpu_into(result) });
625 }
626
627 let contig = self.contiguous();
628 let new_storage = contig.storage.to_device(device)?;
629
630 Ok(Self {
631 storage: new_storage,
632 shape: self.shape.clone(),
633 strides: self.strides.clone(),
634 offset: 0,
635 })
636 }
637
638 pub fn cpu(&self) -> Result<Self> {
640 self.to_device(Device::Cpu)
641 }
642
643 #[must_use]
649 pub fn clone_deep(&self) -> Self {
650 let data = self.to_vec();
651 let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
652 #[cfg(feature = "cuda")]
653 if self.device().is_gpu() {
654 return cpu.to_device(self.device()).unwrap();
655 }
656 cpu
657 }
658}
659
660impl<T: Numeric> Tensor<T> {
665 pub fn fill_(&self, value: T) {
667 #[cfg(feature = "cuda")]
668 if self.storage.is_gpu() {
669 panic!("fill_() not supported on GPU tensors — create a new tensor instead");
670 }
671 let mut data = self.storage.as_slice_mut();
672 CpuBackend::fill(&mut data, value);
673 }
674
675 pub fn zero_(&self) {
677 self.fill_(T::zero());
678 }
679
680 #[must_use]
686 pub fn sum(&self) -> Self {
687 let data = self.to_vec();
688 let result = CpuBackend::sum(&data);
689 let s = Self::scalar(result);
690 #[cfg(feature = "cuda")]
691 if self.device().is_gpu() {
692 return s.to_device(self.device()).unwrap();
693 }
694 s
695 }
696
697 #[must_use]
699 pub fn prod(&self) -> Self {
700 let data = self.to_vec();
701 let result = CpuBackend::prod(&data);
702 let s = Self::scalar(result);
703 #[cfg(feature = "cuda")]
704 if self.device().is_gpu() {
705 return s.to_device(self.device()).unwrap();
706 }
707 s
708 }
709
710 pub fn max(&self) -> Result<Self> {
712 if self.is_empty() {
713 return Err(Error::EmptyTensor);
714 }
715 let data = self.to_vec();
716 let result = CpuBackend::max(&data).unwrap();
717 let s = Self::scalar(result);
718 #[cfg(feature = "cuda")]
719 if self.device().is_gpu() {
720 return Ok(s.to_device(self.device()).unwrap());
721 }
722 Ok(s)
723 }
724
725 pub fn min(&self) -> Result<Self> {
727 if self.is_empty() {
728 return Err(Error::EmptyTensor);
729 }
730 let data = self.to_vec();
731 let result = CpuBackend::min(&data).unwrap();
732 let s = Self::scalar(result);
733 #[cfg(feature = "cuda")]
734 if self.device().is_gpu() {
735 return Ok(s.to_device(self.device()).unwrap());
736 }
737 Ok(s)
738 }
739
740 pub fn argmax(&self) -> Result<usize> {
742 if self.is_empty() {
743 return Err(Error::EmptyTensor);
744 }
745 let data = self.to_vec();
746 Ok(CpuBackend::argmax(&data).unwrap())
747 }
748
749 pub fn argmin(&self) -> Result<usize> {
751 if self.is_empty() {
752 return Err(Error::EmptyTensor);
753 }
754 let data = self.to_vec();
755 Ok(CpuBackend::argmin(&data).unwrap())
756 }
757
758 pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
762 if tensors.is_empty() {
763 return Err(Error::invalid_operation("cat requires at least one tensor"));
764 }
765 let ndim = tensors[0].ndim();
766 if dim >= ndim {
767 return Err(Error::invalid_operation("cat dimension out of range"));
768 }
769
770 for t in &tensors[1..] {
771 if t.ndim() != ndim {
772 return Err(Error::invalid_operation("cat: all tensors must have same ndim"));
773 }
774 for d in 0..ndim {
775 if d != dim && t.shape[d] != tensors[0].shape[d] {
776 return Err(Error::invalid_operation("cat: shapes must match on non-cat dims"));
777 }
778 }
779 }
780
781 let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
782 let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
783 out_shape[dim] = total_dim_size;
784
785 let outer_size: usize = out_shape[..dim].iter().product();
786 let inner_size: usize = out_shape[dim + 1..].iter().product();
787 let total_numel: usize = out_shape.iter().product();
788 let mut result = vec![T::zero(); total_numel];
789
790 let mut dim_offset = 0;
791 for t in tensors {
792 let t_data = t.contiguous().to_vec();
793 let t_dim_size = t.shape[dim];
794 for outer in 0..outer_size {
795 for d in 0..t_dim_size {
796 let src_base = outer * t_dim_size * inner_size + d * inner_size;
797 let dst_base = outer * total_dim_size * inner_size
798 + (dim_offset + d) * inner_size;
799 result[dst_base..dst_base + inner_size]
800 .copy_from_slice(&t_data[src_base..src_base + inner_size]);
801 }
802 }
803 dim_offset += t_dim_size;
804 }
805
806 let out = Self::from_vec(result, &out_shape)?;
807 #[cfg(feature = "cuda")]
808 if tensors[0].device().is_gpu() {
809 return Ok(out.to_device(tensors[0].device()).unwrap());
810 }
811 Ok(out)
812 }
813}
814
815impl<T: Float> Tensor<T> {
820 pub fn mean(&self) -> Result<Self> {
822 if self.is_empty() {
823 return Err(Error::EmptyTensor);
824 }
825 let data = self.to_vec();
826 let result = CpuBackend::mean(&data).unwrap();
827 let s = Self::scalar(result);
828 #[cfg(feature = "cuda")]
829 if self.device().is_gpu() {
830 return Ok(s.to_device(self.device()).unwrap());
831 }
832 Ok(s)
833 }
834
835 #[must_use]
841 pub fn relu(&self) -> Self {
842 #[cfg(feature = "cuda")]
843 if self.device().is_gpu() {
844 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
845 return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
846 }
847 let data = self.to_vec();
848 let mut result = vec![T::zero(); data.len()];
849 CpuBackend::relu(&mut result, &data);
850 Self::from_vec(result, &self.shape).unwrap()
851 }
852
853 #[must_use]
855 pub fn sigmoid(&self) -> Self {
856 #[cfg(feature = "cuda")]
857 if self.device().is_gpu() {
858 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
859 return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
860 }
861 let data = self.to_vec();
862 let mut result = vec![T::zero(); data.len()];
863 CpuBackend::sigmoid(&mut result, &data);
864 Self::from_vec(result, &self.shape).unwrap()
865 }
866
867 #[must_use]
869 pub fn tanh(&self) -> Self {
870 #[cfg(feature = "cuda")]
871 if self.device().is_gpu() {
872 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
873 return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
874 }
875 let data = self.to_vec();
876 let mut result = vec![T::zero(); data.len()];
877 CpuBackend::tanh(&mut result, &data);
878 Self::from_vec(result, &self.shape).unwrap()
879 }
880
881 #[must_use]
883 pub fn exp(&self) -> Self {
884 #[cfg(feature = "cuda")]
885 if self.device().is_gpu() {
886 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
887 return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
888 }
889 let data = self.to_vec();
890 let mut result = vec![T::zero(); data.len()];
891 CpuBackend::exp(&mut result, &data);
892 Self::from_vec(result, &self.shape).unwrap()
893 }
894
895 #[must_use]
897 pub fn ln(&self) -> Self {
898 #[cfg(feature = "cuda")]
899 if self.device().is_gpu() {
900 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
901 return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
902 }
903 let data = self.to_vec();
904 let mut result = vec![T::zero(); data.len()];
905 CpuBackend::ln(&mut result, &data);
906 Self::from_vec(result, &self.shape).unwrap()
907 }
908
909 #[must_use]
911 pub fn sqrt(&self) -> Self {
912 #[cfg(feature = "cuda")]
913 if self.device().is_gpu() {
914 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
915 return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
916 }
917 let data = self.to_vec();
918 let mut result = vec![T::zero(); data.len()];
919 CpuBackend::sqrt(&mut result, &data);
920 Self::from_vec(result, &self.shape).unwrap()
921 }
922
923 #[must_use]
925 pub fn pow(&self, exp: T) -> Self {
926 #[cfg(feature = "cuda")]
927 if self.device().is_gpu() {
928 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
929 let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
930 return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
931 }
932 let data = self.to_vec();
933 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
934 Self::from_vec(result, &self.shape).unwrap()
935 }
936
937 #[must_use]
939 pub fn gelu(&self) -> Self {
940 #[cfg(feature = "cuda")]
941 if self.device().is_gpu() {
942 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
943 return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
944 }
945 crate::ops::gelu(self)
946 }
947
948 #[must_use]
950 pub fn silu(&self) -> Self {
951 #[cfg(feature = "cuda")]
952 if self.device().is_gpu() {
953 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
954 return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
955 }
956 crate::ops::silu(self)
957 }
958
959 #[must_use]
961 pub fn softmax(&self, dim: i32) -> Self {
962 #[cfg(feature = "cuda")]
963 if self.device().is_gpu() {
964 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
965 let self_f32 = unsafe { gpu_ref(self) };
966 return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
967 }
968 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
969 }
970
971 #[must_use]
973 pub fn log_softmax(&self, dim: i32) -> Self {
974 let softmax_result = self.softmax(dim);
975 softmax_result.ln()
976 }
977
978 #[must_use]
980 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
981 let ndim = self.ndim();
982 let dim = if dim < 0 {
983 (ndim as i32 + dim) as usize
984 } else {
985 dim as usize
986 };
987
988 if dim >= ndim {
989 return self.clone();
990 }
991
992 #[cfg(feature = "cuda")]
994 if self.device().is_gpu() {
995 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
996 let self_f32 = unsafe { gpu_ref(self) };
997 let summed = if keepdim {
998 self_f32.sum_dim_keepdim_cuda(dim)
999 } else {
1000 self_f32.sum_dim_cuda(dim)
1001 };
1002 let dim_size = self.shape[dim];
1003 let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1004 return unsafe { gpu_into(result) };
1005 }
1006
1007 let dim_size = self.shape[dim];
1008 let data = self.to_vec();
1009 let mut new_shape = self.shape.clone();
1010
1011 if keepdim {
1012 new_shape[dim] = 1;
1013 } else {
1014 new_shape.remove(dim);
1015 }
1016
1017 if new_shape.is_empty() {
1018 new_shape = smallvec::smallvec![1];
1019 }
1020
1021 let new_numel: usize = new_shape.iter().product();
1022 let mut result = vec![T::zero(); new_numel];
1023
1024 let outer_size: usize = self.shape[..dim].iter().product();
1025 let inner_size: usize = self.shape[dim + 1..].iter().product();
1026
1027 for outer in 0..outer_size {
1028 for inner in 0..inner_size {
1029 let mut sum = T::zero();
1030 for d in 0..dim_size {
1031 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1032 sum = sum + data[idx];
1033 }
1034 let mean = sum / NumCast::from(dim_size).unwrap();
1035 let result_idx = outer * inner_size + inner;
1036 result[result_idx] = mean;
1037 }
1038 }
1039
1040 Self::from_vec(result, &new_shape).unwrap()
1041 }
1042
1043 #[must_use]
1045 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1046 let ndim = self.ndim();
1047 let dim = if dim < 0 {
1048 (ndim as i32 + dim) as usize
1049 } else {
1050 dim as usize
1051 };
1052
1053 if dim >= ndim {
1054 return self.clone();
1055 }
1056
1057 #[cfg(feature = "cuda")]
1059 if self.device().is_gpu() {
1060 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1061 let self_f32 = unsafe { gpu_ref(self) };
1062 let result = if keepdim {
1063 self_f32.sum_dim_keepdim_cuda(dim)
1064 } else {
1065 self_f32.sum_dim_cuda(dim)
1066 };
1067 return unsafe { gpu_into(result) };
1068 }
1069
1070 let dim_size = self.shape[dim];
1071 let data = self.to_vec();
1072 let mut new_shape = self.shape.clone();
1073
1074 if keepdim {
1075 new_shape[dim] = 1;
1076 } else {
1077 new_shape.remove(dim);
1078 }
1079
1080 if new_shape.is_empty() {
1081 new_shape = smallvec::smallvec![1];
1082 }
1083
1084 let new_numel: usize = new_shape.iter().product();
1085 let mut result = vec![T::zero(); new_numel];
1086
1087 let outer_size: usize = self.shape[..dim].iter().product();
1088 let inner_size: usize = self.shape[dim + 1..].iter().product();
1089
1090 for outer in 0..outer_size {
1091 for inner in 0..inner_size {
1092 let mut sum = T::zero();
1093 for d in 0..dim_size {
1094 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1095 sum = sum + data[idx];
1096 }
1097 let result_idx = outer * inner_size + inner;
1098 result[result_idx] = sum;
1099 }
1100 }
1101
1102 Self::from_vec(result, &new_shape).unwrap()
1103 }
1104
1105 #[must_use]
1107 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1108 let mean = self.mean_dim(dim, true);
1110 let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1111 let mean_sq = sq.mean_dim(dim, keepdim);
1112 let mean_keepdim = if keepdim { mean.clone() } else { self.mean_dim(dim, keepdim) };
1113 let mean_squared = mean_keepdim.mul(&mean_keepdim).unwrap_or_else(|_| mean_keepdim.clone());
1114 mean_sq.sub(&mean_squared).unwrap_or_else(|_| mean_sq.clone())
1115 }
1116
1117 #[must_use]
1119 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1120 if self.shape.as_slice() == shape {
1121 return self.clone();
1122 }
1123
1124 #[cfg(feature = "cuda")]
1125 if self.device().is_gpu() {
1126 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1127 let self_f32 = unsafe { gpu_ref(self) };
1128 return unsafe { gpu_into(self_f32.broadcast_to_cuda(shape).expect("CUDA broadcast_to failed")) };
1129 }
1130
1131 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1132 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1133
1134 let total = numel(&result_shape);
1135 let mut result_data = vec![T::zero(); total];
1136 let self_data = self.storage.as_slice();
1137
1138 for i in 0..total {
1139 let indices = crate::shape::unravel_index(i, &result_shape);
1140 let self_idx = self.offset + linear_index(&indices, &self_strides);
1141 result_data[i] = self_data[self_idx];
1142 }
1143
1144 Self::from_vec(result_data, &result_shape).unwrap()
1145 }
1146
1147 #[must_use]
1149 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1150 let mut new_shape = Vec::with_capacity(self.ndim());
1151 for (i, range) in ranges.iter().enumerate() {
1152 if i < self.ndim() {
1153 new_shape.push(range.end - range.start);
1154 }
1155 }
1156 for i in ranges.len()..self.ndim() {
1158 new_shape.push(self.shape[i]);
1159 }
1160
1161 let new_numel: usize = new_shape.iter().product();
1162 let mut result_data = vec![T::zero(); new_numel];
1163 let self_data = self.to_vec();
1164
1165 let mut result_idx = 0;
1167 Self::slice_recursive(
1168 &self_data,
1169 &self.shape,
1170 ranges,
1171 0,
1172 0,
1173 &mut result_data,
1174 &mut result_idx,
1175 );
1176
1177 let out = Self::from_vec(result_data, &new_shape).unwrap();
1178 #[cfg(feature = "cuda")]
1179 if self.device().is_gpu() {
1180 return out.to_device(self.device()).unwrap();
1181 }
1182 out
1183 }
1184
1185 fn slice_recursive(
1186 data: &[T],
1187 shape: &[usize],
1188 ranges: &[std::ops::Range<usize>],
1189 dim: usize,
1190 offset: usize,
1191 result: &mut [T],
1192 result_idx: &mut usize,
1193 ) {
1194 if dim == shape.len() {
1195 result[*result_idx] = data[offset];
1196 *result_idx += 1;
1197 return;
1198 }
1199
1200 let stride: usize = shape[dim + 1..].iter().product();
1201 let (start, end) = if dim < ranges.len() {
1202 (ranges[dim].start, ranges[dim].end)
1203 } else {
1204 (0, shape[dim])
1205 };
1206
1207 for i in start..end {
1208 Self::slice_recursive(
1209 data,
1210 shape,
1211 ranges,
1212 dim + 1,
1213 offset + i * stride,
1214 result,
1215 result_idx,
1216 );
1217 }
1218 }
1219}
1220
1221impl<T: Numeric> Tensor<T> {
1226 pub fn add(&self, other: &Self) -> Result<Self> {
1228 #[cfg(feature = "cuda")]
1229 {
1230 let self_gpu = self.device().is_gpu();
1231 let other_gpu = other.device().is_gpu();
1232 if self_gpu || other_gpu {
1233 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1234 if self_gpu && other_gpu {
1235 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1236 if self.shape == other.shape {
1237 return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1238 } else {
1239 return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1240 }
1241 }
1242 let target_device = if self_gpu { self.device() } else { other.device() };
1244 let a_gpu = if self_gpu { self.clone() } else { self.to_device(target_device)? };
1245 let b_gpu = if other_gpu { other.clone() } else { other.to_device(target_device)? };
1246 return a_gpu.add(&b_gpu);
1247 }
1248 }
1249 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1250 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1251 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1252
1253 let total = numel(&result_shape);
1254 let mut result_data = vec![T::zero(); total];
1255
1256 let self_data = self.storage.as_slice();
1257 let other_data = other.storage.as_slice();
1258
1259 for i in 0..total {
1260 let indices = crate::shape::unravel_index(i, &result_shape);
1261 let self_idx = self.offset + linear_index(&indices, &self_strides);
1262 let other_idx = other.offset + linear_index(&indices, &other_strides);
1263 result_data[i] = self_data[self_idx] + other_data[other_idx];
1264 }
1265
1266 Self::from_vec(result_data, &result_shape)
1267 }
1268
1269 pub fn sub(&self, other: &Self) -> Result<Self> {
1271 #[cfg(feature = "cuda")]
1272 {
1273 let self_gpu = self.device().is_gpu();
1274 let other_gpu = other.device().is_gpu();
1275 if self_gpu || other_gpu {
1276 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1277 if self_gpu && other_gpu {
1278 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1279 if self.shape == other.shape {
1280 return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1281 } else {
1282 return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1283 }
1284 }
1285 let target = if self_gpu { self.device() } else { other.device() };
1286 let a_gpu = if self_gpu { self.clone() } else { self.to_device(target)? };
1287 let b_gpu = if other_gpu { other.clone() } else { other.to_device(target)? };
1288 return a_gpu.sub(&b_gpu);
1289 }
1290 }
1291 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1292 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1293 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1294
1295 let total = numel(&result_shape);
1296 let mut result_data = vec![T::zero(); total];
1297
1298 let self_data = self.storage.as_slice();
1299 let other_data = other.storage.as_slice();
1300
1301 for i in 0..total {
1302 let indices = crate::shape::unravel_index(i, &result_shape);
1303 let self_idx = self.offset + linear_index(&indices, &self_strides);
1304 let other_idx = other.offset + linear_index(&indices, &other_strides);
1305 result_data[i] = self_data[self_idx] - other_data[other_idx];
1306 }
1307
1308 Self::from_vec(result_data, &result_shape)
1309 }
1310
1311 pub fn mul(&self, other: &Self) -> Result<Self> {
1313 #[cfg(feature = "cuda")]
1314 {
1315 let self_gpu = self.device().is_gpu();
1316 let other_gpu = other.device().is_gpu();
1317 if self_gpu || other_gpu {
1318 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1319 if self_gpu && other_gpu {
1320 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1321 if self.shape == other.shape {
1322 return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1323 } else {
1324 return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1325 }
1326 }
1327 let target = if self_gpu { self.device() } else { other.device() };
1328 let a_gpu = if self_gpu { self.clone() } else { self.to_device(target)? };
1329 let b_gpu = if other_gpu { other.clone() } else { other.to_device(target)? };
1330 return a_gpu.mul(&b_gpu);
1331 }
1332 }
1333 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1334 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1335 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1336
1337 let total = numel(&result_shape);
1338 let mut result_data = vec![T::zero(); total];
1339
1340 let self_data = self.storage.as_slice();
1341 let other_data = other.storage.as_slice();
1342
1343 for i in 0..total {
1344 let indices = crate::shape::unravel_index(i, &result_shape);
1345 let self_idx = self.offset + linear_index(&indices, &self_strides);
1346 let other_idx = other.offset + linear_index(&indices, &other_strides);
1347 result_data[i] = self_data[self_idx] * other_data[other_idx];
1348 }
1349
1350 Self::from_vec(result_data, &result_shape)
1351 }
1352
1353 pub fn div(&self, other: &Self) -> Result<Self> {
1355 #[cfg(feature = "cuda")]
1356 {
1357 let self_gpu = self.device().is_gpu();
1358 let other_gpu = other.device().is_gpu();
1359 if self_gpu || other_gpu {
1360 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1361 if self_gpu && other_gpu {
1362 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1363 if self.shape == other.shape {
1364 return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1365 } else {
1366 return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1367 }
1368 }
1369 let target = if self_gpu { self.device() } else { other.device() };
1370 let a_gpu = if self_gpu { self.clone() } else { self.to_device(target)? };
1371 let b_gpu = if other_gpu { other.clone() } else { other.to_device(target)? };
1372 return a_gpu.div(&b_gpu);
1373 }
1374 }
1375 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1376 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1377 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1378
1379 let total = numel(&result_shape);
1380 let mut result_data = vec![T::zero(); total];
1381
1382 let self_data = self.storage.as_slice();
1383 let other_data = other.storage.as_slice();
1384
1385 for i in 0..total {
1386 let indices = crate::shape::unravel_index(i, &result_shape);
1387 let self_idx = self.offset + linear_index(&indices, &self_strides);
1388 let other_idx = other.offset + linear_index(&indices, &other_strides);
1389 result_data[i] = self_data[self_idx] / other_data[other_idx];
1390 }
1391
1392 Self::from_vec(result_data, &result_shape)
1393 }
1394
1395 #[must_use]
1397 pub fn add_scalar(&self, scalar: T) -> Self {
1398 #[cfg(feature = "cuda")]
1399 if self.device().is_gpu() {
1400 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1401 let self_f32 = unsafe { gpu_ref(self) };
1402 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1403 return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1404 }
1405 let data = self.to_vec();
1406 let mut result = vec![T::zero(); data.len()];
1407 CpuBackend::add_scalar(&mut result, &data, scalar);
1408 Self::from_vec(result, &self.shape).unwrap()
1409 }
1410
1411 #[must_use]
1413 pub fn mul_scalar(&self, scalar: T) -> Self {
1414 #[cfg(feature = "cuda")]
1415 if self.device().is_gpu() {
1416 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1417 let self_f32 = unsafe { gpu_ref(self) };
1418 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1419 return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1420 }
1421 let data = self.to_vec();
1422 let mut result = vec![T::zero(); data.len()];
1423 CpuBackend::mul_scalar(&mut result, &data, scalar);
1424 Self::from_vec(result, &self.shape).unwrap()
1425 }
1426
1427 #[must_use]
1429 pub fn neg(&self) -> Self {
1430 #[cfg(feature = "cuda")]
1431 if self.device().is_gpu() {
1432 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1433 let self_f32 = unsafe { gpu_ref(self) };
1434 return unsafe { gpu_into(self_f32.neg_cuda()) };
1435 }
1436 let data = self.to_vec();
1437 let mut result = vec![T::zero(); data.len()];
1438 CpuBackend::neg(&mut result, &data);
1439 Self::from_vec(result, &self.shape).unwrap()
1440 }
1441
1442 pub fn matmul(&self, other: &Self) -> Result<Self> {
1449 #[cfg(feature = "cuda")]
1450 if self.device().is_gpu() {
1451 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1452 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1453 return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1454 }
1455 if self.ndim() < 2 || other.ndim() < 2 {
1456 return Err(Error::invalid_operation(
1457 "matmul requires at least 2D tensors",
1458 ));
1459 }
1460
1461 let m = self.shape[self.ndim() - 2];
1462 let k1 = self.shape[self.ndim() - 1];
1463 let k2 = other.shape[other.ndim() - 2];
1464 let n = other.shape[other.ndim() - 1];
1465
1466 if k1 != k2 {
1467 return Err(Error::invalid_operation(format!(
1468 "matmul inner dimensions must match: {k1} vs {k2}"
1469 )));
1470 }
1471
1472 if self.ndim() == 2 && other.ndim() == 2 {
1474 let a_data = self.contiguous().to_vec();
1475 let b_data = other.contiguous().to_vec();
1476
1477 #[cfg(feature = "cuda")]
1481 {
1482 let flops = m * n * k1;
1483 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1484 && flops >= 4_000_000
1485 {
1486 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1487 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1488 if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1489 let c_t: Vec<T> = unsafe {
1490 let mut v = std::mem::ManuallyDrop::new(c_f32);
1491 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1492 };
1493 return Self::from_vec(c_t, &[m, n]);
1494 }
1495 }
1496 }
1497
1498 let mut c_data = vec![T::zero(); m * n];
1499 CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1500 return Self::from_vec(c_data, &[m, n]);
1501 }
1502
1503 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1505 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1506
1507 if batch_dims_self != batch_dims_other {
1508 return Err(Error::invalid_operation(format!(
1509 "matmul batch dimensions must match: {:?} vs {:?}",
1510 batch_dims_self, batch_dims_other
1511 )));
1512 }
1513
1514 let batch_size: usize = batch_dims_self.iter().product();
1515 let a_stride = m * k1;
1516 let b_stride = k1 * n;
1517 let c_stride = m * n;
1518
1519 let a_data = self.contiguous().to_vec();
1520 let b_data = other.contiguous().to_vec();
1521 let mut c_data = vec![T::zero(); batch_size * m * n];
1522
1523 #[cfg(feature = "cuda")]
1525 {
1526 let flops = m * n * k1;
1527 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1528 && flops >= 4_000_000
1529 {
1530 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1531 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1532 let mut gpu_ok = true;
1533 for batch in 0..batch_size {
1534 let a_slice = &a_f32[batch * a_stride..(batch + 1) * a_stride];
1535 let b_slice = &b_f32[batch * b_stride..(batch + 1) * b_stride];
1536 if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1537 c_data[batch * c_stride..(batch + 1) * c_stride]
1538 .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1539 } else {
1540 gpu_ok = false;
1541 break;
1542 }
1543 }
1544 if gpu_ok {
1545 let mut output_shape = batch_dims_self;
1546 output_shape.push(m);
1547 output_shape.push(n);
1548 return Self::from_vec(c_data, &output_shape);
1549 }
1550 c_data = vec![T::zero(); batch_size * m * n];
1552 }
1553 }
1554
1555 for batch in 0..batch_size {
1557 let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1558 let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1559 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1560 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1561 }
1562
1563 let mut output_shape = batch_dims_self;
1565 output_shape.push(m);
1566 output_shape.push(n);
1567
1568 Self::from_vec(c_data, &output_shape)
1569 }
1570
1571 pub fn dot(&self, other: &Self) -> Result<Self> {
1573 if self.ndim() != 1 || other.ndim() != 1 {
1574 return Err(Error::invalid_operation("dot requires 1D tensors"));
1575 }
1576
1577 if self.shape[0] != other.shape[0] {
1578 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1579 }
1580
1581 let a_data = self.to_vec();
1582 let b_data = other.to_vec();
1583 let result = CpuBackend::dot(&a_data, &b_data);
1584
1585 Ok(Self::scalar(result))
1586 }
1587}
1588
1589impl<T: Numeric> Add for &Tensor<T> {
1594 type Output = Tensor<T>;
1595
1596 fn add(self, other: Self) -> Self::Output {
1597 self.add(other).expect("Addition failed")
1598 }
1599}
1600
1601impl<T: Numeric> Sub for &Tensor<T> {
1602 type Output = Tensor<T>;
1603
1604 fn sub(self, other: Self) -> Self::Output {
1605 self.sub(other).expect("Subtraction failed")
1606 }
1607}
1608
1609impl<T: Numeric> Mul for &Tensor<T> {
1610 type Output = Tensor<T>;
1611
1612 fn mul(self, other: Self) -> Self::Output {
1613 self.mul(other).expect("Multiplication failed")
1614 }
1615}
1616
1617impl<T: Numeric> Div for &Tensor<T> {
1618 type Output = Tensor<T>;
1619
1620 fn div(self, other: Self) -> Self::Output {
1621 self.div(other).expect("Division failed")
1622 }
1623}
1624
1625impl<T: Numeric> Neg for &Tensor<T> {
1626 type Output = Tensor<T>;
1627
1628 fn neg(self) -> Self::Output {
1629 self.neg()
1630 }
1631}
1632
1633impl<T: Numeric> Add<T> for &Tensor<T> {
1635 type Output = Tensor<T>;
1636
1637 fn add(self, scalar: T) -> Self::Output {
1638 self.add_scalar(scalar)
1639 }
1640}
1641
1642impl<T: Numeric> Mul<T> for &Tensor<T> {
1643 type Output = Tensor<T>;
1644
1645 fn mul(self, scalar: T) -> Self::Output {
1646 self.mul_scalar(scalar)
1647 }
1648}
1649
1650impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1655 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1656 write!(
1657 f,
1658 "Tensor(shape={:?}, device={}",
1659 self.shape(),
1660 self.device()
1661 )?;
1662 if self.numel() <= 10 {
1663 write!(f, ", data={:?}", self.to_vec())?;
1664 }
1665 write!(f, ")")
1666 }
1667}
1668
1669impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1670 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1671 if self.is_scalar() {
1672 write!(f, "{}", self.item().unwrap())
1673 } else if self.ndim() == 1 {
1674 write!(f, "[")?;
1675 let data = self.to_vec();
1676 for (i, val) in data.iter().enumerate() {
1677 if i > 0 {
1678 write!(f, ", ")?;
1679 }
1680 write!(f, "{val}")?;
1681 }
1682 write!(f, "]")
1683 } else {
1684 write!(f, "Tensor(shape={:?})", self.shape())
1685 }
1686 }
1687}
1688
1689#[cfg(test)]
1694mod tests {
1695 use super::*;
1696
1697 #[test]
1698 fn test_from_vec() {
1699 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1700 assert_eq!(t.shape(), &[2, 3]);
1701 assert_eq!(t.numel(), 6);
1702 }
1703
1704 #[test]
1705 fn test_get_set() {
1706 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1707 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1708 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1709 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1710 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1711
1712 t.set(&[0, 0], 99.0).unwrap();
1713 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1714 }
1715
1716 #[test]
1717 fn test_reshape() {
1718 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1719 let r = t.reshape(&[3, 2]).unwrap();
1720 assert_eq!(r.shape(), &[3, 2]);
1721
1722 let r = t.reshape(&[-1]).unwrap();
1723 assert_eq!(r.shape(), &[6]);
1724 }
1725
1726 #[test]
1727 fn test_transpose() {
1728 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1729 let r = t.t().unwrap();
1730 assert_eq!(r.shape(), &[3, 2]);
1731 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1732 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1733 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1734 }
1735
1736 #[test]
1737 fn test_arithmetic() {
1738 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1739 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1740
1741 let c = &a + &b;
1742 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1743
1744 let d = &a * &b;
1745 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1746 }
1747
1748 #[test]
1749 fn test_broadcasting() {
1750 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1751 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1752
1753 let c = &a + &b;
1754 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1755 }
1756
1757 #[test]
1758 fn test_sum() {
1759 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1760 let s = t.sum();
1761 assert_eq!(s.item().unwrap(), 10.0);
1762 }
1763
1764 #[test]
1765 fn test_matmul() {
1766 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1768 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1769 let c = a.matmul(&b).unwrap();
1770
1771 assert_eq!(c.shape(), &[2, 2]);
1772 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1773 }
1774
1775 #[test]
1776 fn test_relu() {
1777 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1778 let r = t.relu();
1779 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1780 }
1781
1782 #[test]
1783 fn test_scalar() {
1784 let s = Tensor::<f32>::scalar(42.0);
1785 assert!(s.is_scalar());
1786 assert_eq!(s.numel(), 1);
1787 assert_eq!(s.item().unwrap(), 42.0);
1788 }
1789}