1use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::Device;
21use axonml_core::backends::CpuBackend;
22#[cfg(feature = "cuda")]
23use axonml_core::backends::CudaBackend;
24use axonml_core::dtype::{Float, Numeric, Scalar};
25use axonml_core::error::{Error, Result};
26use axonml_core::storage::Storage;
27use num_traits::NumCast;
28
29#[cfg(feature = "cuda")]
34mod cuda_accel {
35 use super::*;
36 use axonml_core::backends::cuda::get_cuda_backend;
37
38 pub fn get_cuda() -> Option<&'static CudaBackend> {
40 get_cuda_backend()
41 }
42
43 pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
46 let cuda = get_cuda()?;
47
48 let a_gpu = cuda.htod_copy(a).ok()?;
49 let b_gpu = cuda.htod_copy(b).ok()?;
50 let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
51
52 cuda.gemm_f32(
55 false, false, n, m, k, 1.0, &b_gpu, n, &a_gpu, k, 0.0, &mut c_gpu, n,
56 )
57 .ok()?;
58
59 cuda.dtoh_copy(&c_gpu).ok()
60 }
61}
62
63use crate::shape::{
64 Shape, Strides, broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous,
65 linear_index, normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides,
66 unsqueeze,
67};
68
69#[cfg(feature = "cuda")]
77unsafe fn gpu_ref<T: Scalar>(t: &Tensor<T>) -> &Tensor<f32> {
78 assert!(
79 is_f32::<T>(),
80 "gpu_ref: only Tensor<f32> can be used for GPU operations, got {:?}",
81 T::DTYPE
82 );
83 unsafe { &*(t as *const Tensor<T> as *const Tensor<f32>) }
85}
86
87#[cfg(feature = "cuda")]
88unsafe fn gpu_into<T: Scalar>(t: Tensor<f32>) -> Tensor<T> {
89 assert!(
90 is_f32::<T>(),
91 "gpu_into: only Tensor<f32> can be produced from GPU operations, got {:?}",
92 T::DTYPE
93 );
94 unsafe {
96 let out = std::ptr::read(&t as *const Tensor<f32> as *const Tensor<T>);
97 std::mem::forget(t);
98 out
99 }
100}
101
102#[cfg(feature = "cuda")]
103fn is_f32<T: 'static>() -> bool {
104 std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
105}
106
107#[derive(Clone)]
117pub struct Tensor<T: Scalar> {
118 pub(crate) storage: Storage<T>,
120 pub(crate) shape: Shape,
122 pub(crate) strides: Strides,
124 pub(crate) offset: usize,
126}
127
128impl<T: Scalar> Tensor<T> {
129 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
142 let total = numel(shape);
143 if total != storage.len() {
144 return Err(Error::shape_mismatch(&[storage.len()], shape));
145 }
146
147 let shape = Shape::from_slice(shape);
148 let strides = contiguous_strides(&shape);
149
150 Ok(Self {
151 storage,
152 shape,
153 strides,
154 offset: 0,
155 })
156 }
157
158 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
167 let storage = Storage::from_vec(data, Device::Cpu);
168 Self::from_storage(storage, shape)
169 }
170
171 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
180 let storage = Storage::from_slice(data, Device::Cpu);
181 Self::from_storage(storage, shape)
182 }
183
184 pub fn scalar(value: T) -> Self {
192 Self {
193 storage: Storage::from_vec(vec![value], Device::Cpu),
194 shape: Shape::new(),
195 strides: Strides::new(),
196 offset: 0,
197 }
198 }
199
200 #[must_use]
202 pub fn zeros(shape: &[usize]) -> Self {
203 crate::creation::zeros(shape)
204 }
205
206 #[must_use]
208 pub fn ones(shape: &[usize]) -> Self
209 where
210 T: Numeric,
211 {
212 crate::creation::ones(shape)
213 }
214
215 #[must_use]
217 pub fn full(shape: &[usize], value: T) -> Self {
218 crate::creation::full(shape, value)
219 }
220
221 #[must_use]
223 pub fn randn(shape: &[usize]) -> Self
224 where
225 T: Float,
226 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
227 {
228 crate::creation::randn(shape)
229 }
230
231 #[must_use]
233 pub fn rand(shape: &[usize]) -> Self
234 where
235 T: Float,
236 rand::distributions::Standard: rand::distributions::Distribution<T>,
237 {
238 crate::creation::rand(shape)
239 }
240
241 #[must_use]
247 pub fn shape(&self) -> &[usize] {
248 &self.shape
249 }
250
251 #[must_use]
253 pub fn strides(&self) -> &[isize] {
254 &self.strides
255 }
256
257 #[must_use]
259 pub fn ndim(&self) -> usize {
260 self.shape.len()
261 }
262
263 #[must_use]
265 pub fn numel(&self) -> usize {
266 numel(&self.shape)
267 }
268
269 #[must_use]
271 pub fn is_empty(&self) -> bool {
272 self.numel() == 0
273 }
274
275 pub fn size(&self, dim: i64) -> Result<usize> {
280 let idx = normalize_dim(dim, self.ndim())?;
281 Ok(self.shape[idx])
282 }
283
284 #[must_use]
286 pub fn device(&self) -> Device {
287 self.storage.device()
288 }
289
290 #[must_use]
292 pub fn is_contiguous(&self) -> bool {
293 is_contiguous(&self.shape, &self.strides)
294 }
295
296 #[must_use]
298 pub fn is_scalar(&self) -> bool {
299 self.shape.is_empty()
300 }
301
302 pub fn get(&self, indices: &[usize]) -> Result<T> {
311 if indices.len() != self.ndim() {
312 return Err(Error::invalid_operation(format!(
313 "Expected {} indices, got {}",
314 self.ndim(),
315 indices.len()
316 )));
317 }
318
319 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
320 if idx >= dim {
321 return Err(Error::IndexOutOfBounds {
322 index: idx,
323 size: dim,
324 });
325 }
326 }
327
328 let offset = self.offset + linear_index(indices, &self.strides);
329 Ok(self.storage.as_slice()[offset])
330 }
331
332 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
338 if indices.len() != self.ndim() {
339 return Err(Error::invalid_operation(format!(
340 "Expected {} indices, got {}",
341 self.ndim(),
342 indices.len()
343 )));
344 }
345
346 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
347 if idx >= dim {
348 return Err(Error::IndexOutOfBounds {
349 index: idx,
350 size: dim,
351 });
352 }
353 }
354
355 let offset = self.offset + linear_index(indices, &self.strides);
356 self.storage.as_slice_mut()[offset] = value;
357 Ok(())
358 }
359
360 pub fn item(&self) -> Result<T> {
362 if self.numel() != 1 {
363 return Err(Error::invalid_operation(
364 "item() only works on single-element tensors",
365 ));
366 }
367
368 let data = self.to_vec();
370 if data.is_empty() {
371 Err(Error::invalid_operation("item() on empty tensor"))
372 } else {
373 Ok(data[0])
374 }
375 }
376
377 #[must_use]
383 pub fn to_vec(&self) -> Vec<T> {
384 #[cfg(feature = "cuda")]
386 if self.storage.is_gpu() {
387 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
388 let self_f32 = unsafe { gpu_ref(self) };
389 let f32_vec = self_f32.to_vec_gpu();
390 unsafe {
391 let mut v = std::mem::ManuallyDrop::new(f32_vec);
392 return Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity());
393 }
394 }
395
396 if self.is_contiguous() {
397 let storage = self.storage.as_slice();
398 storage[self.offset..self.offset + self.numel()].to_vec()
399 } else {
400 let mut result = Vec::with_capacity(self.numel());
401 self.copy_data_to(&mut result);
402 result
403 }
404 }
405
406 fn copy_data_to(&self, dst: &mut Vec<T>) {
408 dst.clear();
409 let storage = self.storage.as_slice();
410
411 let total = self.numel();
413 for i in 0..total {
414 let indices = crate::shape::unravel_index(i, &self.shape);
415 let offset = self.offset + linear_index(&indices, &self.strides);
416 dst.push(storage[offset]);
417 }
418 }
419
420 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
432 let shape = reshape(&self.shape, new_shape)?;
433
434 if self.is_contiguous() {
435 Ok(Self {
437 storage: self.storage.clone(),
438 strides: contiguous_strides(&shape),
439 shape,
440 offset: self.offset,
441 })
442 } else {
443 let contig = self.contiguous();
445 Ok(Self {
446 storage: contig.storage,
447 strides: contiguous_strides(&shape),
448 shape,
449 offset: 0,
450 })
451 }
452 }
453
454 #[must_use]
456 pub fn flatten(&self) -> Self {
457 self.reshape(&[-1]).expect("Flatten should never fail")
458 }
459
460 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
465 let dim = match dim {
466 Some(d) => Some(normalize_dim(d, self.ndim())?),
467 None => None,
468 };
469
470 let new_shape = squeeze(&self.shape, dim);
471 let new_strides: Strides = match dim {
472 Some(d) => {
473 let mut s = self.strides.clone();
474 if d < self.shape.len() && self.shape[d] == 1 {
475 s.remove(d);
476 }
477 s
478 }
479 None => self
480 .shape
481 .iter()
482 .zip(self.strides.iter())
483 .filter(|(dim, _)| **dim != 1)
484 .map(|(_, stride)| *stride)
485 .collect(),
486 };
487
488 Ok(Self {
489 storage: self.storage.clone(),
490 shape: new_shape,
491 strides: new_strides,
492 offset: self.offset,
493 })
494 }
495
496 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
501 let normalized = if dim < 0 {
502 (dim + self.ndim() as i64 + 1) as usize
503 } else {
504 dim as usize
505 };
506
507 let new_shape = unsqueeze(&self.shape, normalized)?;
508 let mut new_strides = Strides::with_capacity(new_shape.len());
509
510 for (i, _) in new_shape.iter().enumerate() {
511 if i < normalized {
512 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
513 } else if i == normalized {
514 new_strides.push(1);
516 } else {
517 new_strides.push(self.strides[i - 1]);
518 }
519 }
520
521 Ok(Self {
522 storage: self.storage.clone(),
523 shape: new_shape,
524 strides: new_strides,
525 offset: self.offset,
526 })
527 }
528
529 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
535 let d0 = normalize_dim(dim0, self.ndim())?;
536 let d1 = normalize_dim(dim1, self.ndim())?;
537
538 let new_shape = transpose_shape(&self.shape, d0, d1)?;
539 let new_strides = transpose_strides(&self.strides, d0, d1);
540
541 Ok(Self {
542 storage: self.storage.clone(),
543 shape: new_shape,
544 strides: new_strides,
545 offset: self.offset,
546 })
547 }
548
549 pub fn t(&self) -> Result<Self> {
551 if self.ndim() != 2 {
552 return Err(Error::invalid_operation("t() only works on 2D tensors"));
553 }
554 self.transpose(0, 1)
555 }
556
557 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
562 if dims.len() != self.ndim() {
563 return Err(Error::invalid_operation(format!(
564 "Expected {} dimensions, got {}",
565 self.ndim(),
566 dims.len()
567 )));
568 }
569
570 let mut seen = vec![false; self.ndim()];
572 for &d in dims {
573 if d >= self.ndim() {
574 return Err(Error::InvalidDimension {
575 index: d as i64,
576 ndim: self.ndim(),
577 });
578 }
579 if seen[d] {
580 return Err(Error::invalid_operation("Duplicate dimension in permute"));
581 }
582 seen[d] = true;
583 }
584
585 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
586 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
587
588 Ok(Self {
589 storage: self.storage.clone(),
590 shape: new_shape,
591 strides: new_strides,
592 offset: self.offset,
593 })
594 }
595
596 #[must_use]
598 pub fn contiguous(&self) -> Self {
599 if self.is_contiguous() && self.offset == 0 {
600 return self.clone();
601 }
602
603 #[cfg(feature = "cuda")]
604 if self.storage.is_gpu() {
605 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
606 let self_f32 = unsafe { gpu_ref(self) };
607 let result = self_f32.contiguous_gpu();
608 return unsafe { gpu_into(result) };
609 }
610
611 let data = self.to_vec();
612 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
613 }
614
615 #[must_use]
624 pub fn map<F: Fn(T) -> T>(&self, f: F) -> Self {
625 let data = self.to_vec(); let result: Vec<T> = data.into_iter().map(f).collect();
627 Self::from_vec(result, &self.shape).unwrap()
628 }
629
630 #[must_use]
636 pub fn zip_map<F: Fn(T, T) -> T>(&self, other: &Self, f: F) -> Self {
637 let a = self.to_vec();
638 let b = other.to_vec();
639 debug_assert_eq!(
640 a.len(),
641 b.len(),
642 "zip_map requires same number of elements: {} vs {}",
643 a.len(),
644 b.len()
645 );
646 let result: Vec<T> = a.into_iter().zip(b).map(|(x, y)| f(x, y)).collect();
647 Self::from_vec(result, &self.shape).unwrap()
648 }
649
650 #[must_use]
652 pub fn zip_map3<F: Fn(T, T, T) -> T>(&self, b: &Self, c: &Self, f: F) -> Self {
653 let a_data = self.to_vec();
654 let b_data = b.to_vec();
655 let c_data = c.to_vec();
656 debug_assert_eq!(a_data.len(), b_data.len());
657 debug_assert_eq!(a_data.len(), c_data.len());
658 let result: Vec<T> = a_data
659 .into_iter()
660 .zip(b_data)
661 .zip(c_data)
662 .map(|((a, b), c)| f(a, b, c))
663 .collect();
664 Self::from_vec(result, &self.shape).unwrap()
665 }
666
667 pub fn to_device(&self, device: Device) -> Result<Self> {
676 if self.device() == device {
677 return Ok(self.clone());
678 }
679
680 #[cfg(feature = "cuda")]
681 if self.storage.is_gpu() || device.is_gpu() {
682 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
683 let self_f32 = unsafe { gpu_ref(self) };
684 let result = self_f32.to_device_f32(device)?;
685 return Ok(unsafe { gpu_into(result) });
686 }
687
688 let contig = self.contiguous();
689 let new_storage = contig.storage.to_device(device)?;
690
691 Ok(Self {
692 storage: new_storage,
693 shape: self.shape.clone(),
694 strides: self.strides.clone(),
695 offset: 0,
696 })
697 }
698
699 pub fn cpu(&self) -> Result<Self> {
701 self.to_device(Device::Cpu)
702 }
703
704 #[must_use]
710 pub fn clone_deep(&self) -> Self {
711 let data = self.to_vec();
712 let cpu = Self::from_vec(data, &self.shape).expect("Deep clone should never fail");
713 #[cfg(feature = "cuda")]
714 if self.device().is_gpu() {
715 return cpu.to_device(self.device()).unwrap();
716 }
717 cpu
718 }
719}
720
721impl<T: Numeric> Tensor<T> {
726 pub fn fill_(&self, value: T) {
732 assert!(
733 self.storage.is_cpu(),
734 "fill_() not supported on GPU tensors — create a new tensor and transfer instead"
735 );
736 let mut data = self.storage.as_slice_mut();
737 CpuBackend::fill(&mut data, value);
738 }
739
740 pub fn zero_(&self) {
742 self.fill_(T::zero());
743 }
744
745 #[must_use]
753 pub fn sum(&self) -> Self {
754 #[cfg(feature = "cuda")]
755 if self.device().is_gpu() {
756 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
757 let self_f32 = unsafe { gpu_ref(self) };
758 let mut t = self_f32.clone();
759 while t.ndim() > 1 {
760 t = t.sum_dim_cuda(0);
761 }
762 if t.numel() > 1 {
763 t = t.sum_dim_cuda(0);
764 }
765 return unsafe { gpu_into(t) };
766 }
767
768 let data = self.to_vec();
769 let result = CpuBackend::sum(&data);
770 Self::scalar(result)
771 }
772
773 #[must_use]
777 pub fn prod(&self) -> Self {
778 let data = self.to_vec();
779 let result = CpuBackend::prod(&data);
780 let s = Self::scalar(result);
781 #[cfg(feature = "cuda")]
782 if self.device().is_gpu() {
783 return s
784 .to_device(self.device())
785 .expect("prod: device transfer failed");
786 }
787 s
788 }
789
790 pub fn max(&self) -> Result<Self> {
794 if self.is_empty() {
795 return Err(Error::EmptyTensor);
796 }
797 let data = self.to_vec();
798 let result = CpuBackend::max(&data).expect("max on non-empty tensor");
799 let s = Self::scalar(result);
800 #[cfg(feature = "cuda")]
801 if self.device().is_gpu() {
802 return Ok(s
803 .to_device(self.device())
804 .expect("max: device transfer failed"));
805 }
806 Ok(s)
807 }
808
809 pub fn min(&self) -> Result<Self> {
813 if self.is_empty() {
814 return Err(Error::EmptyTensor);
815 }
816 let data = self.to_vec();
817 let result = CpuBackend::min(&data).expect("min on non-empty tensor");
818 let s = Self::scalar(result);
819 #[cfg(feature = "cuda")]
820 if self.device().is_gpu() {
821 return Ok(s
822 .to_device(self.device())
823 .expect("min: device transfer failed"));
824 }
825 Ok(s)
826 }
827
828 pub fn argmax(&self) -> Result<usize> {
830 if self.is_empty() {
831 return Err(Error::EmptyTensor);
832 }
833 let data = self.to_vec();
834 Ok(CpuBackend::argmax(&data).unwrap())
835 }
836
837 pub fn argmin(&self) -> Result<usize> {
839 if self.is_empty() {
840 return Err(Error::EmptyTensor);
841 }
842 let data = self.to_vec();
843 Ok(CpuBackend::argmin(&data).unwrap())
844 }
845
846 pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
850 if tensors.is_empty() {
851 return Err(Error::invalid_operation("cat requires at least one tensor"));
852 }
853 let ndim = tensors[0].ndim();
854 if dim >= ndim {
855 return Err(Error::invalid_operation("cat dimension out of range"));
856 }
857
858 for t in &tensors[1..] {
859 if t.ndim() != ndim {
860 return Err(Error::invalid_operation(
861 "cat: all tensors must have same ndim",
862 ));
863 }
864 for d in 0..ndim {
865 if d != dim && t.shape[d] != tensors[0].shape[d] {
866 return Err(Error::invalid_operation(
867 "cat: shapes must match on non-cat dims",
868 ));
869 }
870 }
871 }
872
873 let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
874 let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
875 out_shape[dim] = total_dim_size;
876
877 let outer_size: usize = out_shape[..dim].iter().product();
878 let inner_size: usize = out_shape[dim + 1..].iter().product();
879 let total_numel: usize = out_shape.iter().product();
880 let mut result = vec![T::zero(); total_numel];
881
882 let mut dim_offset = 0;
883 for t in tensors {
884 let t_data = t.contiguous().to_vec();
885 let t_dim_size = t.shape[dim];
886 for outer in 0..outer_size {
887 for d in 0..t_dim_size {
888 let src_base = outer * t_dim_size * inner_size + d * inner_size;
889 let dst_base =
890 outer * total_dim_size * inner_size + (dim_offset + d) * inner_size;
891 result[dst_base..dst_base + inner_size]
892 .copy_from_slice(&t_data[src_base..src_base + inner_size]);
893 }
894 }
895 dim_offset += t_dim_size;
896 }
897
898 let out = Self::from_vec(result, &out_shape)?;
899 #[cfg(feature = "cuda")]
900 if tensors[0].device().is_gpu() {
901 return Ok(out.to_device(tensors[0].device()).unwrap());
902 }
903 Ok(out)
904 }
905}
906
907impl<T: Float> Tensor<T> {
912 pub fn mean(&self) -> Result<Self> {
917 if self.is_empty() {
918 return Err(Error::EmptyTensor);
919 }
920 #[cfg(feature = "cuda")]
921 if self.device().is_gpu() {
922 let s = self.sum(); let n = self.numel() as f32;
924 return Ok(s.mul_scalar(T::from(1.0 / n as f64).unwrap_or(T::zero())));
926 }
927
928 let data = self.to_vec();
929 let result = CpuBackend::mean(&data).expect("mean on non-empty tensor");
930 Ok(Self::scalar(result))
931 }
932
933 #[must_use]
939 pub fn relu(&self) -> Self {
940 #[cfg(feature = "cuda")]
941 if self.device().is_gpu() {
942 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
943 return unsafe { gpu_into(gpu_ref(self).relu_cuda()) };
944 }
945 let data = self.to_vec();
946 let mut result = vec![T::zero(); data.len()];
947 CpuBackend::relu(&mut result, &data);
948 Self::from_vec(result, &self.shape).unwrap()
949 }
950
951 #[must_use]
953 pub fn sigmoid(&self) -> Self {
954 #[cfg(feature = "cuda")]
955 if self.device().is_gpu() {
956 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
957 return unsafe { gpu_into(gpu_ref(self).sigmoid_cuda()) };
958 }
959 let data = self.to_vec();
960 let mut result = vec![T::zero(); data.len()];
961 CpuBackend::sigmoid(&mut result, &data);
962 Self::from_vec(result, &self.shape).unwrap()
963 }
964
965 #[must_use]
967 pub fn tanh(&self) -> Self {
968 #[cfg(feature = "cuda")]
969 if self.device().is_gpu() {
970 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
971 return unsafe { gpu_into(gpu_ref(self).tanh_cuda()) };
972 }
973 let data = self.to_vec();
974 let mut result = vec![T::zero(); data.len()];
975 CpuBackend::tanh(&mut result, &data);
976 Self::from_vec(result, &self.shape).unwrap()
977 }
978
979 #[must_use]
981 pub fn exp(&self) -> Self {
982 #[cfg(feature = "cuda")]
983 if self.device().is_gpu() {
984 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
985 return unsafe { gpu_into(gpu_ref(self).exp_cuda()) };
986 }
987 let data = self.to_vec();
988 let mut result = vec![T::zero(); data.len()];
989 CpuBackend::exp(&mut result, &data);
990 Self::from_vec(result, &self.shape).unwrap()
991 }
992
993 #[must_use]
995 pub fn ln(&self) -> Self {
996 #[cfg(feature = "cuda")]
997 if self.device().is_gpu() {
998 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
999 return unsafe { gpu_into(gpu_ref(self).ln_cuda()) };
1000 }
1001 let data = self.to_vec();
1002 let mut result = vec![T::zero(); data.len()];
1003 CpuBackend::ln(&mut result, &data);
1004 Self::from_vec(result, &self.shape).unwrap()
1005 }
1006
1007 #[must_use]
1009 pub fn sqrt(&self) -> Self {
1010 #[cfg(feature = "cuda")]
1011 if self.device().is_gpu() {
1012 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1013 return unsafe { gpu_into(gpu_ref(self).sqrt_cuda()) };
1014 }
1015 let data = self.to_vec();
1016 let mut result = vec![T::zero(); data.len()];
1017 CpuBackend::sqrt(&mut result, &data);
1018 Self::from_vec(result, &self.shape).unwrap()
1019 }
1020
1021 #[must_use]
1023 pub fn pow(&self, exp: T) -> Self {
1024 #[cfg(feature = "cuda")]
1025 if self.device().is_gpu() {
1026 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1027 let exp_f32: f32 = unsafe { *(&exp as *const T as *const f32) };
1028 return unsafe { gpu_into(gpu_ref(self).pow_cuda(exp_f32)) };
1029 }
1030 let data = self.to_vec();
1031 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
1032 Self::from_vec(result, &self.shape).unwrap()
1033 }
1034
1035 #[must_use]
1037 pub fn gelu(&self) -> Self {
1038 #[cfg(feature = "cuda")]
1039 if self.device().is_gpu() {
1040 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1041 return unsafe { gpu_into(gpu_ref(self).gelu_cuda()) };
1042 }
1043 crate::ops::gelu(self)
1044 }
1045
1046 #[must_use]
1048 pub fn silu(&self) -> Self {
1049 #[cfg(feature = "cuda")]
1050 if self.device().is_gpu() {
1051 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1052 return unsafe { gpu_into(gpu_ref(self).silu_cuda()) };
1053 }
1054 crate::ops::silu(self)
1055 }
1056
1057 #[must_use]
1059 pub fn softmax(&self, dim: i32) -> Self {
1060 #[cfg(feature = "cuda")]
1061 if self.device().is_gpu() {
1062 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1063 let self_f32 = unsafe { gpu_ref(self) };
1064 return unsafe { gpu_into(self_f32.softmax_cuda(dim).expect("CUDA softmax failed")) };
1065 }
1066 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
1067 }
1068
1069 #[must_use]
1071 pub fn log_softmax(&self, dim: i32) -> Self {
1072 let softmax_result = self.softmax(dim);
1073 softmax_result.ln()
1074 }
1075
1076 #[must_use]
1078 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
1079 let ndim = self.ndim();
1080 let dim = if dim < 0 {
1081 (ndim as i32 + dim) as usize
1082 } else {
1083 dim as usize
1084 };
1085
1086 if dim >= ndim {
1087 return self.clone();
1088 }
1089
1090 #[cfg(feature = "cuda")]
1092 if self.device().is_gpu() {
1093 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1094 let self_f32 = unsafe { gpu_ref(self) };
1095 let summed = if keepdim {
1096 self_f32.sum_dim_keepdim_cuda(dim)
1097 } else {
1098 self_f32.sum_dim_cuda(dim)
1099 };
1100 let dim_size = self.shape[dim];
1101 let result = summed.mul_scalar_cuda(1.0 / dim_size as f32);
1102 return unsafe { gpu_into(result) };
1103 }
1104
1105 let dim_size = self.shape[dim];
1106 let data = self.to_vec();
1107 let mut new_shape = self.shape.clone();
1108
1109 if keepdim {
1110 new_shape[dim] = 1;
1111 } else {
1112 new_shape.remove(dim);
1113 }
1114
1115 if new_shape.is_empty() {
1116 new_shape = smallvec::smallvec![1];
1117 }
1118
1119 let new_numel: usize = new_shape.iter().product();
1120 let mut result = vec![T::zero(); new_numel];
1121
1122 let outer_size: usize = self.shape[..dim].iter().product();
1123 let inner_size: usize = self.shape[dim + 1..].iter().product();
1124
1125 for outer in 0..outer_size {
1126 for inner in 0..inner_size {
1127 let mut sum = T::zero();
1128 for d in 0..dim_size {
1129 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1130 sum = sum + data[idx];
1131 }
1132 let mean = sum / NumCast::from(dim_size).unwrap();
1133 let result_idx = outer * inner_size + inner;
1134 result[result_idx] = mean;
1135 }
1136 }
1137
1138 Self::from_vec(result, &new_shape).unwrap()
1139 }
1140
1141 #[must_use]
1143 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
1144 let ndim = self.ndim();
1145 let dim = if dim < 0 {
1146 (ndim as i32 + dim) as usize
1147 } else {
1148 dim as usize
1149 };
1150
1151 if dim >= ndim {
1152 return self.clone();
1153 }
1154
1155 #[cfg(feature = "cuda")]
1157 if self.device().is_gpu() {
1158 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1159 let self_f32 = unsafe { gpu_ref(self) };
1160 let result = if keepdim {
1161 self_f32.sum_dim_keepdim_cuda(dim)
1162 } else {
1163 self_f32.sum_dim_cuda(dim)
1164 };
1165 return unsafe { gpu_into(result) };
1166 }
1167
1168 let dim_size = self.shape[dim];
1169 let data = self.to_vec();
1170 let mut new_shape = self.shape.clone();
1171
1172 if keepdim {
1173 new_shape[dim] = 1;
1174 } else {
1175 new_shape.remove(dim);
1176 }
1177
1178 if new_shape.is_empty() {
1179 new_shape = smallvec::smallvec![1];
1180 }
1181
1182 let new_numel: usize = new_shape.iter().product();
1183 let mut result = vec![T::zero(); new_numel];
1184
1185 let outer_size: usize = self.shape[..dim].iter().product();
1186 let inner_size: usize = self.shape[dim + 1..].iter().product();
1187
1188 for outer in 0..outer_size {
1189 for inner in 0..inner_size {
1190 let mut sum = T::zero();
1191 for d in 0..dim_size {
1192 let idx = outer * dim_size * inner_size + d * inner_size + inner;
1193 sum = sum + data[idx];
1194 }
1195 let result_idx = outer * inner_size + inner;
1196 result[result_idx] = sum;
1197 }
1198 }
1199
1200 Self::from_vec(result, &new_shape).unwrap()
1201 }
1202
1203 #[must_use]
1205 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
1206 let mean = self.mean_dim(dim, true);
1208 let sq = self.mul(self).unwrap_or_else(|_| self.clone());
1209 let mean_sq = sq.mean_dim(dim, keepdim);
1210 let mean_keepdim = if keepdim {
1211 mean.clone()
1212 } else {
1213 self.mean_dim(dim, keepdim)
1214 };
1215 let mean_squared = mean_keepdim
1216 .mul(&mean_keepdim)
1217 .unwrap_or_else(|_| mean_keepdim.clone());
1218 mean_sq
1219 .sub(&mean_squared)
1220 .unwrap_or_else(|_| mean_sq.clone())
1221 }
1222
1223 #[must_use]
1225 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
1226 if self.shape.as_slice() == shape {
1227 return self.clone();
1228 }
1229
1230 #[cfg(feature = "cuda")]
1231 if self.device().is_gpu() {
1232 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1233 let self_f32 = unsafe { gpu_ref(self) };
1234 return unsafe {
1235 gpu_into(
1236 self_f32
1237 .broadcast_to_cuda(shape)
1238 .expect("CUDA broadcast_to failed"),
1239 )
1240 };
1241 }
1242
1243 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
1244 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1245
1246 let total = numel(&result_shape);
1247 let mut result_data = vec![T::zero(); total];
1248 let self_data = self.storage.as_slice();
1249
1250 for i in 0..total {
1251 let indices = crate::shape::unravel_index(i, &result_shape);
1252 let self_idx = self.offset + linear_index(&indices, &self_strides);
1253 result_data[i] = self_data[self_idx];
1254 }
1255
1256 Self::from_vec(result_data, &result_shape).unwrap()
1257 }
1258
1259 #[must_use]
1261 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
1262 let mut new_shape = Vec::with_capacity(self.ndim());
1263 for (i, range) in ranges.iter().enumerate() {
1264 if i < self.ndim() {
1265 new_shape.push(range.end - range.start);
1266 }
1267 }
1268 for i in ranges.len()..self.ndim() {
1270 new_shape.push(self.shape[i]);
1271 }
1272
1273 let new_numel: usize = new_shape.iter().product();
1274 let mut result_data = vec![T::zero(); new_numel];
1275 let self_data = self.to_vec();
1276
1277 let mut result_idx = 0;
1279 Self::slice_recursive(
1280 &self_data,
1281 &self.shape,
1282 ranges,
1283 0,
1284 0,
1285 &mut result_data,
1286 &mut result_idx,
1287 );
1288
1289 let out = Self::from_vec(result_data, &new_shape).unwrap();
1290 #[cfg(feature = "cuda")]
1291 if self.device().is_gpu() {
1292 return out.to_device(self.device()).unwrap();
1293 }
1294 out
1295 }
1296
1297 fn slice_recursive(
1298 data: &[T],
1299 shape: &[usize],
1300 ranges: &[std::ops::Range<usize>],
1301 dim: usize,
1302 offset: usize,
1303 result: &mut [T],
1304 result_idx: &mut usize,
1305 ) {
1306 if dim == shape.len() {
1307 result[*result_idx] = data[offset];
1308 *result_idx += 1;
1309 return;
1310 }
1311
1312 let stride: usize = shape[dim + 1..].iter().product();
1313 let (start, end) = if dim < ranges.len() {
1314 (ranges[dim].start, ranges[dim].end)
1315 } else {
1316 (0, shape[dim])
1317 };
1318
1319 for i in start..end {
1320 Self::slice_recursive(
1321 data,
1322 shape,
1323 ranges,
1324 dim + 1,
1325 offset + i * stride,
1326 result,
1327 result_idx,
1328 );
1329 }
1330 }
1331}
1332
1333impl<T: Numeric> Tensor<T> {
1338 pub fn add(&self, other: &Self) -> Result<Self> {
1340 #[cfg(feature = "cuda")]
1341 {
1342 let self_gpu = self.device().is_gpu();
1343 let other_gpu = other.device().is_gpu();
1344 if self_gpu || other_gpu {
1345 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1346 if self_gpu && other_gpu {
1347 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1348 if self.shape == other.shape {
1349 return Ok(unsafe { gpu_into(s.add_cuda(o)?) });
1350 } else {
1351 return Ok(unsafe { gpu_into(s.broadcast_add_cuda(o)?) });
1352 }
1353 }
1354 let target_device = if self_gpu {
1356 self.device()
1357 } else {
1358 other.device()
1359 };
1360 let a_gpu = if self_gpu {
1361 self.clone()
1362 } else {
1363 self.to_device(target_device)?
1364 };
1365 let b_gpu = if other_gpu {
1366 other.clone()
1367 } else {
1368 other.to_device(target_device)?
1369 };
1370 return a_gpu.add(&b_gpu);
1371 }
1372 }
1373 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1375 let a = self.storage.as_slice();
1376 let b = other.storage.as_slice();
1377 let ao = self.offset;
1378 let bo = other.offset;
1379 let n = numel(&self.shape);
1380 let mut result_data = vec![T::zero(); n];
1381 for i in 0..n {
1382 result_data[i] = a[ao + i] + b[bo + i];
1383 }
1384 return Self::from_vec(result_data, &self.shape);
1385 }
1386
1387 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1388 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1389 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1390
1391 let total = numel(&result_shape);
1392 let mut result_data = vec![T::zero(); total];
1393
1394 let self_data = self.storage.as_slice();
1395 let other_data = other.storage.as_slice();
1396
1397 for i in 0..total {
1398 let indices = crate::shape::unravel_index(i, &result_shape);
1399 let self_idx = self.offset + linear_index(&indices, &self_strides);
1400 let other_idx = other.offset + linear_index(&indices, &other_strides);
1401 result_data[i] = self_data[self_idx] + other_data[other_idx];
1402 }
1403
1404 Self::from_vec(result_data, &result_shape)
1405 }
1406
1407 pub fn sub(&self, other: &Self) -> Result<Self> {
1409 #[cfg(feature = "cuda")]
1410 {
1411 let self_gpu = self.device().is_gpu();
1412 let other_gpu = other.device().is_gpu();
1413 if self_gpu || other_gpu {
1414 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1415 if self_gpu && other_gpu {
1416 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1417 if self.shape == other.shape {
1418 return Ok(unsafe { gpu_into(s.sub_cuda(o)?) });
1419 } else {
1420 return Ok(unsafe { gpu_into(s.broadcast_sub_cuda(o)?) });
1421 }
1422 }
1423 let target = if self_gpu {
1424 self.device()
1425 } else {
1426 other.device()
1427 };
1428 let a_gpu = if self_gpu {
1429 self.clone()
1430 } else {
1431 self.to_device(target)?
1432 };
1433 let b_gpu = if other_gpu {
1434 other.clone()
1435 } else {
1436 other.to_device(target)?
1437 };
1438 return a_gpu.sub(&b_gpu);
1439 }
1440 }
1441 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1443 let a = self.storage.as_slice();
1444 let b = other.storage.as_slice();
1445 let (ao, bo) = (self.offset, other.offset);
1446 let n = numel(&self.shape);
1447 let mut r = vec![T::zero(); n];
1448 for i in 0..n {
1449 r[i] = a[ao + i] - b[bo + i];
1450 }
1451 return Self::from_vec(r, &self.shape);
1452 }
1453
1454 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1455 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1456 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1457
1458 let total = numel(&result_shape);
1459 let mut result_data = vec![T::zero(); total];
1460
1461 let self_data = self.storage.as_slice();
1462 let other_data = other.storage.as_slice();
1463
1464 for i in 0..total {
1465 let indices = crate::shape::unravel_index(i, &result_shape);
1466 let self_idx = self.offset + linear_index(&indices, &self_strides);
1467 let other_idx = other.offset + linear_index(&indices, &other_strides);
1468 result_data[i] = self_data[self_idx] - other_data[other_idx];
1469 }
1470
1471 Self::from_vec(result_data, &result_shape)
1472 }
1473
1474 pub fn mul(&self, other: &Self) -> Result<Self> {
1476 #[cfg(feature = "cuda")]
1477 {
1478 let self_gpu = self.device().is_gpu();
1479 let other_gpu = other.device().is_gpu();
1480 if self_gpu || other_gpu {
1481 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1482 if self_gpu && other_gpu {
1483 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1484 if self.shape == other.shape {
1485 return Ok(unsafe { gpu_into(s.mul_cuda(o)?) });
1486 } else {
1487 return Ok(unsafe { gpu_into(s.broadcast_mul_cuda(o)?) });
1488 }
1489 }
1490 let target = if self_gpu {
1491 self.device()
1492 } else {
1493 other.device()
1494 };
1495 let a_gpu = if self_gpu {
1496 self.clone()
1497 } else {
1498 self.to_device(target)?
1499 };
1500 let b_gpu = if other_gpu {
1501 other.clone()
1502 } else {
1503 other.to_device(target)?
1504 };
1505 return a_gpu.mul(&b_gpu);
1506 }
1507 }
1508 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1510 let a = self.storage.as_slice();
1511 let b = other.storage.as_slice();
1512 let (ao, bo) = (self.offset, other.offset);
1513 let n = numel(&self.shape);
1514 let mut r = vec![T::zero(); n];
1515 for i in 0..n {
1516 r[i] = a[ao + i] * b[bo + i];
1517 }
1518 return Self::from_vec(r, &self.shape);
1519 }
1520
1521 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1522 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1523 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1524
1525 let total = numel(&result_shape);
1526 let mut result_data = vec![T::zero(); total];
1527
1528 let self_data = self.storage.as_slice();
1529 let other_data = other.storage.as_slice();
1530
1531 for i in 0..total {
1532 let indices = crate::shape::unravel_index(i, &result_shape);
1533 let self_idx = self.offset + linear_index(&indices, &self_strides);
1534 let other_idx = other.offset + linear_index(&indices, &other_strides);
1535 result_data[i] = self_data[self_idx] * other_data[other_idx];
1536 }
1537
1538 Self::from_vec(result_data, &result_shape)
1539 }
1540
1541 pub fn div(&self, other: &Self) -> Result<Self> {
1543 #[cfg(feature = "cuda")]
1544 {
1545 let self_gpu = self.device().is_gpu();
1546 let other_gpu = other.device().is_gpu();
1547 if self_gpu || other_gpu {
1548 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1549 if self_gpu && other_gpu {
1550 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1551 if self.shape == other.shape {
1552 return Ok(unsafe { gpu_into(s.div_cuda(o)?) });
1553 } else {
1554 return Ok(unsafe { gpu_into(s.broadcast_div_cuda(o)?) });
1555 }
1556 }
1557 let target = if self_gpu {
1558 self.device()
1559 } else {
1560 other.device()
1561 };
1562 let a_gpu = if self_gpu {
1563 self.clone()
1564 } else {
1565 self.to_device(target)?
1566 };
1567 let b_gpu = if other_gpu {
1568 other.clone()
1569 } else {
1570 other.to_device(target)?
1571 };
1572 return a_gpu.div(&b_gpu);
1573 }
1574 }
1575 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
1577 let a = self.storage.as_slice();
1578 let b = other.storage.as_slice();
1579 let (ao, bo) = (self.offset, other.offset);
1580 let n = numel(&self.shape);
1581 let mut r = vec![T::zero(); n];
1582 for i in 0..n {
1583 r[i] = a[ao + i] / b[bo + i];
1584 }
1585 return Self::from_vec(r, &self.shape);
1586 }
1587
1588 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1589 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1590 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1591
1592 let total = numel(&result_shape);
1593 let mut result_data = vec![T::zero(); total];
1594
1595 let self_data = self.storage.as_slice();
1596 let other_data = other.storage.as_slice();
1597
1598 for i in 0..total {
1599 let indices = crate::shape::unravel_index(i, &result_shape);
1600 let self_idx = self.offset + linear_index(&indices, &self_strides);
1601 let other_idx = other.offset + linear_index(&indices, &other_strides);
1602 result_data[i] = self_data[self_idx] / other_data[other_idx];
1603 }
1604
1605 Self::from_vec(result_data, &result_shape)
1606 }
1607
1608 #[must_use]
1610 pub fn add_scalar(&self, scalar: T) -> Self {
1611 #[cfg(feature = "cuda")]
1612 if self.device().is_gpu() {
1613 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1614 let self_f32 = unsafe { gpu_ref(self) };
1615 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1616 return unsafe { gpu_into(self_f32.add_scalar_cuda(scalar_f32)) };
1617 }
1618 let data = self.to_vec();
1619 let mut result = vec![T::zero(); data.len()];
1620 CpuBackend::add_scalar(&mut result, &data, scalar);
1621 Self::from_vec(result, &self.shape).unwrap()
1622 }
1623
1624 #[must_use]
1626 pub fn mul_scalar(&self, scalar: T) -> Self {
1627 #[cfg(feature = "cuda")]
1628 if self.device().is_gpu() {
1629 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1630 let self_f32 = unsafe { gpu_ref(self) };
1631 let scalar_f32: f32 = unsafe { *(&scalar as *const T as *const f32) };
1632 return unsafe { gpu_into(self_f32.mul_scalar_cuda(scalar_f32)) };
1633 }
1634 let data = self.to_vec();
1635 let mut result = vec![T::zero(); data.len()];
1636 CpuBackend::mul_scalar(&mut result, &data, scalar);
1637 Self::from_vec(result, &self.shape).unwrap()
1638 }
1639
1640 #[must_use]
1642 pub fn neg(&self) -> Self {
1643 #[cfg(feature = "cuda")]
1644 if self.device().is_gpu() {
1645 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1646 let self_f32 = unsafe { gpu_ref(self) };
1647 return unsafe { gpu_into(self_f32.neg_cuda()) };
1648 }
1649 let data = self.to_vec();
1650 let mut result = vec![T::zero(); data.len()];
1651 CpuBackend::neg(&mut result, &data);
1652 Self::from_vec(result, &self.shape).unwrap()
1653 }
1654
1655 pub fn matmul(&self, other: &Self) -> Result<Self> {
1662 #[cfg(feature = "cuda")]
1663 if self.device().is_gpu() {
1664 assert!(is_f32::<T>(), "GPU tensors are only supported for f32");
1665 let (s, o) = unsafe { (gpu_ref(self), gpu_ref(other)) };
1666 return Ok(unsafe { gpu_into(s.matmul_cuda(o)?) });
1667 }
1668 if self.ndim() < 2 || other.ndim() < 2 {
1669 return Err(Error::invalid_operation(
1670 "matmul requires at least 2D tensors",
1671 ));
1672 }
1673
1674 let m = self.shape[self.ndim() - 2];
1675 let k1 = self.shape[self.ndim() - 1];
1676 let k2 = other.shape[other.ndim() - 2];
1677 let n = other.shape[other.ndim() - 1];
1678
1679 if k1 != k2 {
1680 return Err(Error::invalid_operation(format!(
1681 "matmul inner dimensions must match: {k1} vs {k2}"
1682 )));
1683 }
1684
1685 if self.ndim() == 2 && other.ndim() == 2 {
1687 let a_data = self.contiguous().to_vec();
1688 let b_data = other.contiguous().to_vec();
1689
1690 #[cfg(feature = "cuda")]
1694 {
1695 let flops = m * n * k1;
1696 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
1697 && flops >= 4_000_000
1698 {
1699 debug_assert!(std::mem::size_of::<T>() == std::mem::size_of::<f32>());
1700 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1702 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1703 if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1704 let c_t: Vec<T> = unsafe {
1706 let mut v = std::mem::ManuallyDrop::new(c_f32);
1707 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1708 };
1709 return Self::from_vec(c_t, &[m, n]);
1710 }
1711 }
1712 }
1713
1714 let mut c_data = vec![T::zero(); m * n];
1715 CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1716 return Self::from_vec(c_data, &[m, n]);
1717 }
1718
1719 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1721 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1722
1723 let broadcast_batch = if batch_dims_self == batch_dims_other {
1725 None
1726 } else {
1727 let max_len = batch_dims_self.len().max(batch_dims_other.len());
1729 let pad_a = vec![1usize; max_len - batch_dims_self.len()];
1730 let pad_b = vec![1usize; max_len - batch_dims_other.len()];
1731 let a_dims: Vec<usize> = pad_a
1732 .iter()
1733 .chain(batch_dims_self.iter())
1734 .copied()
1735 .collect();
1736 let b_dims: Vec<usize> = pad_b
1737 .iter()
1738 .chain(batch_dims_other.iter())
1739 .copied()
1740 .collect();
1741
1742 let mut out_dims = Vec::with_capacity(max_len);
1743 for i in 0..max_len {
1744 if a_dims[i] == b_dims[i] {
1745 out_dims.push(a_dims[i]);
1746 } else if a_dims[i] == 1 {
1747 out_dims.push(b_dims[i]);
1748 } else if b_dims[i] == 1 {
1749 out_dims.push(a_dims[i]);
1750 } else {
1751 return Err(Error::invalid_operation(format!(
1752 "matmul batch dimensions not broadcastable: {:?} vs {:?}",
1753 batch_dims_self, batch_dims_other
1754 )));
1755 }
1756 }
1757 Some((a_dims, b_dims, out_dims))
1758 };
1759
1760 let (batch_size, a_batch_idx, b_batch_idx) =
1761 if let Some((a_dims, b_dims, out_dims)) = &broadcast_batch {
1762 let bs: usize = out_dims.iter().product();
1763 let mut a_idx = Vec::with_capacity(bs);
1765 let mut b_idx = Vec::with_capacity(bs);
1766 for flat in 0..bs {
1767 let mut remaining = flat;
1768 let mut ai = 0usize;
1769 let mut bi = 0usize;
1770 let mut a_stride_acc = 1usize;
1771 let mut b_stride_acc = 1usize;
1772 for d in (0..out_dims.len()).rev() {
1773 let out_d = out_dims[d];
1774 let idx = remaining % out_d;
1775 remaining /= out_d;
1776 let a_d = a_dims[d];
1777 let b_d = b_dims[d];
1778 ai += (idx % a_d) * a_stride_acc;
1779 bi += (idx % b_d) * b_stride_acc;
1780 a_stride_acc *= a_d;
1781 b_stride_acc *= b_d;
1782 }
1783 a_idx.push(ai);
1784 b_idx.push(bi);
1785 }
1786 (bs, a_idx, b_idx)
1787 } else {
1788 let bs: usize = batch_dims_self.iter().product();
1789 let idx: Vec<usize> = (0..bs).collect();
1790 (bs, idx.clone(), idx)
1791 };
1792
1793 let a_stride = m * k1;
1794 let b_stride = k1 * n;
1795 let c_stride = m * n;
1796
1797 let a_data = self.contiguous().to_vec();
1798 let b_data = other.contiguous().to_vec();
1799 let mut c_data = vec![T::zero(); batch_size * m * n];
1800
1801 #[cfg(feature = "cuda")]
1803 {
1804 let flops = m * n * k1;
1805 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && flops >= 4_000_000 {
1806 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1807 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1808 let mut gpu_ok = true;
1809 for batch in 0..batch_size {
1810 let ai = a_batch_idx[batch];
1811 let bi = b_batch_idx[batch];
1812 let a_slice = &a_f32[ai * a_stride..(ai + 1) * a_stride];
1813 let b_slice = &b_f32[bi * b_stride..(bi + 1) * b_stride];
1814 if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1815 c_data[batch * c_stride..(batch + 1) * c_stride]
1816 .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1817 } else {
1818 gpu_ok = false;
1819 break;
1820 }
1821 }
1822 if gpu_ok {
1823 let mut output_shape = batch_dims_self;
1824 output_shape.push(m);
1825 output_shape.push(n);
1826 return Self::from_vec(c_data, &output_shape);
1827 }
1828 c_data = vec![T::zero(); batch_size * m * n];
1830 }
1831 }
1832
1833 for batch in 0..batch_size {
1835 let ai = a_batch_idx[batch];
1836 let bi = b_batch_idx[batch];
1837 let a_slice = &a_data[ai * a_stride..(ai + 1) * a_stride];
1838 let b_slice = &b_data[bi * b_stride..(bi + 1) * b_stride];
1839 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1840 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1841 }
1842
1843 let mut output_shape = if let Some((_, _, ref out_dims)) = broadcast_batch {
1845 out_dims.clone()
1846 } else {
1847 batch_dims_self
1848 };
1849 output_shape.push(m);
1850 output_shape.push(n);
1851
1852 Self::from_vec(c_data, &output_shape)
1853 }
1854
1855 pub fn dot(&self, other: &Self) -> Result<Self> {
1857 if self.ndim() != 1 || other.ndim() != 1 {
1858 return Err(Error::invalid_operation("dot requires 1D tensors"));
1859 }
1860
1861 if self.shape[0] != other.shape[0] {
1862 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1863 }
1864
1865 let a_data = self.to_vec();
1866 let b_data = other.to_vec();
1867 let result = CpuBackend::dot(&a_data, &b_data);
1868
1869 Ok(Self::scalar(result))
1870 }
1871}
1872
1873impl<T: Numeric> Add for &Tensor<T> {
1878 type Output = Tensor<T>;
1879
1880 fn add(self, other: Self) -> Self::Output {
1881 self.add(other).expect("Addition failed")
1882 }
1883}
1884
1885impl<T: Numeric> Sub for &Tensor<T> {
1886 type Output = Tensor<T>;
1887
1888 fn sub(self, other: Self) -> Self::Output {
1889 self.sub(other).expect("Subtraction failed")
1890 }
1891}
1892
1893impl<T: Numeric> Mul for &Tensor<T> {
1894 type Output = Tensor<T>;
1895
1896 fn mul(self, other: Self) -> Self::Output {
1897 self.mul(other).expect("Multiplication failed")
1898 }
1899}
1900
1901impl<T: Numeric> Div for &Tensor<T> {
1902 type Output = Tensor<T>;
1903
1904 fn div(self, other: Self) -> Self::Output {
1905 self.div(other).expect("Division failed")
1906 }
1907}
1908
1909impl<T: Numeric> Neg for &Tensor<T> {
1910 type Output = Tensor<T>;
1911
1912 fn neg(self) -> Self::Output {
1913 self.neg()
1914 }
1915}
1916
1917impl<T: Numeric> Add<T> for &Tensor<T> {
1919 type Output = Tensor<T>;
1920
1921 fn add(self, scalar: T) -> Self::Output {
1922 self.add_scalar(scalar)
1923 }
1924}
1925
1926impl<T: Numeric> Mul<T> for &Tensor<T> {
1927 type Output = Tensor<T>;
1928
1929 fn mul(self, scalar: T) -> Self::Output {
1930 self.mul_scalar(scalar)
1931 }
1932}
1933
1934impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1939 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1940 write!(
1941 f,
1942 "Tensor(shape={:?}, device={}",
1943 self.shape(),
1944 self.device()
1945 )?;
1946 if self.numel() <= 10 {
1947 write!(f, ", data={:?}", self.to_vec())?;
1948 }
1949 write!(f, ")")
1950 }
1951}
1952
1953impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1954 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1955 if self.is_scalar() {
1956 write!(f, "{}", self.item().unwrap())
1957 } else if self.ndim() == 1 {
1958 write!(f, "[")?;
1959 let data = self.to_vec();
1960 for (i, val) in data.iter().enumerate() {
1961 if i > 0 {
1962 write!(f, ", ")?;
1963 }
1964 write!(f, "{val}")?;
1965 }
1966 write!(f, "]")
1967 } else {
1968 write!(f, "Tensor(shape={:?})", self.shape())
1969 }
1970 }
1971}
1972
1973impl Tensor<f32> {
1978 #[must_use]
1987 pub fn to_f16_precision(&self) -> Self {
1988 let data = self.to_vec();
1989 let f16_data: Vec<f32> = data
1990 .iter()
1991 .map(|&v| {
1992 let h = half::f16::from_f32(v);
1993 h.to_f32()
1994 })
1995 .collect();
1996 Self::from_vec(f16_data, self.shape()).unwrap()
1997 }
1998
1999 #[must_use]
2004 pub fn to_f32_precision(&self) -> Self {
2005 self.clone()
2006 }
2007
2008 #[must_use]
2011 pub fn has_f16_rounding_error(&self) -> bool {
2012 let data = self.to_vec();
2013 data.iter().any(|&v| {
2014 let h = half::f16::from_f32(v);
2015 (h.to_f32() - v).abs() > f32::EPSILON
2016 })
2017 }
2018}
2019
2020#[cfg(test)]
2025mod tests {
2026 use super::*;
2027
2028 #[test]
2029 fn test_from_vec() {
2030 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2031 assert_eq!(t.shape(), &[2, 3]);
2032 assert_eq!(t.numel(), 6);
2033 }
2034
2035 #[test]
2036 fn test_get_set() {
2037 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2038 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
2039 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
2040 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
2041 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
2042
2043 t.set(&[0, 0], 99.0).unwrap();
2044 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
2045 }
2046
2047 #[test]
2048 fn test_reshape() {
2049 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2050 let r = t.reshape(&[3, 2]).expect("reshape failed");
2051 assert_eq!(r.shape(), &[3, 2]);
2052
2053 let r = t.reshape(&[-1]).expect("reshape failed");
2054 assert_eq!(r.shape(), &[6]);
2055 }
2056
2057 #[test]
2058 fn test_transpose() {
2059 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
2060 let r = t.t().unwrap();
2061 assert_eq!(r.shape(), &[3, 2]);
2062 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
2063 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
2064 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
2065 }
2066
2067 #[test]
2068 fn test_arithmetic() {
2069 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2070 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
2071
2072 let c = &a + &b;
2073 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
2074
2075 let d = &a * &b;
2076 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
2077 }
2078
2079 #[test]
2080 fn test_broadcasting() {
2081 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
2082 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
2083
2084 let c = &a + &b;
2085 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
2086 }
2087
2088 #[test]
2089 fn test_sum() {
2090 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
2091 let s = t.sum();
2092 assert_eq!(s.item().unwrap(), 10.0);
2093 }
2094
2095 #[test]
2096 fn test_matmul() {
2097 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2099 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
2100 let c = a.matmul(&b).unwrap();
2101
2102 assert_eq!(c.shape(), &[2, 2]);
2103 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
2104 }
2105
2106 #[test]
2107 fn test_relu() {
2108 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
2109 let r = t.relu();
2110 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
2111 }
2112
2113 #[test]
2114 fn test_scalar() {
2115 let s = Tensor::<f32>::scalar(42.0);
2116 assert!(s.is_scalar());
2117 assert_eq!(s.numel(), 1);
2118 assert_eq!(s.item().unwrap(), 42.0);
2119 }
2120}