1use core::fmt;
18use core::ops::{Add, Div, Mul, Neg, Sub};
19
20use axonml_core::backends::CpuBackend;
21use axonml_core::dtype::{Float, Numeric, Scalar};
22use num_traits::NumCast;
23use axonml_core::error::{Error, Result};
24use axonml_core::storage::Storage;
25use axonml_core::Device;
26
27use crate::shape::{
28 broadcast_shape, broadcast_strides, contiguous_strides, is_contiguous, linear_index,
29 normalize_dim, numel, reshape, squeeze, transpose_shape, transpose_strides, unsqueeze, Shape,
30 Strides,
31};
32
33#[derive(Clone)]
43pub struct Tensor<T: Scalar> {
44 pub(crate) storage: Storage<T>,
46 pub(crate) shape: Shape,
48 pub(crate) strides: Strides,
50 pub(crate) offset: usize,
52}
53
54impl<T: Scalar> Tensor<T> {
55 pub fn from_storage(storage: Storage<T>, shape: &[usize]) -> Result<Self> {
68 let total = numel(shape);
69 if total != storage.len() {
70 return Err(Error::shape_mismatch(&[storage.len()], shape));
71 }
72
73 let shape = Shape::from_slice(shape);
74 let strides = contiguous_strides(&shape);
75
76 Ok(Self {
77 storage,
78 shape,
79 strides,
80 offset: 0,
81 })
82 }
83
84 pub fn from_vec(data: Vec<T>, shape: &[usize]) -> Result<Self> {
93 let storage = Storage::from_vec(data, Device::Cpu);
94 Self::from_storage(storage, shape)
95 }
96
97 pub fn from_slice(data: &[T], shape: &[usize]) -> Result<Self> {
106 let storage = Storage::from_slice(data, Device::Cpu);
107 Self::from_storage(storage, shape)
108 }
109
110 pub fn scalar(value: T) -> Self {
118 Self {
119 storage: Storage::from_vec(vec![value], Device::Cpu),
120 shape: Shape::new(),
121 strides: Strides::new(),
122 offset: 0,
123 }
124 }
125
126 #[must_use]
128 pub fn zeros(shape: &[usize]) -> Self {
129 crate::creation::zeros(shape)
130 }
131
132 #[must_use]
134 pub fn ones(shape: &[usize]) -> Self
135 where
136 T: Numeric,
137 {
138 crate::creation::ones(shape)
139 }
140
141 #[must_use]
143 pub fn full(shape: &[usize], value: T) -> Self {
144 crate::creation::full(shape, value)
145 }
146
147 #[must_use]
149 pub fn randn(shape: &[usize]) -> Self
150 where
151 T: Float,
152 rand_distr::StandardNormal: rand::distributions::Distribution<T>,
153 {
154 crate::creation::randn(shape)
155 }
156
157 #[must_use]
159 pub fn rand(shape: &[usize]) -> Self
160 where
161 T: Float,
162 rand::distributions::Standard: rand::distributions::Distribution<T>,
163 {
164 crate::creation::rand(shape)
165 }
166
167 #[must_use]
173 pub fn shape(&self) -> &[usize] {
174 &self.shape
175 }
176
177 #[must_use]
179 pub fn strides(&self) -> &[isize] {
180 &self.strides
181 }
182
183 #[must_use]
185 pub fn ndim(&self) -> usize {
186 self.shape.len()
187 }
188
189 #[must_use]
191 pub fn numel(&self) -> usize {
192 numel(&self.shape)
193 }
194
195 #[must_use]
197 pub fn is_empty(&self) -> bool {
198 self.numel() == 0
199 }
200
201 pub fn size(&self, dim: i64) -> Result<usize> {
206 let idx = normalize_dim(dim, self.ndim())?;
207 Ok(self.shape[idx])
208 }
209
210 #[must_use]
212 pub fn device(&self) -> Device {
213 self.storage.device()
214 }
215
216 #[must_use]
218 pub fn is_contiguous(&self) -> bool {
219 is_contiguous(&self.shape, &self.strides)
220 }
221
222 #[must_use]
224 pub fn is_scalar(&self) -> bool {
225 self.shape.is_empty()
226 }
227
228 pub fn get(&self, indices: &[usize]) -> Result<T> {
237 if indices.len() != self.ndim() {
238 return Err(Error::invalid_operation(format!(
239 "Expected {} indices, got {}",
240 self.ndim(),
241 indices.len()
242 )));
243 }
244
245 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
246 if idx >= dim {
247 return Err(Error::IndexOutOfBounds {
248 index: idx,
249 size: dim,
250 });
251 }
252 }
253
254 let offset = self.offset + linear_index(indices, &self.strides);
255 Ok(self.storage.as_slice()[offset])
256 }
257
258 pub fn set(&self, indices: &[usize], value: T) -> Result<()> {
264 if indices.len() != self.ndim() {
265 return Err(Error::invalid_operation(format!(
266 "Expected {} indices, got {}",
267 self.ndim(),
268 indices.len()
269 )));
270 }
271
272 for (&idx, &dim) in indices.iter().zip(self.shape.iter()) {
273 if idx >= dim {
274 return Err(Error::IndexOutOfBounds {
275 index: idx,
276 size: dim,
277 });
278 }
279 }
280
281 let offset = self.offset + linear_index(indices, &self.strides);
282 self.storage.as_slice_mut()[offset] = value;
283 Ok(())
284 }
285
286 pub fn item(&self) -> Result<T> {
288 if self.numel() != 1 {
289 return Err(Error::invalid_operation(
290 "item() only works on single-element tensors",
291 ));
292 }
293
294 if self.is_scalar() {
295 Ok(self.storage.as_slice()[self.offset])
296 } else {
297 let indices = vec![0; self.ndim()];
299 self.get(&indices)
300 }
301 }
302
303 #[must_use]
308 pub fn to_vec(&self) -> Vec<T> {
309 if self.is_contiguous() {
310 let storage = self.storage.as_slice();
311 storage[self.offset..self.offset + self.numel()].to_vec()
312 } else {
313 let mut result = Vec::with_capacity(self.numel());
314 self.copy_data_to(&mut result);
315 result
316 }
317 }
318
319 fn copy_data_to(&self, dst: &mut Vec<T>) {
321 dst.clear();
322 let storage = self.storage.as_slice();
323
324 let total = self.numel();
326 for i in 0..total {
327 let indices = crate::shape::unravel_index(i, &self.shape);
328 let offset = self.offset + linear_index(&indices, &self.strides);
329 dst.push(storage[offset]);
330 }
331 }
332
333 pub fn reshape(&self, new_shape: &[isize]) -> Result<Self> {
345 let shape = reshape(&self.shape, new_shape)?;
346
347 if self.is_contiguous() {
348 Ok(Self {
350 storage: self.storage.clone(),
351 strides: contiguous_strides(&shape),
352 shape,
353 offset: self.offset,
354 })
355 } else {
356 let contig = self.contiguous();
358 Ok(Self {
359 storage: contig.storage,
360 strides: contiguous_strides(&shape),
361 shape,
362 offset: 0,
363 })
364 }
365 }
366
367 #[must_use] pub fn flatten(&self) -> Self {
369 self.reshape(&[-1]).expect("Flatten should never fail")
370 }
371
372 pub fn squeeze(&self, dim: Option<i64>) -> Result<Self> {
377 let dim = match dim {
378 Some(d) => Some(normalize_dim(d, self.ndim())?),
379 None => None,
380 };
381
382 let new_shape = squeeze(&self.shape, dim);
383 let new_strides: Strides = match dim {
384 Some(d) => {
385 let mut s = self.strides.clone();
386 if d < self.shape.len() && self.shape[d] == 1 {
387 s.remove(d);
388 }
389 s
390 }
391 None => self
392 .shape
393 .iter()
394 .zip(self.strides.iter())
395 .filter(|(&dim, _)| dim != 1)
396 .map(|(_, &stride)| stride)
397 .collect(),
398 };
399
400 Ok(Self {
401 storage: self.storage.clone(),
402 shape: new_shape,
403 strides: new_strides,
404 offset: self.offset,
405 })
406 }
407
408 pub fn unsqueeze(&self, dim: i64) -> Result<Self> {
413 let normalized = if dim < 0 {
414 (dim + self.ndim() as i64 + 1) as usize
415 } else {
416 dim as usize
417 };
418
419 let new_shape = unsqueeze(&self.shape, normalized)?;
420 let mut new_strides = Strides::with_capacity(new_shape.len());
421
422 for (i, _) in new_shape.iter().enumerate() {
423 if i < normalized {
424 new_strides.push(self.strides.get(i).copied().unwrap_or(1));
425 } else if i == normalized {
426 new_strides.push(1);
428 } else {
429 new_strides.push(self.strides[i - 1]);
430 }
431 }
432
433 Ok(Self {
434 storage: self.storage.clone(),
435 shape: new_shape,
436 strides: new_strides,
437 offset: self.offset,
438 })
439 }
440
441 pub fn transpose(&self, dim0: i64, dim1: i64) -> Result<Self> {
447 let d0 = normalize_dim(dim0, self.ndim())?;
448 let d1 = normalize_dim(dim1, self.ndim())?;
449
450 let new_shape = transpose_shape(&self.shape, d0, d1)?;
451 let new_strides = transpose_strides(&self.strides, d0, d1);
452
453 Ok(Self {
454 storage: self.storage.clone(),
455 shape: new_shape,
456 strides: new_strides,
457 offset: self.offset,
458 })
459 }
460
461 pub fn t(&self) -> Result<Self> {
463 if self.ndim() != 2 {
464 return Err(Error::invalid_operation("t() only works on 2D tensors"));
465 }
466 self.transpose(0, 1)
467 }
468
469 pub fn permute(&self, dims: &[usize]) -> Result<Self> {
474 if dims.len() != self.ndim() {
475 return Err(Error::invalid_operation(format!(
476 "Expected {} dimensions, got {}",
477 self.ndim(),
478 dims.len()
479 )));
480 }
481
482 let mut seen = vec![false; self.ndim()];
484 for &d in dims {
485 if d >= self.ndim() {
486 return Err(Error::InvalidDimension {
487 index: d as i64,
488 ndim: self.ndim(),
489 });
490 }
491 if seen[d] {
492 return Err(Error::invalid_operation("Duplicate dimension in permute"));
493 }
494 seen[d] = true;
495 }
496
497 let new_shape: Shape = dims.iter().map(|&d| self.shape[d]).collect();
498 let new_strides: Strides = dims.iter().map(|&d| self.strides[d]).collect();
499
500 Ok(Self {
501 storage: self.storage.clone(),
502 shape: new_shape,
503 strides: new_strides,
504 offset: self.offset,
505 })
506 }
507
508 #[must_use]
510 pub fn contiguous(&self) -> Self {
511 if self.is_contiguous() && self.offset == 0 {
512 return self.clone();
513 }
514
515 let data = self.to_vec();
516 Self::from_vec(data, &self.shape).expect("Contiguous should never fail")
517 }
518
519 pub fn to_device(&self, device: Device) -> Result<Self> {
528 if self.device() == device {
529 return Ok(self.clone());
530 }
531
532 let contig = self.contiguous();
533 let new_storage = contig.storage.to_device(device)?;
534
535 Ok(Self {
536 storage: new_storage,
537 shape: self.shape.clone(),
538 strides: self.strides.clone(),
539 offset: 0,
540 })
541 }
542
543 pub fn cpu(&self) -> Result<Self> {
545 self.to_device(Device::Cpu)
546 }
547
548 #[must_use]
554 pub fn clone_deep(&self) -> Self {
555 let data = self.to_vec();
556 Self::from_vec(data, &self.shape).expect("Deep clone should never fail")
557 }
558}
559
560impl<T: Numeric> Tensor<T> {
565 pub fn fill_(&self, value: T) {
567 let mut data = self.storage.as_slice_mut();
568 CpuBackend::fill(&mut data, value);
569 }
570
571 pub fn zero_(&self) {
573 self.fill_(T::zero());
574 }
575
576 #[must_use] pub fn sum(&self) -> Self {
582 let data = self.to_vec();
583 let result = CpuBackend::sum(&data);
584 Self::scalar(result)
585 }
586
587 #[must_use] pub fn prod(&self) -> Self {
589 let data = self.to_vec();
590 let result = CpuBackend::prod(&data);
591 Self::scalar(result)
592 }
593
594 pub fn max(&self) -> Result<Self> {
596 if self.is_empty() {
597 return Err(Error::EmptyTensor);
598 }
599 let data = self.to_vec();
600 let result = CpuBackend::max(&data).unwrap();
601 Ok(Self::scalar(result))
602 }
603
604 pub fn min(&self) -> Result<Self> {
606 if self.is_empty() {
607 return Err(Error::EmptyTensor);
608 }
609 let data = self.to_vec();
610 let result = CpuBackend::min(&data).unwrap();
611 Ok(Self::scalar(result))
612 }
613
614 pub fn argmax(&self) -> Result<usize> {
616 if self.is_empty() {
617 return Err(Error::EmptyTensor);
618 }
619 let data = self.to_vec();
620 Ok(CpuBackend::argmax(&data).unwrap())
621 }
622
623 pub fn argmin(&self) -> Result<usize> {
625 if self.is_empty() {
626 return Err(Error::EmptyTensor);
627 }
628 let data = self.to_vec();
629 Ok(CpuBackend::argmin(&data).unwrap())
630 }
631}
632
633impl<T: Float> Tensor<T> {
638 pub fn mean(&self) -> Result<Self> {
640 if self.is_empty() {
641 return Err(Error::EmptyTensor);
642 }
643 let data = self.to_vec();
644 let result = CpuBackend::mean(&data).unwrap();
645 Ok(Self::scalar(result))
646 }
647
648 #[must_use]
654 pub fn relu(&self) -> Self {
655 let data = self.to_vec();
656 let mut result = vec![T::zero(); data.len()];
657 CpuBackend::relu(&mut result, &data);
658 Self::from_vec(result, &self.shape).unwrap()
659 }
660
661 #[must_use]
663 pub fn sigmoid(&self) -> Self {
664 let data = self.to_vec();
665 let mut result = vec![T::zero(); data.len()];
666 CpuBackend::sigmoid(&mut result, &data);
667 Self::from_vec(result, &self.shape).unwrap()
668 }
669
670 #[must_use]
672 pub fn tanh(&self) -> Self {
673 let data = self.to_vec();
674 let mut result = vec![T::zero(); data.len()];
675 CpuBackend::tanh(&mut result, &data);
676 Self::from_vec(result, &self.shape).unwrap()
677 }
678
679 #[must_use]
681 pub fn exp(&self) -> Self {
682 let data = self.to_vec();
683 let mut result = vec![T::zero(); data.len()];
684 CpuBackend::exp(&mut result, &data);
685 Self::from_vec(result, &self.shape).unwrap()
686 }
687
688 #[must_use]
690 pub fn ln(&self) -> Self {
691 let data = self.to_vec();
692 let mut result = vec![T::zero(); data.len()];
693 CpuBackend::ln(&mut result, &data);
694 Self::from_vec(result, &self.shape).unwrap()
695 }
696
697 #[must_use]
699 pub fn sqrt(&self) -> Self {
700 let data = self.to_vec();
701 let mut result = vec![T::zero(); data.len()];
702 CpuBackend::sqrt(&mut result, &data);
703 Self::from_vec(result, &self.shape).unwrap()
704 }
705
706 #[must_use]
708 pub fn pow(&self, exp: T) -> Self {
709 let data = self.to_vec();
710 let result: Vec<T> = data.iter().map(|&x| x.pow_value(exp)).collect();
711 Self::from_vec(result, &self.shape).unwrap()
712 }
713
714 #[must_use]
716 pub fn gelu(&self) -> Self {
717 crate::ops::gelu(self)
718 }
719
720 #[must_use]
722 pub fn silu(&self) -> Self {
723 crate::ops::silu(self)
724 }
725
726 #[must_use]
728 pub fn softmax(&self, dim: i32) -> Self {
729 crate::ops::softmax(self, dim as i64).unwrap_or_else(|_| self.clone())
730 }
731
732 #[must_use]
734 pub fn log_softmax(&self, dim: i32) -> Self {
735 let softmax_result = self.softmax(dim);
736 softmax_result.ln()
737 }
738
739 #[must_use]
741 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Self {
742 let ndim = self.ndim();
743 let dim = if dim < 0 { (ndim as i32 + dim) as usize } else { dim as usize };
744
745 if dim >= ndim {
746 return self.clone();
747 }
748
749 let dim_size = self.shape[dim];
750 let data = self.to_vec();
751 let mut new_shape = self.shape.clone();
752
753 if keepdim {
754 new_shape[dim] = 1;
755 } else {
756 new_shape.remove(dim);
757 }
758
759 if new_shape.is_empty() {
760 new_shape = smallvec::smallvec![1];
761 }
762
763 let new_numel: usize = new_shape.iter().product();
764 let mut result = vec![T::zero(); new_numel];
765
766 let outer_size: usize = self.shape[..dim].iter().product();
768 let inner_size: usize = self.shape[dim + 1..].iter().product();
769
770 for outer in 0..outer_size {
771 for inner in 0..inner_size {
772 let mut sum = T::zero();
773 for d in 0..dim_size {
774 let idx = outer * dim_size * inner_size + d * inner_size + inner;
775 sum = sum + data[idx];
776 }
777 let mean = sum / NumCast::from(dim_size).unwrap();
778 let result_idx = outer * inner_size + inner;
779 result[result_idx] = mean;
780 }
781 }
782
783 Self::from_vec(result, &new_shape).unwrap()
784 }
785
786 #[must_use]
788 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Self {
789 let mean = self.mean_dim(dim, true);
790 let diff = self.sub(&mean).unwrap_or_else(|_| self.clone());
791 let squared = diff.mul(&diff).unwrap_or_else(|_| self.clone());
792 squared.mean_dim(dim, keepdim)
793 }
794
795 #[must_use]
797 pub fn broadcast_to(&self, shape: &[usize]) -> Self {
798 if self.shape.as_slice() == shape {
799 return self.clone();
800 }
801
802 let result_shape = broadcast_shape(&self.shape, shape).unwrap_or_else(|_| shape.into());
803 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
804
805 let total = numel(&result_shape);
806 let mut result_data = vec![T::zero(); total];
807 let self_data = self.storage.as_slice();
808
809 for i in 0..total {
810 let indices = crate::shape::unravel_index(i, &result_shape);
811 let self_idx = self.offset + linear_index(&indices, &self_strides);
812 result_data[i] = self_data[self_idx];
813 }
814
815 Self::from_vec(result_data, &result_shape).unwrap()
816 }
817
818 #[must_use]
820 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Self {
821 let mut new_shape = Vec::with_capacity(self.ndim());
822 for (i, range) in ranges.iter().enumerate() {
823 if i < self.ndim() {
824 new_shape.push(range.end - range.start);
825 }
826 }
827 for i in ranges.len()..self.ndim() {
829 new_shape.push(self.shape[i]);
830 }
831
832 let new_numel: usize = new_shape.iter().product();
833 let mut result_data = vec![T::zero(); new_numel];
834 let self_data = self.to_vec();
835
836 let mut result_idx = 0;
838 Self::slice_recursive(
839 &self_data,
840 &self.shape,
841 ranges,
842 0,
843 0,
844 &mut result_data,
845 &mut result_idx,
846 );
847
848 Self::from_vec(result_data, &new_shape).unwrap()
849 }
850
851 fn slice_recursive(
852 data: &[T],
853 shape: &[usize],
854 ranges: &[std::ops::Range<usize>],
855 dim: usize,
856 offset: usize,
857 result: &mut [T],
858 result_idx: &mut usize,
859 ) {
860 if dim == shape.len() {
861 result[*result_idx] = data[offset];
862 *result_idx += 1;
863 return;
864 }
865
866 let stride: usize = shape[dim + 1..].iter().product();
867 let (start, end) = if dim < ranges.len() {
868 (ranges[dim].start, ranges[dim].end)
869 } else {
870 (0, shape[dim])
871 };
872
873 for i in start..end {
874 Self::slice_recursive(data, shape, ranges, dim + 1, offset + i * stride, result, result_idx);
875 }
876 }
877}
878
879impl<T: Numeric> Tensor<T> {
884 pub fn add(&self, other: &Self) -> Result<Self> {
886 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
887 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
888 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
889
890 let total = numel(&result_shape);
891 let mut result_data = vec![T::zero(); total];
892
893 let self_data = self.storage.as_slice();
894 let other_data = other.storage.as_slice();
895
896 for i in 0..total {
897 let indices = crate::shape::unravel_index(i, &result_shape);
898 let self_idx = self.offset + linear_index(&indices, &self_strides);
899 let other_idx = other.offset + linear_index(&indices, &other_strides);
900 result_data[i] = self_data[self_idx] + other_data[other_idx];
901 }
902
903 Self::from_vec(result_data, &result_shape)
904 }
905
906 pub fn sub(&self, other: &Self) -> Result<Self> {
908 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
909 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
910 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
911
912 let total = numel(&result_shape);
913 let mut result_data = vec![T::zero(); total];
914
915 let self_data = self.storage.as_slice();
916 let other_data = other.storage.as_slice();
917
918 for i in 0..total {
919 let indices = crate::shape::unravel_index(i, &result_shape);
920 let self_idx = self.offset + linear_index(&indices, &self_strides);
921 let other_idx = other.offset + linear_index(&indices, &other_strides);
922 result_data[i] = self_data[self_idx] - other_data[other_idx];
923 }
924
925 Self::from_vec(result_data, &result_shape)
926 }
927
928 pub fn mul(&self, other: &Self) -> Result<Self> {
930 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
931 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
932 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
933
934 let total = numel(&result_shape);
935 let mut result_data = vec![T::zero(); total];
936
937 let self_data = self.storage.as_slice();
938 let other_data = other.storage.as_slice();
939
940 for i in 0..total {
941 let indices = crate::shape::unravel_index(i, &result_shape);
942 let self_idx = self.offset + linear_index(&indices, &self_strides);
943 let other_idx = other.offset + linear_index(&indices, &other_strides);
944 result_data[i] = self_data[self_idx] * other_data[other_idx];
945 }
946
947 Self::from_vec(result_data, &result_shape)
948 }
949
950 pub fn div(&self, other: &Self) -> Result<Self> {
952 let result_shape = broadcast_shape(&self.shape, &other.shape)?;
953 let self_strides = broadcast_strides(&self.shape, &self.strides, &result_shape);
954 let other_strides = broadcast_strides(&other.shape, &other.strides, &result_shape);
955
956 let total = numel(&result_shape);
957 let mut result_data = vec![T::zero(); total];
958
959 let self_data = self.storage.as_slice();
960 let other_data = other.storage.as_slice();
961
962 for i in 0..total {
963 let indices = crate::shape::unravel_index(i, &result_shape);
964 let self_idx = self.offset + linear_index(&indices, &self_strides);
965 let other_idx = other.offset + linear_index(&indices, &other_strides);
966 result_data[i] = self_data[self_idx] / other_data[other_idx];
967 }
968
969 Self::from_vec(result_data, &result_shape)
970 }
971
972 #[must_use]
974 pub fn add_scalar(&self, scalar: T) -> Self {
975 let data = self.to_vec();
976 let mut result = vec![T::zero(); data.len()];
977 CpuBackend::add_scalar(&mut result, &data, scalar);
978 Self::from_vec(result, &self.shape).unwrap()
979 }
980
981 #[must_use]
983 pub fn mul_scalar(&self, scalar: T) -> Self {
984 let data = self.to_vec();
985 let mut result = vec![T::zero(); data.len()];
986 CpuBackend::mul_scalar(&mut result, &data, scalar);
987 Self::from_vec(result, &self.shape).unwrap()
988 }
989
990 #[must_use]
992 pub fn neg(&self) -> Self {
993 let data = self.to_vec();
994 let mut result = vec![T::zero(); data.len()];
995 CpuBackend::neg(&mut result, &data);
996 Self::from_vec(result, &self.shape).unwrap()
997 }
998
999 pub fn matmul(&self, other: &Self) -> Result<Self> {
1006 if self.ndim() < 2 || other.ndim() < 2 {
1007 return Err(Error::invalid_operation(
1008 "matmul requires at least 2D tensors",
1009 ));
1010 }
1011
1012 let m = self.shape[self.ndim() - 2];
1013 let k1 = self.shape[self.ndim() - 1];
1014 let k2 = other.shape[other.ndim() - 2];
1015 let n = other.shape[other.ndim() - 1];
1016
1017 if k1 != k2 {
1018 return Err(Error::invalid_operation(format!(
1019 "matmul inner dimensions must match: {k1} vs {k2}"
1020 )));
1021 }
1022
1023 if self.ndim() == 2 && other.ndim() == 2 {
1025 let a_data = self.contiguous().to_vec();
1026 let b_data = other.contiguous().to_vec();
1027 let mut c_data = vec![T::zero(); m * n];
1028 CpuBackend::matmul(&mut c_data, &a_data, &b_data, m, n, k1);
1029 return Self::from_vec(c_data, &[m, n]);
1030 }
1031
1032 let batch_dims_self: Vec<usize> = self.shape[..self.ndim() - 2].to_vec();
1034 let batch_dims_other: Vec<usize> = other.shape[..other.ndim() - 2].to_vec();
1035
1036 if batch_dims_self != batch_dims_other {
1037 return Err(Error::invalid_operation(format!(
1038 "matmul batch dimensions must match: {:?} vs {:?}",
1039 batch_dims_self, batch_dims_other
1040 )));
1041 }
1042
1043 let batch_size: usize = batch_dims_self.iter().product();
1044 let a_stride = m * k1;
1045 let b_stride = k1 * n;
1046 let c_stride = m * n;
1047
1048 let a_data = self.contiguous().to_vec();
1049 let b_data = other.contiguous().to_vec();
1050 let mut c_data = vec![T::zero(); batch_size * m * n];
1051
1052 for batch in 0..batch_size {
1054 let a_slice = &a_data[batch * a_stride..(batch + 1) * a_stride];
1055 let b_slice = &b_data[batch * b_stride..(batch + 1) * b_stride];
1056 let c_slice = &mut c_data[batch * c_stride..(batch + 1) * c_stride];
1057 CpuBackend::matmul(c_slice, a_slice, b_slice, m, n, k1);
1058 }
1059
1060 let mut output_shape = batch_dims_self;
1062 output_shape.push(m);
1063 output_shape.push(n);
1064
1065 Self::from_vec(c_data, &output_shape)
1066 }
1067
1068 pub fn dot(&self, other: &Self) -> Result<Self> {
1070 if self.ndim() != 1 || other.ndim() != 1 {
1071 return Err(Error::invalid_operation("dot requires 1D tensors"));
1072 }
1073
1074 if self.shape[0] != other.shape[0] {
1075 return Err(Error::shape_mismatch(&self.shape, &other.shape));
1076 }
1077
1078 let a_data = self.to_vec();
1079 let b_data = other.to_vec();
1080 let result = CpuBackend::dot(&a_data, &b_data);
1081
1082 Ok(Self::scalar(result))
1083 }
1084}
1085
1086impl<T: Numeric> Add for &Tensor<T> {
1091 type Output = Tensor<T>;
1092
1093 fn add(self, other: Self) -> Self::Output {
1094 self.add(other).expect("Addition failed")
1095 }
1096}
1097
1098impl<T: Numeric> Sub for &Tensor<T> {
1099 type Output = Tensor<T>;
1100
1101 fn sub(self, other: Self) -> Self::Output {
1102 self.sub(other).expect("Subtraction failed")
1103 }
1104}
1105
1106impl<T: Numeric> Mul for &Tensor<T> {
1107 type Output = Tensor<T>;
1108
1109 fn mul(self, other: Self) -> Self::Output {
1110 self.mul(other).expect("Multiplication failed")
1111 }
1112}
1113
1114impl<T: Numeric> Div for &Tensor<T> {
1115 type Output = Tensor<T>;
1116
1117 fn div(self, other: Self) -> Self::Output {
1118 self.div(other).expect("Division failed")
1119 }
1120}
1121
1122impl<T: Numeric> Neg for &Tensor<T> {
1123 type Output = Tensor<T>;
1124
1125 fn neg(self) -> Self::Output {
1126 self.neg()
1127 }
1128}
1129
1130impl<T: Numeric> Add<T> for &Tensor<T> {
1132 type Output = Tensor<T>;
1133
1134 fn add(self, scalar: T) -> Self::Output {
1135 self.add_scalar(scalar)
1136 }
1137}
1138
1139impl<T: Numeric> Mul<T> for &Tensor<T> {
1140 type Output = Tensor<T>;
1141
1142 fn mul(self, scalar: T) -> Self::Output {
1143 self.mul_scalar(scalar)
1144 }
1145}
1146
1147impl<T: Scalar + fmt::Display> fmt::Debug for Tensor<T> {
1152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1153 write!(
1154 f,
1155 "Tensor(shape={:?}, device={}",
1156 self.shape(),
1157 self.device()
1158 )?;
1159 if self.numel() <= 10 {
1160 write!(f, ", data={:?}", self.to_vec())?;
1161 }
1162 write!(f, ")")
1163 }
1164}
1165
1166impl<T: Scalar + fmt::Display> fmt::Display for Tensor<T> {
1167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1168 if self.is_scalar() {
1169 write!(f, "{}", self.item().unwrap())
1170 } else if self.ndim() == 1 {
1171 write!(f, "[")?;
1172 let data = self.to_vec();
1173 for (i, val) in data.iter().enumerate() {
1174 if i > 0 {
1175 write!(f, ", ")?;
1176 }
1177 write!(f, "{val}")?;
1178 }
1179 write!(f, "]")
1180 } else {
1181 write!(f, "Tensor(shape={:?})", self.shape())
1182 }
1183 }
1184}
1185
1186#[cfg(test)]
1191mod tests {
1192 use super::*;
1193
1194 #[test]
1195 fn test_from_vec() {
1196 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1197 assert_eq!(t.shape(), &[2, 3]);
1198 assert_eq!(t.numel(), 6);
1199 }
1200
1201 #[test]
1202 fn test_get_set() {
1203 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1204 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
1205 assert_eq!(t.get(&[0, 1]).unwrap(), 2.0);
1206 assert_eq!(t.get(&[1, 0]).unwrap(), 3.0);
1207 assert_eq!(t.get(&[1, 1]).unwrap(), 4.0);
1208
1209 t.set(&[0, 0], 99.0).unwrap();
1210 assert_eq!(t.get(&[0, 0]).unwrap(), 99.0);
1211 }
1212
1213 #[test]
1214 fn test_reshape() {
1215 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1216 let r = t.reshape(&[3, 2]).unwrap();
1217 assert_eq!(r.shape(), &[3, 2]);
1218
1219 let r = t.reshape(&[-1]).unwrap();
1220 assert_eq!(r.shape(), &[6]);
1221 }
1222
1223 #[test]
1224 fn test_transpose() {
1225 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1226 let r = t.t().unwrap();
1227 assert_eq!(r.shape(), &[3, 2]);
1228 assert_eq!(r.get(&[0, 0]).unwrap(), 1.0);
1229 assert_eq!(r.get(&[0, 1]).unwrap(), 4.0);
1230 assert_eq!(r.get(&[1, 0]).unwrap(), 2.0);
1231 }
1232
1233 #[test]
1234 fn test_arithmetic() {
1235 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1236 let b = Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
1237
1238 let c = &a + &b;
1239 assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
1240
1241 let d = &a * &b;
1242 assert_eq!(d.to_vec(), vec![4.0, 10.0, 18.0]);
1243 }
1244
1245 #[test]
1246 fn test_broadcasting() {
1247 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1248 let b = Tensor::<f32>::from_vec(vec![10.0], &[1]).unwrap();
1249
1250 let c = &a + &b;
1251 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
1252 }
1253
1254 #[test]
1255 fn test_sum() {
1256 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
1257 let s = t.sum();
1258 assert_eq!(s.item().unwrap(), 10.0);
1259 }
1260
1261 #[test]
1262 fn test_matmul() {
1263 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1265 let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
1266 let c = a.matmul(&b).unwrap();
1267
1268 assert_eq!(c.shape(), &[2, 2]);
1269 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1270 }
1271
1272 #[test]
1273 fn test_relu() {
1274 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap();
1275 let r = t.relu();
1276 assert_eq!(r.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
1277 }
1278
1279 #[test]
1280 fn test_scalar() {
1281 let s = Tensor::<f32>::scalar(42.0);
1282 assert!(s.is_scalar());
1283 assert_eq!(s.numel(), 1);
1284 assert_eq!(s.item().unwrap(), 42.0);
1285 }
1286}