1use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::backends::CpuBackend;
21#[cfg(feature = "cuda")]
22use axonml_core::backends::CudaBackend;
23use axonml_core::dtype::{Float, Numeric, Scalar};
24use axonml_core::error::{Error, Result};
25use axonml_core::storage::Storage;
26use axonml_core::Device;
27use num_traits::NumCast;
28
29#[cfg(feature = "cuda")]
34mod cuda_accel {
35 use super::*;
36 use std::sync::OnceLock;
37
38 static CUDA_BACKEND: OnceLock<Option<CudaBackend>> = OnceLock::new();
39
40 pub fn get_cuda() -> Option<&'static CudaBackend> {
42 CUDA_BACKEND
43 .get_or_init(|| {
44 let backend = CudaBackend::new(0);
45 if backend.is_some() {
46 eprintln!("[AxonML] CUDA backend initialized (GPU 0)");
47 }
48 backend
49 })
50 .as_ref()
51 }
52
53 pub fn cuda_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Option<Vec<f32>> {
56 let cuda = get_cuda()?;
57
58 let a_gpu = cuda.htod_copy(a).ok()?;
61 let b_gpu = cuda.htod_copy(b).ok()?;
62 let mut c_gpu = cuda.alloc::<f32>(m * n).ok()?;
63
64 cuda.gemm_f32(
67 false, false, n, m, k, 1.0,
70 &b_gpu, n, &a_gpu, k, 0.0,
73 &mut c_gpu, n, ).ok()?;
75
76 cuda.dtoh_copy(&c_gpu).ok()
77 }
78}
79
80use crate::shape::{
81 broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous, linear_index,
82 normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides, unsqueeze, Shape,
83 Strides,
84};
85
86#[derive(Clone)]
96pub struct Tensor<T: Scalar> {
97 pub(crate) storage: Storage<T>,
99 pub(crate) shape: Shape,
101 pub(crate) strides: Strides,
103 pub(crate) offset: usize,
105}
106
107impl<T: Scalar> Tensor<T> {
108 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
121 let total = numel(shape);
122 if total != storage.len() {
123 return Err(Error::shape_mismatch(&[storage.len()], shape));
124 }
125
126 let shape = Shape::from_slice(shape);
127 let strides = contiguous_strides(&shape);
128
129 Ok(Self {
130 storage,
131 shape,
132 strides,
133 offset: 0,
134 })
135 }
136
137 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
146 let storage = Storage::from_vec(data, Device::Cpu);
147 Self::from_storage(storage, shape)
148 }
149
150 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
159 let storage = Storage::from_slice(data, Device::Cpu);
160 Self::from_storage(storage, shape)
161 }
162
163 pub fn scalar(value: T) -> Self {
171 Self {
172 storage: Storage::from_vec(vec![value], Device::Cpu),
173 shape: Shape::new(),
174 strides: Strides::new(),
175 offset: 0,
176 }
177 }
178
179 #[must_use]
181 pub fn zeros(shape: &[usize]) -> Self {
182 crate::creation::zeros(shape)
183 }
184
185 #[must_use]
187 pub fn ones(shape: &[usize]) -> Self
188 where
189 T: Numeric,
190 {
191 crate::creation::ones(shape)
192 }
193
194 #[must_use]
196 pub fn full(shape: &[usize], value: T) -> Self {
197 crate::creation::full(shape, value)
198 }
199
200 #[must_use]
202 pub fn randn(shape: &[usize]) -> Self
203 where
204 T: Float,
205 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
206 {
207 crate::creation::randn(shape)
208 }
209
210 #[must_use]
212 pub fn rand(shape: &[usize]) -> Self
213 where
214 T: Float,
215 rand::distributions::Standard: rand::distributions::Distribution<T>,
216 {
217 crate::creation::rand(shape)
218 }
219
220 #[must_use]
226 pub fn shape(&self) -> &[usize] {
227 &self.shape
228 }
229
230 #[must_use]
232 pub fn strides(&self) -> &[isize] {
233 &self.strides
234 }
235
236 #[must_use]
238 pub fn ndim(&self) -> usize {
239 self.shape.len()
240 }
241
242 #[must_use]
244 pub fn numel(&self) -> usize {
245 numel(&self.shape)
246 }
247
248 #[must_use]
250 pub fn is_empty(&self) -> bool {
251 self.numel() == 0
252 }
253
254 pub fn size(&self, dim: i64) -> Result<usize> {
259 let idx = normalize_dim(dim, self.ndim())?;
260 Ok(self.shape[idx])
261 }
262
263 #[must_use]
265 pub fn device(&self) -> Device {
266 self.storage.device()
267 }
268
269 #[must_use]
271 pub fn is_contiguous(&self) -> bool {
272 is_contiguous(&self.shape, &self.strides)
273 }
274
275 #[must_use]
277 pub fn is_scalar(&self) -> bool {
278 self.shape.is_empty()
279 }
280
281 pub fn get(&self, indices: &[usize]) -> Result<T> {
290 if indices.len() != self.ndim() {
291 return Err(Error::invalid_operation(format!(
292 "Expected {} indices, got {}",
293 self.ndim(),
294 indices.len()
295 )));
296 }
297
298 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
299 if idx >= dim {
300 return Err(Error::IndexOutOfBounds {
301 index: idx,
302 size: dim,
303 });
304 }
305 }
306
307 let offset = self.offset + linear_index(indices, &self.strides);
308 Ok(self.storage.as_slice()[offset])
309 }
310
311 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
317 if indices.len() != self.ndim() {
318 return Err(Error::invalid_operation(format!(
319 "Expected {} indices, got {}",
320 self.ndim(),
321 indices.len()
322 )));
323 }
324
325 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
326 if idx >= dim {
327 return Err(Error::IndexOutOfBounds {
328 index: idx,
329 size: dim,
330 });
331 }
332 }
333
334 let offset = self.offset + linear_index(indices, &self.strides);
335 self.storage.as_slice_mut()[offset] = value;
336 Ok(())
337 }
338
339 pub fn item(&self) -> Result<T> {
341 if self.numel() != 1 {
342 return Err(Error::invalid_operation(
343 "item() only works on single-element tensors",
344 ));
345 }
346
347 if self.is_scalar() {
348 Ok(self.storage.as_slice()[self.offset])
349 } else {
350 let indices = vec![0; self.ndim()];
352 self.get(&indices)
353 }
354 }
355
356 #[must_use]
361 pub fn to_vec(&self) -> Vec<T> {
362 if self.is_contiguous() {
363 let storage = self.storage.as_slice();
364 storage[self.offset..self.offset + self.numel()].to_vec()
365 } else {
366 let mut result = Vec::with_capacity(self.numel());
367 self.copy_data_to(&mut result);
368 result
369 }
370 }
371
372 fn copy_data_to(&self, dst: &mut Vec<T>) {
374 dst.clear();
375 let storage = self.storage.as_slice();
376
377 let total = self.numel();
379 for i in 0..total {
380 let indices = crate::shape::unravel_index(i, &self.shape);
381 let offset = self.offset + linear_index(&indices, &self.strides);
382 dst.push(storage[offset]);
383 }
384 }
385
386 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
398 let shape = reshape(&self.shape, new_shape)?;
399
400 if self.is_contiguous() {
401 Ok(Self {
403 storage: self.storage.clone(),
404 strides: contiguous_strides(&shape),
405 shape,
406 offset: self.offset,
407 })
408 } else {
409 let contig = self.contiguous();
411 Ok(Self {
412 storage: contig.storage,
413 strides: contiguous_strides(&shape),
414 shape,
415 offset: 0,
416 })
417 }
418 }
419
420 #[must_use]
422 pub fn flatten(&self) -> Self {
423 self.reshape(&[-1]).expect("Flatten should never fail")
424 }
425
426 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
431 let dim = match dim {
432 Some(d) => Some(normalize_dim(d, self.ndim())?),
433 None => None,
434 };
435
436 let new_shape = squeeze(&self.shape, dim);
437 let new_strides: Strides = match dim {
438 Some(d) => {
439 let mut s = self.strides.clone();
440 if d < self.shape.len() && self.shape[d] == 1 {
441 s.remove(d);
442 }
443 s
444 }
445 None => self
446 .shape
447 .iter()
448 .zip(self.strides.iter())
449 .filter(|(&dim, _)| dim != 1)
450 .map(|(_, &stride)| stride)
451 .collect(),
452 };
453
454 Ok(Self {
455 storage: self.storage.clone(),
456 shape: new_shape,
457 strides: new_strides,
458 offset: self.offset,
459 })
460 }
461
462 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
467 let normalized = if dim < 0 {
468 (dim + self.ndim() as i64 + 1) as usize
469 } else {
470 dim as usize
471 };
472
473 let new_shape = unsqueeze(&self.shape, normalized)?;
474 let mut new_strides = Strides::with_capacity(new_shape.len());
475
476 for (i, _) in new_shape.iter().enumerate() {
477 if i < normalized {
478 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
479 } else if i == normalized {
480 new_strides.push(1);
482 } else {
483 new_strides.push(self.strides[i - 1]);
484 }
485 }
486
487 Ok(Self {
488 storage: self.storage.clone(),
489 shape: new_shape,
490 strides: new_strides,
491 offset: self.offset,
492 })
493 }
494
495 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
501 let d0 = normalize_dim(dim0, self.ndim())?;
502 let d1 = normalize_dim(dim1, self.ndim())?;
503
504 let new_shape = transpose_shape(&self.shape, d0, d1)?;
505 let new_strides = transpose_strides(&self.strides, d0, d1);
506
507 Ok(Self {
508 storage: self.storage.clone(),
509 shape: new_shape,
510 strides: new_strides,
511 offset: self.offset,
512 })
513 }
514
515 pub fn t(&self) -> Result<Self> {
517 if self.ndim() != 2 {
518 return Err(Error::invalid_operation("t() only works on 2D tensors"));
519 }
520 self.transpose(0, 1)
521 }
522
523 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
528 if dims.len() != self.ndim() {
529 return Err(Error::invalid_operation(format!(
530 "Expected {} dimensions, got {}",
531 self.ndim(),
532 dims.len()
533 )));
534 }
535
536 let mut seen = vec![false; self.ndim()];
538 for &d in dims {
539 if d >= self.ndim() {
540 return Err(Error::InvalidDimension {
541 index: d as i64,
542 ndim: self.ndim(),
543 });
544 }
545 if seen[d] {
546 return Err(Error::invalid_operation("Duplicate dimension in permute"));
547 }
548 seen[d] = true;
549 }
550
551 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
552 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
553
554 Ok(Self {
555 storage: self.storage.clone(),
556 shape: new_shape,
557 strides: new_strides,
558 offset: self.offset,
559 })
560 }
561
562 #[must_use]
564 pub fn contiguous(&self) -> Self {
565 if self.is_contiguous() && self.offset == 0 {
566 return self.clone();
567 }
568
569 let data = self.to_vec();
570 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
571 }
572
573 pub fn to_device(&self, device: Device) -> Result<Self> {
582 if self.device() == device {
583 return Ok(self.clone());
584 }
585
586 let contig = self.contiguous();
587 let new_storage = contig.storage.to_device(device)?;
588
589 Ok(Self {
590 storage: new_storage,
591 shape: self.shape.clone(),
592 strides: self.strides.clone(),
593 offset: 0,
594 })
595 }
596
597 pub fn cpu(&self) -> Result<Self> {
599 self.to_device(Device::Cpu)
600 }
601
602 #[must_use]
608 pub fn clone_deep(&self) -> Self {
609 let data = self.to_vec();
610 Self::from_vec(data, &self.shape).expect("Deep clone should never fail")
611 }
612}
613
614impl<T: Numeric> Tensor<T> {
619 pub fn fill_(&self, value: T) {
621 let mut data = self.storage.as_slice_mut();
622 CpuBackend::fill(&mut data, value);
623 }
624
625 pub fn zero_(&self) {
627 self.fill_(T::zero());
628 }
629
630 #[must_use]
636 pub fn sum(&self) -> Self {
637 let data = self.to_vec();
638 let result = CpuBackend::sum(&data);
639 Self::scalar(result)
640 }
641
642 #[must_use]
644 pub fn prod(&self) -> Self {
645 let data = self.to_vec();
646 let result = CpuBackend::prod(&data);
647 Self::scalar(result)
648 }
649
650 pub fn max(&self) -> Result<Self> {
652 if self.is_empty() {
653 return Err(Error::EmptyTensor);
654 }
655 let data = self.to_vec();
656 let result = CpuBackend::max(&data).unwrap();
657 Ok(Self::scalar(result))
658 }
659
660 pub fn min(&self) -> Result<Self> {
662 if self.is_empty() {
663 return Err(Error::EmptyTensor);
664 }
665 let data = self.to_vec();
666 let result = CpuBackend::min(&data).unwrap();
667 Ok(Self::scalar(result))
668 }
669
670 pub fn argmax(&self) -> Result<usize> {
672 if self.is_empty() {
673 return Err(Error::EmptyTensor);
674 }
675 let data = self.to_vec();
676 Ok(CpuBackend::argmax(&data).unwrap())
677 }
678
679 pub fn argmin(&self) -> Result<usize> {
681 if self.is_empty() {
682 return Err(Error::EmptyTensor);
683 }
684 let data = self.to_vec();
685 Ok(CpuBackend::argmin(&data).unwrap())
686 }
687
688 pub fn cat(tensors: &[&Self], dim: usize) -> Result<Self> {
692 if tensors.is_empty() {
693 return Err(Error::invalid_operation("cat requires at least one tensor"));
694 }
695 let ndim = tensors[0].ndim();
696 if dim >= ndim {
697 return Err(Error::invalid_operation("cat dimension out of range"));
698 }
699
700 for t in &tensors[1..] {
701 if t.ndim() != ndim {
702 return Err(Error::invalid_operation("cat: all tensors must have same ndim"));
703 }
704 for d in 0..ndim {
705 if d != dim && t.shape[d] != tensors[0].shape[d] {
706 return Err(Error::invalid_operation("cat: shapes must match on non-cat dims"));
707 }
708 }
709 }
710
711 let total_dim_size: usize = tensors.iter().map(|t| t.shape[dim]).sum();
712 let mut out_shape: Vec<usize> = tensors[0].shape.to_vec();
713 out_shape[dim] = total_dim_size;
714
715 let outer_size: usize = out_shape[..dim].iter().product();
716 let inner_size: usize = out_shape[dim + 1..].iter().product();
717 let total_numel: usize = out_shape.iter().product();
718 let mut result = vec![T::zero(); total_numel];
719
720 let mut dim_offset = 0;
721 for t in tensors {
722 let t_data = t.contiguous().to_vec();
723 let t_dim_size = t.shape[dim];
724 for outer in 0..outer_size {
725 for d in 0..t_dim_size {
726 for inner in 0..inner_size {
727 let src_idx = outer * t_dim_size * inner_size + d * inner_size + inner;
728 let dst_idx = outer * total_dim_size * inner_size
729 + (dim_offset + d) * inner_size
730 + inner;
731 result[dst_idx] = t_data[src_idx];
732 }
733 }
734 }
735 dim_offset += t_dim_size;
736 }
737
738 Self::from_vec(result, &out_shape)
739 }
740}
741
742impl<T: Float> Tensor<T> {
747 pub fn mean(&self) -> Result<Self> {
749 if self.is_empty() {
750 return Err(Error::EmptyTensor);
751 }
752 let data = self.to_vec();
753 let result = CpuBackend::mean(&data).unwrap();
754 Ok(Self::scalar(result))
755 }
756
757 #[must_use]
763 pub fn relu(&self) -> Self {
764 let data = self.to_vec();
765 let mut result = vec![T::zero(); data.len()];
766 CpuBackend::relu(&mut result, &data);
767 Self::from_vec(result, &self.shape).unwrap()
768 }
769
770 #[must_use]
772 pub fn sigmoid(&self) -> Self {
773 let data = self.to_vec();
774 let mut result = vec![T::zero(); data.len()];
775 CpuBackend::sigmoid(&mut result, &data);
776 Self::from_vec(result, &self.shape).unwrap()
777 }
778
779 #[must_use]
781 pub fn tanh(&self) -> Self {
782 let data = self.to_vec();
783 let mut result = vec![T::zero(); data.len()];
784 CpuBackend::tanh(&mut result, &data);
785 Self::from_vec(result, &self.shape).unwrap()
786 }
787
788 #[must_use]
790 pub fn exp(&self) -> Self {
791 let data = self.to_vec();
792 let mut result = vec![T::zero(); data.len()];
793 CpuBackend::exp(&mut result, &data);
794 Self::from_vec(result, &self.shape).unwrap()
795 }
796
797 #[must_use]
799 pub fn ln(&self) -> Self {
800 let data = self.to_vec();
801 let mut result = vec![T::zero(); data.len()];
802 CpuBackend::ln(&mut result, &data);
803 Self::from_vec(result, &self.shape).unwrap()
804 }
805
806 #[must_use]
808 pub fn sqrt(&self) -> Self {
809 let data = self.to_vec();
810 let mut result = vec![T::zero(); data.len()];
811 CpuBackend::sqrt(&mut result, &data);
812 Self::from_vec(result, &self.shape).unwrap()
813 }
814
815 #[must_use]
817 pub fn pow(&self, exp: T) -> Self {
818 let data = self.to_vec();
819 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
820 Self::from_vec(result, &self.shape).unwrap()
821 }
822
823 #[must_use]
825 pub fn gelu(&self) -> Self {
826 crate::ops::gelu(self)
827 }
828
829 #[must_use]
831 pub fn silu(&self) -> Self {
832 crate::ops::silu(self)
833 }
834
835 #[must_use]
837 pub fn softmax(&self, dim: i32) -> Self {
838 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
839 }
840
841 #[must_use]
843 pub fn log_softmax(&self, dim: i32) -> Self {
844 let softmax_result = self.softmax(dim);
845 softmax_result.ln()
846 }
847
848 #[must_use]
850 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
851 let ndim = self.ndim();
852 let dim = if dim < 0 {
853 (ndim as i32 + dim) as usize
854 } else {
855 dim as usize
856 };
857
858 if dim >= ndim {
859 return self.clone();
860 }
861
862 let dim_size = self.shape[dim];
863 let data = self.to_vec();
864 let mut new_shape = self.shape.clone();
865
866 if keepdim {
867 new_shape[dim] = 1;
868 } else {
869 new_shape.remove(dim);
870 }
871
872 if new_shape.is_empty() {
873 new_shape = smallvec::smallvec![1];
874 }
875
876 let new_numel: usize = new_shape.iter().product();
877 let mut result = vec![T::zero(); new_numel];
878
879 let outer_size: usize = self.shape[..dim].iter().product();
881 let inner_size: usize = self.shape[dim + 1..].iter().product();
882
883 for outer in 0..outer_size {
884 for inner in 0..inner_size {
885 let mut sum = T::zero();
886 for d in 0..dim_size {
887 let idx = outer * dim_size * inner_size + d * inner_size + inner;
888 sum = sum + data[idx];
889 }
890 let mean = sum / NumCast::from(dim_size).unwrap();
891 let result_idx = outer * inner_size + inner;
892 result[result_idx] = mean;
893 }
894 }
895
896 Self::from_vec(result, &new_shape).unwrap()
897 }
898
899 #[must_use]
901 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
902 let ndim = self.ndim();
903 let dim = if dim < 0 {
904 (ndim as i32 + dim) as usize
905 } else {
906 dim as usize
907 };
908
909 if dim >= ndim {
910 return self.clone();
911 }
912
913 let dim_size = self.shape[dim];
914 let data = self.to_vec();
915 let mut new_shape = self.shape.clone();
916
917 if keepdim {
918 new_shape[dim] = 1;
919 } else {
920 new_shape.remove(dim);
921 }
922
923 if new_shape.is_empty() {
924 new_shape = smallvec::smallvec![1];
925 }
926
927 let new_numel: usize = new_shape.iter().product();
928 let mut result = vec![T::zero(); new_numel];
929
930 let outer_size: usize = self.shape[..dim].iter().product();
931 let inner_size: usize = self.shape[dim + 1..].iter().product();
932
933 for outer in 0..outer_size {
934 for inner in 0..inner_size {
935 let mut sum = T::zero();
936 for d in 0..dim_size {
937 let idx = outer * dim_size * inner_size + d * inner_size + inner;
938 sum = sum + data[idx];
939 }
940 let result_idx = outer * inner_size + inner;
941 result[result_idx] = sum;
942 }
943 }
944
945 Self::from_vec(result, &new_shape).unwrap()
946 }
947
948 #[must_use]
950 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
951 let mean = self.mean_dim(dim, true);
952 let diff = self.sub(&mean).unwrap_or_else(|_| self.clone());
953 let squared = diff.mul(&diff).unwrap_or_else(|_| self.clone());
954 squared.mean_dim(dim, keepdim)
955 }
956
957 #[must_use]
959 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
960 if self.shape.as_slice() == shape {
961 return self.clone();
962 }
963
964 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
965 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
966
967 let total = numel(&result_shape);
968 let mut result_data = vec![T::zero(); total];
969 let self_data = self.storage.as_slice();
970
971 for i in 0..total {
972 let indices = crate::shape::unravel_index(i, &result_shape);
973 let self_idx = self.offset + linear_index(&indices, &self_strides);
974 result_data[i] = self_data[self_idx];
975 }
976
977 Self::from_vec(result_data, &result_shape).unwrap()
978 }
979
980 #[must_use]
982 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
983 let mut new_shape = Vec::with_capacity(self.ndim());
984 for (i, range) in ranges.iter().enumerate() {
985 if i < self.ndim() {
986 new_shape.push(range.end - range.start);
987 }
988 }
989 for i in ranges.len()..self.ndim() {
991 new_shape.push(self.shape[i]);
992 }
993
994 let new_numel: usize = new_shape.iter().product();
995 let mut result_data = vec![T::zero(); new_numel];
996 let self_data = self.to_vec();
997
998 let mut result_idx = 0;
1000 Self::slice_recursive(
1001 &self_data,
1002 &self.shape,
1003 ranges,
1004 0,
1005 0,
1006 &mut result_data,
1007 &mut result_idx,
1008 );
1009
1010 Self::from_vec(result_data, &new_shape).unwrap()
1011 }
1012
1013 fn slice_recursive(
1014 data: &[T],
1015 shape: &[usize],
1016 ranges: &[std::ops::Range<usize>],
1017 dim: usize,
1018 offset: usize,
1019 result: &mut [T],
1020 result_idx: &mut usize,
1021 ) {
1022 if dim == shape.len() {
1023 result[*result_idx] = data[offset];
1024 *result_idx += 1;
1025 return;
1026 }
1027
1028 let stride: usize = shape[dim + 1..].iter().product();
1029 let (start, end) = if dim < ranges.len() {
1030 (ranges[dim].start, ranges[dim].end)
1031 } else {
1032 (0, shape[dim])
1033 };
1034
1035 for i in start..end {
1036 Self::slice_recursive(
1037 data,
1038 shape,
1039 ranges,
1040 dim + 1,
1041 offset + i * stride,
1042 result,
1043 result_idx,
1044 );
1045 }
1046 }
1047}
1048
1049impl<T: Numeric> Tensor<T> {
1054 pub fn add(&self, other: &Self) -> Result<Self> {
1056 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1057 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1058 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1059
1060 let total = numel(&result_shape);
1061 let mut result_data = vec![T::zero(); total];
1062
1063 let self_data = self.storage.as_slice();
1064 let other_data = other.storage.as_slice();
1065
1066 for i in 0..total {
1067 let indices = crate::shape::unravel_index(i, &result_shape);
1068 let self_idx = self.offset + linear_index(&indices, &self_strides);
1069 let other_idx = other.offset + linear_index(&indices, &other_strides);
1070 result_data[i] = self_data[self_idx] + other_data[other_idx];
1071 }
1072
1073 Self::from_vec(result_data, &result_shape)
1074 }
1075
1076 pub fn sub(&self, other: &Self) -> Result<Self> {
1078 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1079 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1080 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1081
1082 let total = numel(&result_shape);
1083 let mut result_data = vec![T::zero(); total];
1084
1085 let self_data = self.storage.as_slice();
1086 let other_data = other.storage.as_slice();
1087
1088 for i in 0..total {
1089 let indices = crate::shape::unravel_index(i, &result_shape);
1090 let self_idx = self.offset + linear_index(&indices, &self_strides);
1091 let other_idx = other.offset + linear_index(&indices, &other_strides);
1092 result_data[i] = self_data[self_idx] - other_data[other_idx];
1093 }
1094
1095 Self::from_vec(result_data, &result_shape)
1096 }
1097
1098 pub fn mul(&self, other: &Self) -> Result<Self> {
1100 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1101 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1102 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1103
1104 let total = numel(&result_shape);
1105 let mut result_data = vec![T::zero(); total];
1106
1107 let self_data = self.storage.as_slice();
1108 let other_data = other.storage.as_slice();
1109
1110 for i in 0..total {
1111 let indices = crate::shape::unravel_index(i, &result_shape);
1112 let self_idx = self.offset + linear_index(&indices, &self_strides);
1113 let other_idx = other.offset + linear_index(&indices, &other_strides);
1114 result_data[i] = self_data[self_idx] * other_data[other_idx];
1115 }
1116
1117 Self::from_vec(result_data, &result_shape)
1118 }
1119
1120 pub fn div(&self, other: &Self) -> Result<Self> {
1122 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
1123 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
1124 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
1125
1126 let total = numel(&result_shape);
1127 let mut result_data = vec![T::zero(); total];
1128
1129 let self_data = self.storage.as_slice();
1130 let other_data = other.storage.as_slice();
1131
1132 for i in 0..total {
1133 let indices = crate::shape::unravel_index(i, &result_shape);
1134 let self_idx = self.offset + linear_index(&indices, &self_strides);
1135 let other_idx = other.offset + linear_index(&indices, &other_strides);
1136 result_data[i] = self_data[self_idx] / other_data[other_idx];
1137 }
1138
1139 Self::from_vec(result_data, &result_shape)
1140 }
1141
1142 #[must_use]
1144 pub fn add_scalar(&self, scalar: T) -> Self {
1145 let data = self.to_vec();
1146 let mut result = vec![T::zero(); data.len()];
1147 CpuBackend::add_scalar(&mut result, &data, scalar);
1148 Self::from_vec(result, &self.shape).unwrap()
1149 }
1150
1151 #[must_use]
1153 pub fn mul_scalar(&self, scalar: T) -> Self {
1154 let data = self.to_vec();
1155 let mut result = vec![T::zero(); data.len()];
1156 CpuBackend::mul_scalar(&mut result, &data, scalar);
1157 Self::from_vec(result, &self.shape).unwrap()
1158 }
1159
1160 #[must_use]
1162 pub fn neg(&self) -> Self {
1163 let data = self.to_vec();
1164 let mut result = vec![T::zero(); data.len()];
1165 CpuBackend::neg(&mut result, &data);
1166 Self::from_vec(result, &self.shape).unwrap()
1167 }
1168
1169 pub fn matmul(&self, other: &Self) -> Result<Self> {
1176 if self.ndim() < 2 || other.ndim() < 2 {
1177 return Err(Error::invalid_operation(
1178 "matmul requires at least 2D tensors",
1179 ));
1180 }
1181
1182 let m = self.shape[self.ndim() - 2];
1183 let k1 = self.shape[self.ndim() - 1];
1184 let k2 = other.shape[other.ndim() - 2];
1185 let n = other.shape[other.ndim() - 1];
1186
1187 if k1 != k2 {
1188 return Err(Error::invalid_operation(format!(
1189 "matmul inner dimensions must match: {k1} vs {k2}"
1190 )));
1191 }
1192
1193 if self.ndim() == 2 && other.ndim() == 2 {
1195 let a_data = self.contiguous().to_vec();
1196 let b_data = other.contiguous().to_vec();
1197
1198 #[cfg(feature = "cuda")]
1200 {
1201 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
1202 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1204 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1205 if let Some(c_f32) = cuda_accel::cuda_matmul(a_f32, b_f32, m, n, k1) {
1206 let c_t: Vec<T> = unsafe {
1207 let mut v = std::mem::ManuallyDrop::new(c_f32);
1208 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
1209 };
1210 return Self::from_vec(c_t, &[m, n]);
1211 }
1212 }
1213 }
1214
1215 let mut c_data = vec![T::zero(); m * n];
1216 CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1217 return Self::from_vec(c_data, &[m, n]);
1218 }
1219
1220 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1222 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1223
1224 if batch_dims_self != batch_dims_other {
1225 return Err(Error::invalid_operation(format!(
1226 "matmul batch dimensions must match: {:?} vs {:?}",
1227 batch_dims_self, batch_dims_other
1228 )));
1229 }
1230
1231 let batch_size: usize = batch_dims_self.iter().product();
1232 let a_stride = m * k1;
1233 let b_stride = k1 * n;
1234 let c_stride = m * n;
1235
1236 let a_data = self.contiguous().to_vec();
1237 let b_data = other.contiguous().to_vec();
1238 let mut c_data = vec![T::zero(); batch_size * m * n];
1239
1240 #[cfg(feature = "cuda")]
1242 {
1243 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
1244 let a_f32: &[f32] = unsafe { std::mem::transmute(a_data.as_slice()) };
1245 let b_f32: &[f32] = unsafe { std::mem::transmute(b_data.as_slice()) };
1246 let mut gpu_ok = true;
1247 for batch in 0..batch_size {
1248 let a_slice = &a_f32[batch * a_stride..(batch + 1) * a_stride];
1249 let b_slice = &b_f32[batch * b_stride..(batch + 1) * b_stride];
1250 if let Some(c_batch) = cuda_accel::cuda_matmul(a_slice, b_slice, m, n, k1) {
1251 c_data[batch * c_stride..(batch + 1) * c_stride]
1252 .copy_from_slice(unsafe { std::mem::transmute(c_batch.as_slice()) });
1253 } else {
1254 gpu_ok = false;
1255 break;
1256 }
1257 }
1258 if gpu_ok {
1259 let mut output_shape = batch_dims_self;
1260 output_shape.push(m);
1261 output_shape.push(n);
1262 return Self::from_vec(c_data, &output_shape);
1263 }
1264 c_data = vec![T::zero(); batch_size * m * n];
1266 }
1267 }
1268
1269 for batch in 0..batch_size {
1271 let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1272 let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1273 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1274 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1275 }
1276
1277 let mut output_shape = batch_dims_self;
1279 output_shape.push(m);
1280 output_shape.push(n);
1281
1282 Self::from_vec(c_data, &output_shape)
1283 }
1284
1285 pub fn dot(&self, other: &Self) -> Result<Self> {
1287 if self.ndim() != 1 || other.ndim() != 1 {
1288 return Err(Error::invalid_operation("dot requires 1D tensors"));
1289 }
1290
1291 if self.shape[0] != other.shape[0] {
1292 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1293 }
1294
1295 let a_data = self.to_vec();
1296 let b_data = other.to_vec();
1297 let result = CpuBackend::dot(&a_data, &b_data);
1298
1299 Ok(Self::scalar(result))
1300 }
1301}
1302
1303impl<T: Numeric> Add for &Tensor<T> {
1308 type Output = Tensor<T>;
1309
1310 fn add(self, other: Self) -> Self::Output {
1311 self.add(other).expect("Addition failed")
1312 }
1313}
1314
1315impl<T: Numeric> Sub for &Tensor<T> {
1316 type Output = Tensor<T>;
1317
1318 fn sub(self, other: Self) -> Self::Output {
1319 self.sub(other).expect("Subtraction failed")
1320 }
1321}
1322
1323impl<T: Numeric> Mul for &Tensor<T> {
1324 type Output = Tensor<T>;
1325
1326 fn mul(self, other: Self) -> Self::Output {
1327 self.mul(other).expect("Multiplication failed")
1328 }
1329}
1330
1331impl<T: Numeric> Div for &Tensor<T> {
1332 type Output = Tensor<T>;
1333
1334 fn div(self, other: Self) -> Self::Output {
1335 self.div(other).expect("Division failed")
1336 }
1337}
1338
1339impl<T: Numeric> Neg for &Tensor<T> {
1340 type Output = Tensor<T>;
1341
1342 fn neg(self) -> Self::Output {
1343 self.neg()
1344 }
1345}
1346
1347impl<T: Numeric> Add<T> for &Tensor<T> {
1349 type Output = Tensor<T>;
1350
1351 fn add(self, scalar: T) -> Self::Output {
1352 self.add_scalar(scalar)
1353 }
1354}
1355
1356impl<T: Numeric> Mul<T> for &Tensor<T> {
1357 type Output = Tensor<T>;
1358
1359 fn mul(self, scalar: T) -> Self::Output {
1360 self.mul_scalar(scalar)
1361 }
1362}
1363
1364impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1370 write!(
1371 f,
1372 "Tensor(shape={:?}, device={}",
1373 self.shape(),
1374 self.device()
1375 )?;
1376 if self.numel() <= 10 {
1377 write!(f, ", data={:?}", self.to_vec())?;
1378 }
1379 write!(f, ")")
1380 }
1381}
1382
1383impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1384 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1385 if self.is_scalar() {
1386 write!(f, "{}", self.item().unwrap())
1387 } else if self.ndim() == 1 {
1388 write!(f, "[")?;
1389 let data = self.to_vec();
1390 for (i, val) in data.iter().enumerate() {
1391 if i > 0 {
1392 write!(f, ", ")?;
1393 }
1394 write!(f, "{val}")?;
1395 }
1396 write!(f, "]")
1397 } else {
1398 write!(f, "Tensor(shape={:?})", self.shape())
1399 }
1400 }
1401}
1402
1403#[cfg(test)]
1408mod tests {
1409 use super::*;
1410
1411 #[test]
1412 fn test_from_vec() {
1413 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1414 assert_eq!(t.shape(), &[2, 3]);
1415 assert_eq!(t.numel(), 6);
1416 }
1417
1418 #[test]
1419 fn test_get_set() {
1420 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1421 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1422 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1423 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1424 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1425
1426 t.set(&[0, 0], 99.0).unwrap();
1427 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1428 }
1429
1430 #[test]
1431 fn test_reshape() {
1432 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1433 let r = t.reshape(&[3, 2]).unwrap();
1434 assert_eq!(r.shape(), &[3, 2]);
1435
1436 let r = t.reshape(&[-1]).unwrap();
1437 assert_eq!(r.shape(), &[6]);
1438 }
1439
1440 #[test]
1441 fn test_transpose() {
1442 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1443 let r = t.t().unwrap();
1444 assert_eq!(r.shape(), &[3, 2]);
1445 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1446 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1447 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1448 }
1449
1450 #[test]
1451 fn test_arithmetic() {
1452 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1453 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1454
1455 let c = &a + &b;
1456 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1457
1458 let d = &a * &b;
1459 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1460 }
1461
1462 #[test]
1463 fn test_broadcasting() {
1464 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1465 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1466
1467 let c = &a + &b;
1468 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1469 }
1470
1471 #[test]
1472 fn test_sum() {
1473 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1474 let s = t.sum();
1475 assert_eq!(s.item().unwrap(), 10.0);
1476 }
1477
1478 #[test]
1479 fn test_matmul() {
1480 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1482 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1483 let c = a.matmul(&b).unwrap();
1484
1485 assert_eq!(c.shape(), &[2, 2]);
1486 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1487 }
1488
1489 #[test]
1490 fn test_relu() {
1491 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1492 let r = t.relu();
1493 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1494 }
1495
1496 #[test]
1497 fn test_scalar() {
1498 let s = Tensor::<f32>::scalar(42.0);
1499 assert!(s.is_scalar());
1500 assert_eq!(s.numel(), 1);
1501 assert_eq!(s.item().unwrap(), 42.0);
1502 }
1503}