1#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15 fn new() -> Self {
16 use std::sync::atomic;
18 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20 }
21}
22
23pub struct Tensor_ {
24 id: TensorId,
25 storage: Arc<RwLock<Storage>>,
38 layout: Layout,
39 op: BackpropOp,
40 is_variable: bool,
41 dtype: DType,
42 device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46 fn as_ref(&self) -> &Tensor {
47 self
48 }
49}
50
51#[derive(Clone)]
55pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71 type Target = Tensor_;
72
73 fn deref(&self) -> &Self::Target {
74 self.0.as_ref()
75 }
76}
77
78macro_rules! unary_op {
79 ($fn_name:ident, $op_name:ident) => {
80 pub fn $fn_name(&self) -> Result<Self> {
81 let shape = self.shape();
82 if shape.elem_count() == 0 {
83 return Ok(self.clone());
84 }
85 let storage = self
86 .storage()
87 .unary_impl::<crate::op::$op_name>(self.layout())?;
88 let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89 Ok(from_storage(storage, shape.clone(), op, false))
90 }
91 };
92}
93
94macro_rules! binary_op {
95 ($fn_name:ident, $op_name:ident) => {
96 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97 let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98 if shape.elem_count() == 0 {
99 return Ok(self.clone());
100 }
101 let storage = self.storage().binary_impl::<crate::op::$op_name>(
102 &*rhs.storage(),
103 self.layout(),
104 rhs.layout(),
105 )?;
106 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107 Ok(from_storage(storage, shape.clone(), op, false))
108 }
109 };
110}
111
112macro_rules! binary_op_scalar {
113 ($fn_name:ident, $op_name:ident) => {
114 pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115 let rhs = match rhs.to_tensor_scalar()? {
116 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117 crate::scalar::TensorScalar::Scalar(rhs) => rhs
118 .to_dtype(self.dtype())?
119 .to_device(self.device())?
120 .broadcast_as(self.shape())?,
121 };
122 let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123 if self.elem_count() == 0 {
124 return Ok(self.clone());
125 }
126 let storage = self.storage().binary_impl::<crate::op::$op_name>(
127 &*rhs.storage(),
128 self.layout(),
129 rhs.layout(),
130 )?;
131 let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132 Ok(from_storage(storage, shape.clone(), op, false))
133 }
134 };
135}
136
137macro_rules! broadcast_binary_op {
138 ($fn_name:ident, $inner_fn_name:ident) => {
139 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140 let lhs = self;
141 let shape = lhs
142 .shape()
143 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144 let l_broadcast = shape != *lhs.shape();
145 let r_broadcast = shape != *rhs.shape();
146 match (l_broadcast, r_broadcast) {
147 (true, true) => lhs
148 .broadcast_as(&shape)?
149 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152 (false, false) => lhs.$inner_fn_name(rhs),
153 }
154 }
155 };
156}
157
158pub(crate) fn from_storage<S: Into<Shape>>(
160 storage: Storage,
161 shape: S,
162 op: BackpropOp,
163 is_variable: bool,
164) -> Tensor {
165 let dtype = storage.dtype();
166 let device = storage.device();
167 let tensor_ = Tensor_ {
168 id: TensorId::new(),
169 storage: Arc::new(RwLock::new(storage)),
170 layout: Layout::contiguous(shape),
171 op,
172 is_variable,
173 dtype,
174 device,
175 };
176 Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180 pub(crate) fn ones_impl<S: Into<Shape>>(
181 shape: S,
182 dtype: DType,
183 device: &Device,
184 is_variable: bool,
185 ) -> Result<Self> {
186 let none = BackpropOp::none();
187 let shape = shape.into();
188 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189 let layout = Layout::contiguous(shape.clone());
190 storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191 Ok(from_storage(storage, shape, none, is_variable))
192 }
193
194 pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204 Self::ones_impl(shape, dtype, device, false)
205 }
206
207 pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208 self.storage_mut().const_set(value, self.layout())
209 }
210
211 pub fn zero_set(&self) -> Result<()> {
212 self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213 }
214
215 pub fn one_set(&self) -> Result<()> {
216 self.const_set(crate::scalar::Scalar::one(self.dtype()))
217 }
218
219 pub fn ones_like(&self) -> Result<Self> {
229 Tensor::ones(self.shape(), self.dtype(), self.device())
230 }
231
232 pub(crate) fn zeros_impl<S: Into<Shape>>(
235 shape: S,
236 dtype: DType,
237 device: &Device,
238 is_variable: bool,
239 ) -> Result<Self> {
240 let none = BackpropOp::none();
241 let shape = shape.into();
242 let storage = device.zeros(&shape, dtype)?;
243 Ok(from_storage(storage, shape, none, is_variable))
244 }
245
246 pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256 Self::zeros_impl(shape, dtype, device, false)
257 }
258
259 pub fn zeros_like(&self) -> Result<Self> {
270 Tensor::zeros(self.shape(), self.dtype(), self.device())
271 }
272
273 pub(crate) unsafe fn empty_impl<S: Into<Shape>>(
276 shape: S,
277 dtype: DType,
278 device: &Device,
279 is_variable: bool,
280 ) -> Result<Self> {
281 let none = BackpropOp::none();
282 let shape = shape.into();
283 let storage = device.alloc_uninit(&shape, dtype)?;
284 Ok(from_storage(storage, shape, none, is_variable))
285 }
286
287 pub unsafe fn empty<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
299 Self::empty_impl(shape, dtype, device, false)
300 }
301
302 pub unsafe fn empty_like(&self) -> Result<Self> {
315 Tensor::empty(self.shape(), self.dtype(), self.device())
316 }
317
318 pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
319 lo: T,
320 up: T,
321 s: S,
322 device: &Device,
323 is_variable: bool,
324 ) -> Result<Self> {
325 let s = s.into();
326 let storage = device.rand_uniform(lo, up, &s)?;
327 let none = BackpropOp::none();
328 Ok(from_storage(storage, s, none, is_variable))
329 }
330
331 pub(crate) fn rand_f64_impl<S: Into<Shape>>(
332 lo: f64,
333 up: f64,
334 s: S,
335 dtype: DType,
336 device: &Device,
337 is_variable: bool,
338 ) -> Result<Self> {
339 let s = s.into();
340 let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
341 let none = BackpropOp::none();
342 Ok(from_storage(storage, s, none, is_variable))
343 }
344
345 pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
347 lo: T,
348 up: T,
349 s: S,
350 device: &Device,
351 ) -> Result<Self> {
352 Self::rand_impl(lo, up, s, device, false)
353 }
354
355 pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
356 Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
357 }
358
359 pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
360 mean: T,
361 std: T,
362 s: S,
363 device: &Device,
364 is_variable: bool,
365 ) -> Result<Self> {
366 let s = s.into();
367 let storage = device.rand_normal(mean, std, &s)?;
368 let none = BackpropOp::none();
369 Ok(from_storage(storage, s, none, is_variable))
370 }
371
372 pub(crate) fn randn_f64_impl<S: Into<Shape>>(
373 mean: f64,
374 std: f64,
375 s: S,
376 dtype: DType,
377 device: &Device,
378 is_variable: bool,
379 ) -> Result<Self> {
380 let s = s.into();
381 let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
382 let none = BackpropOp::none();
383 Ok(from_storage(storage, s, none, is_variable))
384 }
385
386 pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
387 Tensor::randn_f64_impl(
388 mean,
389 stdev,
390 self.shape(),
391 self.dtype(),
392 self.device(),
393 false,
394 )
395 }
396
397 pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
400 mean: T,
401 std: T,
402 s: S,
403 device: &Device,
404 ) -> Result<Self> {
405 Self::randn_impl(mean, std, s, device, false)
406 }
407
408 pub(crate) fn new_impl<A: crate::device::NdArray>(
409 array: A,
410 shape: Shape,
411 device: &Device,
412 is_variable: bool,
413 ) -> Result<Self> {
414 let n: usize = shape.elem_count();
415 let buffer_size: usize = array.shape()?.elem_count();
416 if buffer_size != n {
417 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
418 }
419 let storage = device.storage(array)?;
420 let none = BackpropOp::none();
421 Ok(from_storage(storage, shape, none, is_variable))
422 }
423
424 pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
426 let shape = array.shape()?;
427 Self::new_impl(array, shape, device, false)
428 }
429
430 pub fn full<D: crate::WithDType, S: Into<Shape>>(
441 value: D,
442 shape: S,
443 device: &Device,
444 ) -> Result<Self> {
445 let none = BackpropOp::none();
446 let shape = shape.into();
447 let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
448 let layout = Layout::contiguous(shape.clone());
449 storage.const_set(value.to_scalar(), &layout)?;
450 Ok(from_storage(storage, shape, none, false))
451 }
452
453 pub fn from_iter<D: crate::WithDType>(
462 iter: impl IntoIterator<Item = D>,
463 device: &Device,
464 ) -> Result<Self> {
465 let data = iter.into_iter().collect::<Vec<_>>();
466 let len = data.len();
467 Self::from_vec_impl(data, len, device, false)
468 }
469
470 pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
480 Self::arange_step(start, end, D::one(), device)
481 }
482
483 pub fn arange_step<D: crate::WithDType>(
493 start: D,
494 end: D,
495 step: D,
496 device: &Device,
497 ) -> Result<Self> {
498 if D::is_zero(&step) {
499 bail!("step cannot be zero")
500 }
501 let mut data = vec![];
502 let mut current = start;
503 if step >= D::zero() {
504 while current < end {
505 data.push(current);
506 current += step;
507 }
508 } else {
509 while current > end {
510 data.push(current);
511 current += step;
512 }
513 }
514 let len = data.len();
515 Self::from_vec_impl(data, len, device, false)
516 }
517
518 pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
519 data: Vec<D>,
520 shape: S,
521 device: &Device,
522 is_variable: bool,
523 ) -> Result<Self> {
524 let shape = shape.into_shape(data.len())?;
525 let storage = device.storage_owned(data)?;
526 let none = BackpropOp::none();
527 Ok(from_storage(storage, shape, none, is_variable))
528 }
529
530 pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
544 data: Vec<D>,
545 shape: S,
546 device: &Device,
547 ) -> Result<Self> {
548 Self::from_vec_impl(data, shape, device, false)
549 }
550
551 pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
565 array: &[D],
566 shape: S,
567 device: &Device,
568 ) -> Result<Self> {
569 let shape = shape.into_shape(array.len())?;
570 let storage = device.storage_from_slice(array)?;
571 let none = BackpropOp::none();
572 Ok(from_storage(storage, shape, none, false))
573 }
574
575 pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
576 let lhs = self.shape();
577 let rhs = rhs.shape();
578 if lhs != rhs {
579 Err(Error::ShapeMismatchBinaryOp {
580 lhs: lhs.clone(),
581 rhs: rhs.clone(),
582 op,
583 }
584 .bt())
585 } else {
586 Ok(lhs)
587 }
588 }
589
590 pub fn track_op(&self) -> bool {
593 self.is_variable || self.op.is_some()
594 }
595
596 pub fn from_storage<S: Into<Shape>>(
602 storage: Storage,
603 shape: S,
604 op: BackpropOp,
605 is_variable: bool,
606 ) -> Tensor {
607 from_storage(storage, shape, op, is_variable)
608 }
609
610 binary_op!(add, Add);
613 binary_op!(mul, Mul);
614 binary_op!(sub, Sub);
615 binary_op!(div, Div);
616 binary_op_scalar!(maximum, Maximum);
617 binary_op_scalar!(minimum, Minimum);
618 broadcast_binary_op!(broadcast_add, add);
619 broadcast_binary_op!(broadcast_mul, mul);
620 broadcast_binary_op!(broadcast_sub, sub);
621 broadcast_binary_op!(broadcast_div, div);
622 broadcast_binary_op!(broadcast_maximum, maximum);
623 broadcast_binary_op!(broadcast_minimum, minimum);
624 broadcast_binary_op!(broadcast_eq, eq);
625 broadcast_binary_op!(broadcast_ne, ne);
626 broadcast_binary_op!(broadcast_lt, lt);
627 broadcast_binary_op!(broadcast_le, le);
628 broadcast_binary_op!(broadcast_gt, gt);
629 broadcast_binary_op!(broadcast_ge, ge);
630
631 unary_op!(recip, Recip);
632 unary_op!(neg, Neg);
633 unary_op!(exp, Exp);
634 unary_op!(log, Log);
635 unary_op!(sin, Sin);
636 unary_op!(cos, Cos);
637 unary_op!(tanh, Tanh);
638 unary_op!(abs, Abs);
639 unary_op!(sqr, Sqr);
640 unary_op!(sqrt, Sqrt);
641 unary_op!(gelu, Gelu);
642 unary_op!(gelu_erf, GeluErf);
643 unary_op!(erf, Erf);
644 unary_op!(relu, Relu);
645 unary_op!(silu, Silu);
646 unary_op!(ceil, Ceil);
647 unary_op!(floor, Floor);
648 unary_op!(round, Round);
649 unary_op!(sign, Sign);
650
651 pub fn round_to(&self, decimals: i32) -> Result<Self> {
656 let mult = 10f64.powi(decimals);
657 (self * mult)?.round()? * (1f64 / mult)
658 }
659
660 pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
663 if self.rank() != 0 {
664 Err(Error::UnexpectedNumberOfDims {
665 expected: 0,
666 got: self.rank(),
667 shape: self.shape().clone(),
668 }
669 .bt())?
670 }
671 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
672 let data = S::cpu_storage_as_slice(cpu_storage)?;
673 Ok::<_, Error>(data[self.layout().start_offset()])
674 };
675 match &*self.storage() {
676 Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
677 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
678 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
679 #[cfg(feature = "rocm")]
680 Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
681 #[cfg(feature = "vulkan")]
682 Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
683 }
684 }
685
686 pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
688 self.to_scalar::<S>()
689 }
690
691 pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
693 let repeats = shape.into();
695 let repeats = repeats.dims();
696 let mut inp = if self.rank() < repeats.len() {
697 let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
698 self.reshape(shape)?
699 } else {
700 self.clone()
701 };
702 for (idx, &repeat) in repeats.iter().enumerate() {
703 if repeat > 1 {
704 inp = Tensor::cat(&vec![&inp; repeat], idx)?
705 }
706 }
707 Ok(inp)
708 }
709
710 pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
747 if args.len() <= 1 {
748 Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
749 }
750 let args: Vec<_> = if xy_indexing {
751 args.iter().rev().collect()
752 } else {
753 args.iter().collect()
754 };
755
756 let mut shape = Vec::with_capacity(args.len());
757 for arg in args.iter() {
758 shape.push(arg.as_ref().dims1()?)
759 }
760
761 let mut grids = Vec::with_capacity(args.len());
762 for idx in 0..args.len() {
763 let mut ones = vec![1usize; args.len()];
764 ones[idx] = shape[idx];
765 let arg = args[idx].as_ref().reshape(ones)?;
766 let mut repeats = shape.clone();
767 repeats[idx] = 1;
768 let repeated_tensor = arg.repeat(repeats)?;
769 grids.push(repeated_tensor);
770 }
771 if xy_indexing {
772 grids.reverse();
773 }
774 Ok(grids)
775 }
776
777 pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
789 if self.elem_count() == 0 {
790 return Ok(self.clone());
791 }
792 let storage = self.storage().affine(self.layout(), mul, add)?;
793 let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
794 Ok(from_storage(storage, self.shape(), op, false))
795 }
796
797 pub fn elu(&self, alpha: f64) -> Result<Self> {
799 if self.elem_count() == 0 {
800 return Ok(self.clone());
801 }
802 let storage = self.storage().elu(self.layout(), alpha)?;
803 let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
804 Ok(from_storage(storage, self.shape(), op, false))
805 }
806
807 pub fn powf(&self, e: f64) -> Result<Self> {
809 if self.elem_count() == 0 {
810 return Ok(self.clone());
811 }
812 let storage = self.storage().powf(self.layout(), e)?;
813 let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
814 Ok(from_storage(storage, self.shape(), op, false))
815 }
816
817 pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
818 if dim >= self.dims().len() {
819 Err(Error::DimOutOfRange {
820 shape: self.shape().clone(),
821 dim: dim as i32,
822 op,
823 }
824 .bt())?
825 } else {
826 Ok(())
827 }
828 }
829
830 pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
833 let dim = dim.to_index(self.shape(), "chunk")?;
834 let size = self.dim(dim)?;
835 if size < chunks {
836 (0..size).map(|i| self.narrow(dim, i, 1)).collect()
837 } else {
838 let chunk_size = size / chunks;
839 let cnt_additional = size % chunks;
840 let mut tensors = vec![];
841 let mut sum_chunk_size = 0;
842 for i in 0..chunks {
843 let chunk_size = if i < cnt_additional {
844 chunk_size + 1
845 } else {
846 chunk_size
847 };
848 let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
849 tensors.push(tensor);
850 sum_chunk_size += chunk_size
851 }
852 Ok(tensors)
853 }
854 }
855
856 pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
883 let dims = self.dims();
884 let dim = dim.to_index(self.shape(), "narrow")?;
885 let err = |msg| {
886 Err::<(), _>(
887 Error::NarrowInvalidArgs {
888 shape: self.shape().clone(),
889 dim,
890 start,
891 len,
892 msg,
893 }
894 .bt(),
895 )
896 };
897 if start > dims[dim] {
898 err("start > dim_len")?
899 }
900 if start.saturating_add(len) > dims[dim] {
901 err("start + len > dim_len")?
902 }
903 if start == 0 && dims[dim] == len {
904 Ok(self.clone())
905 } else {
906 let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
907 let layout = self.layout().narrow(dim, start, len)?;
908 let tensor_ = Tensor_ {
909 id: TensorId::new(),
910 storage: self.storage.clone(),
911 layout,
912 op,
913 is_variable: false,
914 dtype: self.dtype,
915 device: self.device.clone(),
916 };
917 Ok(Tensor(Arc::new(tensor_)))
918 }
919 }
920
921 fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
922 match dims {
923 [] => Ok(self),
924 [i] => self.squeeze(*i),
925 dims => {
926 let dims = self
927 .dims()
928 .iter()
929 .enumerate()
930 .filter_map(|(dim_idx, &v)| {
931 if dims.contains(&dim_idx) {
932 None
933 } else {
934 Some(v)
935 }
936 })
937 .collect::<Vec<_>>();
938 self.reshape(dims)
939 }
940 }
941 }
942
943 fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
944 let dim = dim.to_index(self.shape(), op.name())?;
945 let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
946 let mut dims = self.dims().to_vec();
947 dims[dim] = 1;
948 let op = match op {
949 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
950 BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
951 }
952 ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
953 };
954 let res = from_storage(storage, dims, op, false);
955 if keepdim {
956 Ok(res)
957 } else {
958 res.squeeze_dims(&[dim])
959 }
960 }
961
962 fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
963 let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
964 let storage = self
965 .storage()
966 .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
967 let mut dims = self.dims().to_vec();
968 for &sum_dim in sum_dims.iter() {
969 dims[sum_dim] = 1
970 }
971 let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
972 let sum = from_storage(storage, dims, op, false);
973 if keepdim {
974 Ok(sum)
975 } else {
976 sum.squeeze_dims(&sum_dims)
977 }
978 }
979
980 pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
994 where
995 D: Dim + Clone,
996 {
997 let dim = dim.to_index(self.shape(), "roll")?;
998 let dim_size = self.dim(dim)?;
999 let shift = shift.rem_euclid(dim_size as i32) as usize;
1000 if shift == 0 {
1001 Ok(self.clone())
1002 } else {
1003 let a = self.narrow(dim, 0, dim_size - shift)?;
1004 let b = self.narrow(dim, dim_size - shift, shift)?;
1005 Tensor::cat(&[&b, &a], dim)
1006 }
1007 }
1008
1009 pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1027 self.sum_impl(sum_dims, true)
1028 }
1029
1030 pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1034 self.sum_impl(sum_dims, false)
1035 }
1036
1037 pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1055 let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
1056 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1057 let scale = 1f64 / (reduced_dim as f64);
1058 self.sum_impl(mean_dims, true)? * scale
1059 }
1060
1061 pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1065 let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
1066 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1067 let scale = 1f64 / (reduced_dim as f64);
1068 self.sum_impl(mean_dims, false)? * scale
1069 }
1070
1071 pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1073 let dim = dim.to_index(self.shape(), "var")?;
1074 let mean = self.mean_keepdim(dim)?;
1075 let squares = self.broadcast_sub(&mean)?.sqr()?;
1076 squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1077 }
1078
1079 pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1081 let dim = dim.to_index(self.shape(), "var")?;
1082 self.var_keepdim(dim)?.squeeze(dim)
1083 }
1084
1085 pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1088 self.reduce_impl(dim, true, ReduceOp::Max)
1089 }
1090
1091 pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1093 self.reduce_impl(dim, false, ReduceOp::Max)
1094 }
1095
1096 pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1099 self.reduce_impl(dim, true, ReduceOp::Min)
1100 }
1101
1102 pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1104 self.reduce_impl(dim, false, ReduceOp::Min)
1105 }
1106
1107 pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1108 self.reduce_impl(dim, true, ReduceOp::ArgMax)
1109 }
1110
1111 pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1113 self.reduce_impl(dim, false, ReduceOp::ArgMax)
1114 }
1115
1116 pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1117 self.reduce_impl(dim, true, ReduceOp::ArgMin)
1118 }
1119
1120 pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1122 self.reduce_impl(dim, false, ReduceOp::ArgMin)
1123 }
1124
1125 pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1130 let rhs = match rhs.to_tensor_scalar()? {
1131 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1132 crate::scalar::TensorScalar::Scalar(rhs) => rhs
1133 .to_dtype(self.dtype())?
1134 .to_device(self.device())?
1135 .broadcast_as(self.shape())?,
1136 };
1137 let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1138 let storage = self
1139 .storage()
1140 .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1141 let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1142 Ok(from_storage(storage, shape.dims(), op, false))
1143 }
1144
1145 pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1147 self.cmp(rhs, CmpOp::Eq)
1148 }
1149
1150 pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1152 self.cmp(rhs, CmpOp::Ne)
1153 }
1154
1155 pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1158 self.cmp(rhs, CmpOp::Lt)
1159 }
1160
1161 pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1164 self.cmp(rhs, CmpOp::Gt)
1165 }
1166
1167 pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1170 self.cmp(rhs, CmpOp::Ge)
1171 }
1172
1173 pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1176 self.cmp(rhs, CmpOp::Le)
1177 }
1178
1179 pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1181 self.maximum(min)?.minimum(max)
1182 }
1183
1184 pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1189 let (n, c, _l) = self.dims3()?;
1190 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1191 let storage = self
1192 .storage()
1193 .upsample_nearest1d(self.layout(), target_size)?;
1194 Ok(from_storage(storage, (n, c, target_size), op, false))
1195 }
1196
1197 pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1199 self.interpolate1d(target_size)
1200 }
1201
1202 pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1208 let (n, c, _h, _w) = self.dims4()?;
1209 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1210 arg,
1211 target_h,
1212 target_w,
1213 });
1214 let storage = self
1215 .storage()
1216 .upsample_nearest2d(self.layout(), target_h, target_w)?;
1217 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1218 }
1219
1220 pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1222 self.interpolate2d(target_h, target_w)
1223 }
1224
1225 pub fn upsample_bilinear2d(
1249 &self,
1250 target_h: usize,
1251 target_w: usize,
1252 align_corners: bool,
1253 ) -> Result<Self> {
1254 let (n, c, _h, _w) = self.dims4()?;
1255 let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1256 arg,
1257 target_h,
1258 target_w,
1259 align_corners,
1260 });
1261 let storage = self.storage().upsample_bilinear2d(
1263 self.layout(),
1264 target_h,
1265 target_w,
1266 align_corners,
1267 None,
1268 None,
1269 )?;
1270 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1271 }
1272
1273 pub fn upsample_bilinear2d_with_scale(
1297 &self,
1298 scale_h: f64,
1299 scale_w: f64,
1300 align_corners: bool,
1301 ) -> Result<Self> {
1302 let (n, c, height_in, width_in) = self.dims4()?;
1303
1304 let height_out = (height_in as f64 * scale_h).floor() as usize;
1306 let width_out = (width_in as f64 * scale_w).floor() as usize;
1307
1308 if height_in == height_out && width_in == width_out {
1310 return Ok(self.clone());
1311 }
1312
1313 let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1314 arg,
1315 target_h: height_out,
1316 target_w: width_out,
1317 align_corners,
1318 });
1319
1320 let storage = self.storage().upsample_bilinear2d(
1323 self.layout(),
1324 height_out,
1325 width_out,
1326 align_corners,
1327 Some(scale_h),
1328 Some(scale_w),
1329 )?;
1330 Ok(from_storage(
1331 storage,
1332 (n, c, height_out, width_out),
1333 op,
1334 false,
1335 ))
1336 }
1337
1338 pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1345 let sz = sz.to_usize2();
1346 self.avg_pool2d_with_stride(sz, sz)
1347 }
1348
1349 pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1352 &self,
1353 kernel_size: T,
1354 stride: T,
1355 ) -> Result<Self> {
1356 let kernel_size = kernel_size.to_usize2();
1357 let stride = stride.to_usize2();
1358 let (n, c, h, w) = self.dims4()?;
1359 if h < kernel_size.0 || w < kernel_size.1 {
1360 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1361 }
1362 let h_out = (h - kernel_size.0) / stride.0 + 1;
1364 let w_out = (w - kernel_size.1) / stride.1 + 1;
1365 let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1366 arg,
1367 kernel_size,
1368 stride,
1369 });
1370 let storage = self
1371 .storage()
1372 .avg_pool2d(self.layout(), kernel_size, stride)?;
1373 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1374 }
1375
1376 pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1383 let sz = sz.to_usize2();
1384 self.max_pool2d_with_stride(sz, sz)
1385 }
1386
1387 pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1390 &self,
1391 kernel_size: T,
1392 stride: T,
1393 ) -> Result<Self> {
1394 let kernel_size = kernel_size.to_usize2();
1395 let stride = stride.to_usize2();
1396 let (n, c, h, w) = self.dims4()?;
1397 if h < kernel_size.0 || w < kernel_size.1 {
1398 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1399 }
1400 let h_out = (h - kernel_size.0) / stride.0 + 1;
1402 let w_out = (w - kernel_size.1) / stride.1 + 1;
1403 let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1404 arg,
1405 kernel_size,
1406 stride,
1407 });
1408 let storage = self
1409 .storage()
1410 .max_pool2d(self.layout(), kernel_size, stride)?;
1411 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1412 }
1413
1414 pub fn dot(&self, rhs: &Self) -> Result<Self> {
1430 if self.dims().len() != 1 || rhs.dims().len() != 1 {
1431 return Err(Error::ShapeMismatchBinaryOp {
1432 lhs: self.shape().clone(),
1433 rhs: rhs.shape().clone(),
1434 op: "dot",
1435 });
1436 }
1437
1438 (self * rhs).and_then(|ret| ret.sum_all())
1439 }
1440
1441 pub fn norm(&self) -> Result<Self> {
1454 if self.dtype().is_int() {
1455 bail!("norm not supported for integer dtypes");
1456 }
1457
1458 self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
1459 }
1460
1461 pub fn mv(&self, rhs: &Self) -> Result<Self> {
1476 let lhs_dims = self.dims();
1478 let rhs_dims = rhs.dims();
1479 if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
1480 return Err(Error::ShapeMismatchBinaryOp {
1481 lhs: self.shape().clone(),
1482 rhs: rhs.shape().clone(),
1483 op: "mv",
1484 });
1485 }
1486
1487 self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
1489 }
1490
1491 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1500 let a_dims = self.shape().dims();
1501 let b_dims = rhs.shape().dims();
1502
1503 let dim = a_dims.len();
1504
1505 if dim < 2 || b_dims.len() != dim {
1506 Err(Error::ShapeMismatchBinaryOp {
1507 lhs: self.shape().clone(),
1508 rhs: rhs.shape().clone(),
1509 op: "matmul",
1510 }
1511 .bt())?
1512 }
1513
1514 let m = a_dims[dim - 2];
1515 let k = a_dims[dim - 1];
1516 let k2 = b_dims[dim - 2];
1517 let n = b_dims[dim - 1];
1518
1519 let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1520 if c_shape.elem_count() == 0 || k == 0 {
1521 return Tensor::zeros(c_shape, self.dtype(), self.device());
1522 }
1523 let batching: usize = a_dims[..dim - 2].iter().product();
1524 let batching_b: usize = b_dims[..dim - 2].iter().product();
1525 if k != k2 || batching != batching_b {
1526 Err(Error::ShapeMismatchBinaryOp {
1527 lhs: self.shape().clone(),
1528 rhs: rhs.shape().clone(),
1529 op: "matmul",
1530 }
1531 .bt())?
1532 }
1533
1534 let storage = self.storage().matmul(
1535 &rhs.storage(),
1536 (batching, m, n, k),
1537 self.layout(),
1538 rhs.layout(),
1539 )?;
1540 let op = BackpropOp::new2(self, rhs, Op::Matmul);
1541 Ok(from_storage(storage, c_shape, op, false))
1542 }
1543
1544 pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1550 let lhs = self;
1551 let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1552 let l_broadcast = l_shape != *lhs.shape();
1553 let r_broadcast = r_shape != *rhs.shape();
1554 match (l_broadcast, r_broadcast) {
1556 (true, true) => lhs
1557 .broadcast_as(&l_shape)?
1558 .contiguous()?
1559 .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1560 (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1561 (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1562 (false, false) => lhs.matmul(rhs),
1563 }
1564 }
1565
1566 pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1570 let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1571 let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1572 let storage = self.storage().where_cond(
1573 self.layout(),
1574 &on_true.storage(),
1575 on_true.layout(),
1576 &on_false.storage(),
1577 on_false.layout(),
1578 )?;
1579 let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1580 Ok(from_storage(storage, shape, op, false))
1581 }
1582
1583 pub fn embedding(&self, ids: &Self) -> Result<Self> {
1603 if self.rank() != 2 || ids.rank() != 1 {
1604 Err(Error::ShapeMismatchBinaryOp {
1605 lhs: self.shape().clone(),
1606 rhs: ids.shape().clone(),
1607 op: "embedding",
1608 }
1609 .bt())?
1610 }
1611 self.index_select(ids, 0)
1612 }
1613
1614 fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1615 let source_dims = source.dims();
1616 let self_dims = self.dims();
1617 let mismatch = if source_dims.len() != self_dims.len() {
1618 true
1619 } else {
1620 let mut mismatch = false;
1621 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1622 if i != dim && d1 != d2 {
1623 mismatch = true;
1624 break;
1625 }
1626 }
1627 mismatch
1628 };
1629 if mismatch {
1630 Err(Error::ShapeMismatchBinaryOp {
1631 op: "scatter (self, src)",
1632 lhs: self.shape().clone(),
1633 rhs: source.shape().clone(),
1634 }
1635 .bt())?
1636 }
1637 if indexes.dims() != source.dims() {
1638 Err(Error::ShapeMismatchBinaryOp {
1639 op: "scatter (indexes, src)",
1640 lhs: indexes.shape().clone(),
1641 rhs: source.shape().clone(),
1642 }
1643 .bt())?
1644 }
1645 Ok(())
1646 }
1647
1648 pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1649 let dim = dim.to_index(self.shape(), "scatter")?;
1650 self.scatter_checks(indexes, source, dim)?;
1651 let shape = self.shape();
1652 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1653 self.storage()
1654 .copy_strided_src(&mut storage, 0, self.layout())?;
1655 let layout = Layout::contiguous(shape);
1656 storage.scatter_set(
1657 &layout,
1658 &indexes.storage(),
1659 indexes.layout(),
1660 &source.storage(),
1661 source.layout(),
1662 dim,
1663 )?;
1664 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1665 Op::Scatter(t1, t2, t3, dim)
1666 });
1667 Ok(from_storage(storage, self.shape(), op, false))
1668 }
1669
1670 pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1671 if self.same_storage(source) {
1672 crate::bail!("cannot use slice_set when self and src share their storage")
1673 }
1674 let dim = dim.to_index(self.shape(), "scatter-set")?;
1675 self.scatter_checks(indexes, source, dim)?;
1676 self.storage_mut().scatter_set(
1677 self.layout(),
1678 &indexes.storage(),
1679 indexes.layout(),
1680 &source.storage(),
1681 source.layout(),
1682 dim,
1683 )?;
1684 Ok(())
1685 }
1686
1687 pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1688 let dim = dim.to_index(self.shape(), "scatter-add")?;
1689 self.scatter_checks(indexes, source, dim)?;
1690 let shape = self.shape();
1691 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1692 self.storage()
1693 .copy_strided_src(&mut storage, 0, self.layout())?;
1694 let layout = Layout::contiguous(shape);
1695 storage.scatter_add(
1696 &layout,
1697 &indexes.storage(),
1698 indexes.layout(),
1699 &source.storage(),
1700 source.layout(),
1701 dim,
1702 )?;
1703 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1704 Op::ScatterAdd(t1, t2, t3, dim)
1705 });
1706 Ok(from_storage(storage, self.shape(), op, false))
1707 }
1708
1709 pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1710 if self.same_storage(source) {
1711 crate::bail!("cannot use slice_set when self and src share their storage")
1712 }
1713 let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1714 self.scatter_checks(indexes, source, dim)?;
1715 self.storage_mut().scatter_add(
1716 self.layout(),
1717 &indexes.storage(),
1718 indexes.layout(),
1719 &source.storage(),
1720 source.layout(),
1721 dim,
1722 )?;
1723 Ok(())
1724 }
1725
1726 pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1728 let dim = dim.to_index(self.shape(), "slice-scatter")?;
1729 if dim == 0 {
1730 self.slice_scatter0(src, start)
1731 } else {
1732 self.transpose(0, dim)?
1734 .slice_scatter0(&src.transpose(0, dim)?, start)?
1735 .transpose(0, dim)
1736 }
1737 }
1738
1739 pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1741 if self.dtype() != src.dtype() {
1742 Err(Error::DTypeMismatchBinaryOp {
1743 lhs: self.dtype(),
1744 rhs: src.dtype(),
1745 op: "slice-scatter",
1746 }
1747 .bt())?
1748 }
1749 if self.device().location() != src.device.location() {
1750 Err(Error::DeviceMismatchBinaryOp {
1751 lhs: self.device().location(),
1752 rhs: src.device().location(),
1753 op: "slice-scatter",
1754 }
1755 .bt())?
1756 }
1757 if self.rank() != src.rank() {
1758 Err(Error::UnexpectedNumberOfDims {
1759 expected: self.rank(),
1760 got: src.rank(),
1761 shape: src.shape().clone(),
1762 }
1763 .bt())?
1764 }
1765 let shape_ok =
1766 self.dims()
1767 .iter()
1768 .zip(src.dims().iter())
1769 .enumerate()
1770 .all(|(dim_idx, (&d1, &d2))| {
1771 if 0 == dim_idx {
1772 d2 + start <= d1
1773 } else {
1774 d1 == d2
1775 }
1776 });
1777 if !shape_ok {
1778 Err(Error::ShapeMismatchBinaryOp {
1779 op: "slice-scatter (self, src)",
1780 lhs: self.shape().clone(),
1781 rhs: src.shape().clone(),
1782 }
1783 .bt())?
1784 }
1785 let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1786 self.storage()
1787 .copy_strided_src(&mut storage, 0, self.layout())?;
1788 let offset = start * src.dims()[1..].iter().product::<usize>();
1789 src.storage()
1790 .copy_strided_src(&mut storage, offset, src.layout())?;
1791 let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1792 Ok(from_storage(storage, self.shape(), op, false))
1793 }
1794
1795 pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1797 let dim = dim.to_index(self.shape(), "index-add")?;
1798 let source_dims = source.dims();
1799 let self_dims = self.dims();
1800 let mismatch = if source_dims.len() != self_dims.len() {
1801 true
1802 } else {
1803 let mut mismatch = false;
1804 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1805 if i != dim && d1 != d2 {
1806 mismatch = true;
1807 break;
1808 }
1809 }
1810 mismatch
1811 };
1812 if mismatch {
1813 Err(Error::ShapeMismatchBinaryOp {
1814 op: "index-add (self, source)",
1815 lhs: self.shape().clone(),
1816 rhs: source.shape().clone(),
1817 }
1818 .bt())?
1819 }
1820 let indexes_len = indexes.dims1()?;
1824 if source_dims[dim] != indexes_len {
1825 Err(Error::ShapeMismatchBinaryOp {
1826 op: "index-add (ids, source))",
1827 lhs: indexes.shape().clone(),
1828 rhs: source.shape().clone(),
1829 }
1830 .bt())?
1831 }
1832 let storage = self.storage().index_add(
1833 self.layout(),
1834 &indexes.storage(),
1835 indexes.layout(),
1836 &source.storage(),
1837 source.layout(),
1838 dim,
1839 )?;
1840 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1841 Op::IndexAdd(t1, t2, t3, dim)
1842 });
1843 Ok(from_storage(storage, self.shape(), op, false))
1844 }
1845
1846 pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1858 let dim = dim.to_index(self.shape(), "gather")?;
1859
1860 let self_dims = self.dims();
1861 let indexes_dims = indexes.dims();
1862 let mismatch = if indexes_dims.len() != self_dims.len() {
1863 true
1864 } else {
1865 let mut mismatch = false;
1866 for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1867 if i != dim && d1 < d2 {
1868 mismatch = true;
1869 break;
1870 }
1871 }
1872 mismatch
1873 };
1874 if mismatch {
1875 Err(Error::ShapeMismatchBinaryOp {
1876 op: "gather",
1877 lhs: self.shape().clone(),
1878 rhs: indexes.shape().clone(),
1879 }
1880 .bt())?
1881 }
1882 let storage =
1883 self.storage()
1884 .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1885 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1886 Ok(from_storage(storage, indexes.shape(), op, false))
1887 }
1888
1889 pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1897 let dim = dim.to_index(self.shape(), "index-select")?;
1898 let indexes_len = match indexes.dims() {
1899 [l] => *l,
1900 _ => Err(Error::ShapeMismatchBinaryOp {
1901 lhs: self.shape().clone(),
1902 rhs: indexes.shape().clone(),
1903 op: "index-select",
1904 }
1905 .bt())?,
1906 };
1907 let storage = self.storage().index_select(
1908 &indexes.storage(),
1909 self.layout(),
1910 indexes.layout(),
1911 dim,
1912 )?;
1913 let mut dims = self.dims().to_vec();
1914 dims[dim] = indexes_len;
1915 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1916 Ok(from_storage(storage, dims, op, false))
1917 }
1918
1919 pub fn strided_index(&self) -> crate::StridedIndex<'_> {
1922 self.layout.strided_index()
1923 }
1924
1925 pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
1930 self.layout.strided_blocks()
1931 }
1932
1933 pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1935 if self.rank() != 1 {
1936 Err(Error::UnexpectedNumberOfDims {
1937 expected: 1,
1938 got: self.rank(),
1939 shape: self.shape().clone(),
1940 }
1941 .bt())?
1942 }
1943 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1944 let data = S::cpu_storage_as_slice(cpu_storage)?;
1945 let data = match self.layout.contiguous_offsets() {
1946 Some((o1, o2)) => data[o1..o2].to_vec(),
1947 None => self.strided_index().map(|i| data[i]).collect(),
1948 };
1949 Ok::<Vec<_>, Error>(data)
1950 };
1951 match &*self.storage() {
1952 Storage::Cpu(storage) => from_cpu_storage(storage),
1953 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1954 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1955 #[cfg(feature = "rocm")]
1956 Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1957 #[cfg(feature = "vulkan")]
1958 Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1959 }
1960 }
1961
1962 pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1964 let (dim1, dim2) = self.dims2()?;
1965 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1966 let data = S::cpu_storage_as_slice(cpu_storage)?;
1967 let mut rows = vec![];
1968 match self.layout.contiguous_offsets() {
1969 Some((o1, o2)) => {
1970 let data = &data[o1..o2];
1971 for idx_row in 0..dim1 {
1972 rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1973 }
1974 }
1975 None => {
1976 let mut src_index = self.strided_index();
1977 for _idx_row in 0..dim1 {
1978 let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1979 rows.push(row)
1980 }
1981 assert!(src_index.next().is_none());
1982 }
1983 }
1984 Ok(rows)
1985 };
1986 match &*self.storage() {
1987 Storage::Cpu(storage) => from_cpu_storage(storage),
1988 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1989 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1990 #[cfg(feature = "rocm")]
1991 Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1992 #[cfg(feature = "vulkan")]
1993 Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1994 }
1995 }
1996
1997 pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1999 let (dim1, dim2, dim3) = self.dims3()?;
2000 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
2001 let data = S::cpu_storage_as_slice(cpu_storage)?;
2002 let mut top_rows = vec![];
2003 match self.layout.contiguous_offsets() {
2004 Some((o1, o2)) => {
2005 let data = &data[o1..o2];
2006 let dim23 = dim2 * dim3;
2007 for idx1 in 0..dim1 {
2008 let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
2009 let mut rows = vec![];
2010 for idx2 in 0..dim2 {
2011 rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
2012 }
2013 top_rows.push(rows);
2014 }
2015 }
2016 None => {
2017 let mut src_index = self.strided_index();
2018 for _idx in 0..dim1 {
2019 let mut rows = vec![];
2020 for _jdx in 0..dim2 {
2021 let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
2022 rows.push(row)
2023 }
2024 top_rows.push(rows);
2025 }
2026 assert!(src_index.next().is_none());
2027 }
2028 }
2029 Ok(top_rows)
2030 };
2031 match &*self.storage() {
2032 Storage::Cpu(storage) => from_cpu_storage(storage),
2033 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2034 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2035 #[cfg(feature = "rocm")]
2036 Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2037 #[cfg(feature = "vulkan")]
2038 Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2039 }
2040 }
2041
2042 pub fn dtype(&self) -> DType {
2044 self.dtype
2045 }
2046
2047 pub fn device(&self) -> &Device {
2049 &self.device
2050 }
2051
2052 pub fn shape(&self) -> &Shape {
2054 self.layout().shape()
2055 }
2056
2057 pub fn dims(&self) -> &[usize] {
2059 self.shape().dims()
2060 }
2061
2062 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
2064 let dim = dim.to_index(self.shape(), "dim")?;
2065 Ok(self.dims()[dim])
2066 }
2067
2068 pub fn layout(&self) -> &Layout {
2071 &self.layout
2072 }
2073
2074 pub fn stride(&self) -> &[usize] {
2075 self.layout.stride()
2076 }
2077
2078 pub fn rank(&self) -> usize {
2080 self.shape().rank()
2081 }
2082
2083 pub fn elem_count(&self) -> usize {
2085 self.shape().elem_count()
2086 }
2087
2088 pub fn id(&self) -> TensorId {
2090 self.id
2091 }
2092
2093 pub fn is_variable(&self) -> bool {
2096 self.is_variable
2097 }
2098
2099 pub(crate) fn op(&self) -> &Option<Op> {
2100 &self.op
2101 }
2102
2103 pub fn max_all(&self) -> Result<Tensor> {
2114 if self.rank() == 0 {
2115 Ok(self.clone())
2116 } else {
2117 self.flatten_all()?.max(0)
2118 }
2119 }
2120
2121 pub fn min_all(&self) -> Result<Tensor> {
2132 if self.rank() == 0 {
2133 Ok(self.clone())
2134 } else {
2135 self.flatten_all()?.min(0)
2136 }
2137 }
2138
2139 pub fn sum_all(&self) -> Result<Tensor> {
2150 let dims: Vec<_> = (0..self.rank()).collect();
2151 self.sum(dims)
2152 }
2153
2154 pub fn mean_all(&self) -> Result<Tensor> {
2155 self.sum_all()? / self.elem_count() as f64
2156 }
2157
2158 fn flatten_<D1: Dim, D2: Dim>(
2159 &self,
2160 start_dim: Option<D1>,
2161 end_dim: Option<D2>,
2162 ) -> Result<Tensor> {
2163 if self.rank() == 0 {
2164 self.reshape(1)
2165 } else {
2166 let start_dim = match start_dim {
2167 None => 0,
2168 Some(dim) => dim.to_index(self.shape(), "flatten")?,
2169 };
2170 let end_dim = match end_dim {
2171 None => self.rank() - 1,
2172 Some(dim) => dim.to_index(self.shape(), "flatten")?,
2173 };
2174 if start_dim < end_dim {
2175 let dims = self.dims();
2176 let mut dst_dims = dims[..start_dim].to_vec();
2177 dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
2178 if end_dim + 1 < dims.len() {
2179 dst_dims.extend(&dims[end_dim + 1..]);
2180 }
2181 self.reshape(dst_dims)
2182 } else {
2183 Ok(self.clone())
2184 }
2185 }
2186 }
2187
2188 pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
2191 self.flatten_(Some(start_dim), Some(end_dim))
2192 }
2193
2194 pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
2196 self.flatten_(None::<usize>, Some(end_dim))
2197 }
2198
2199 pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
2202 self.flatten_(Some(start_dim), None::<usize>)
2203 }
2204
2205 pub fn flatten_all(&self) -> Result<Tensor> {
2215 self.flatten_(None::<usize>, None::<usize>)
2216 }
2217
2218 pub fn get(&self, i: usize) -> Result<Tensor> {
2230 let dims = self.dims();
2231 if dims.is_empty() {
2232 Ok(self.clone())
2233 } else {
2234 self.narrow(0, i, 1)?.reshape(&dims[1..])
2235 }
2236 }
2237
2238 pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
2252 let dim = dim.to_index(self.shape(), "get_on_dim")?;
2253 self.narrow(dim, index, 1)?.squeeze(dim)
2254 }
2255
2256 pub fn t(&self) -> Result<Tensor> {
2267 let rank = self.rank();
2268 if rank < 2 {
2269 Err(Error::UnexpectedNumberOfDims {
2270 expected: 2,
2271 got: rank,
2272 shape: self.shape().clone(),
2273 }
2274 .bt())?
2275 }
2276 self.transpose(rank - 2, rank - 1)
2277 }
2278
2279 pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2282 let dim1 = dim1.to_index(self.shape(), "transpose")?;
2283 let dim2 = dim2.to_index(self.shape(), "transpose")?;
2284 if dim1 == dim2 {
2285 return Ok(self.clone());
2286 }
2287 let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2288 let tensor_ = Tensor_ {
2289 id: TensorId::new(),
2290 storage: self.storage.clone(),
2291 layout: self.layout.transpose(dim1, dim2)?,
2292 op,
2293 is_variable: false,
2294 dtype: self.dtype,
2295 device: self.device.clone(),
2296 };
2297 Ok(Tensor(Arc::new(tensor_)))
2298 }
2299
2300 pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2312 let dims = dims.to_indexes(self.shape(), "permute")?;
2313 let is_permutation =
2315 dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2316 if !is_permutation {
2317 bail!(
2318 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2319 self.dims(),
2320 dims
2321 )
2322 }
2323 let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2324 let tensor_ = Tensor_ {
2325 id: TensorId::new(),
2326 storage: self.storage.clone(),
2327 layout: self.layout.permute(&dims)?,
2328 op,
2329 is_variable: false,
2330 dtype: self.dtype,
2331 device: self.device.clone(),
2332 };
2333 Ok(Tensor(Arc::new(tensor_)))
2334 }
2335
2336 pub fn is_contiguous(&self) -> bool {
2338 self.layout.is_contiguous()
2339 }
2340
2341 pub fn is_fortran_contiguous(&self) -> bool {
2343 self.layout.is_fortran_contiguous()
2344 }
2345
2346 pub fn copy(&self) -> Result<Tensor> {
2349 let op = BackpropOp::new1(self, Op::Copy);
2350 let tensor_ = Tensor_ {
2351 id: TensorId::new(),
2352 storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2353 layout: self.layout.clone(),
2354 op,
2355 is_variable: false,
2356 dtype: self.dtype,
2357 device: self.device.clone(),
2358 };
2359 Ok(Tensor(Arc::new(tensor_)))
2360 }
2361
2362 pub fn detach(&self) -> Tensor {
2367 if self.op.is_none() && !self.is_variable {
2368 self.clone()
2369 } else {
2370 let tensor_ = Tensor_ {
2371 id: TensorId::new(),
2372 storage: self.storage.clone(),
2373 layout: self.layout.clone(),
2374 op: BackpropOp::none(),
2375 is_variable: false,
2376 dtype: self.dtype,
2377 device: self.device.clone(),
2378 };
2379 Tensor(Arc::new(tensor_))
2380 }
2381 }
2382
2383 pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2385 if self.device().same_device(device) {
2386 Ok(self.clone())
2387 } else {
2388 let storage = match (&*self.storage(), device) {
2389 (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2390 Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2391 }
2392 (Storage::Cpu(storage), Device::Metal(metal)) => {
2393 Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2394 }
2395 (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2396 (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2397 #[cfg(feature = "rocm")]
2398 (Storage::Rocm(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2399 #[cfg(feature = "vulkan")]
2400 (Storage::Vulkan(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2401 #[cfg(feature = "rocm")]
2402 (Storage::Cpu(storage), Device::Rocm(rocm)) => {
2403 Storage::Rocm(rocm.storage_from_cpu_storage(storage)?)
2404 }
2405 #[cfg(feature = "vulkan")]
2406 (Storage::Cpu(storage), Device::Vulkan(vulkan)) => {
2407 Storage::Vulkan(vulkan.storage_from_cpu_storage(storage)?)
2408 }
2409 #[cfg(feature = "rocm")]
2410 (Storage::Rocm(storage), Device::Rocm(rocm)) => {
2411 let cpu_storage = storage.to_cpu_storage()?;
2412 Storage::Rocm(rocm.storage_from_cpu_storage(&cpu_storage)?)
2413 }
2414 #[cfg(feature = "vulkan")]
2415 (Storage::Vulkan(storage), Device::Vulkan(vulkan)) => {
2416 let cpu_storage = storage.to_cpu_storage()?;
2417 Storage::Vulkan(vulkan.storage_from_cpu_storage(&cpu_storage)?)
2418 }
2419 (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2420 let cpu_storage = storage.to_cpu_storage()?;
2423 Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2424 }
2425 (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2426 _ => {
2427 bail!(
2428 "not implemented yet, self.device: {:?}, device: {:?}",
2429 self.device(),
2430 device
2431 )
2432 }
2433 };
2434 let op = BackpropOp::new1(self, Op::ToDevice);
2435 let tensor_ = Tensor_ {
2436 id: TensorId::new(),
2437 storage: Arc::new(RwLock::new(storage)),
2438 layout: self.layout.clone(),
2439 op,
2440 is_variable: false,
2441 dtype: self.dtype,
2442 device: device.clone(),
2443 };
2444 Ok(Tensor(Arc::new(tensor_)))
2445 }
2446 }
2447
2448 pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2451 let left_shape = left_shape.into();
2452 let mut dims = left_shape.into_dims();
2453 dims.extend(self.dims());
2454 self.broadcast_as(dims)
2455 }
2456
2457 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2465 let tensor_ = Tensor_ {
2466 id: TensorId::new(),
2467 storage: self.storage.clone(),
2468 layout: self.layout.broadcast_as(shape)?,
2469 op: BackpropOp::new1(self, Op::Broadcast),
2470 is_variable: false,
2471 dtype: self.dtype,
2472 device: self.device.clone(),
2473 };
2474 Ok(Tensor(Arc::new(tensor_)))
2475 }
2476
2477 pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2479 self.broadcast_as(shape)
2480 }
2481
2482 pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2493 if self.dtype() == dtype {
2494 Ok(self.clone())
2495 } else {
2496 let shape = self.shape();
2497 let storage = self.storage().to_dtype(self.layout(), dtype)?;
2498 let op = BackpropOp::new1(self, Op::ToDType);
2499 Ok(from_storage(storage, shape.clone(), op, false))
2500 }
2501 }
2502
2503 pub fn contiguous(&self) -> Result<Tensor> {
2506 if self.is_contiguous() {
2507 Ok(self.clone())
2508 } else {
2509 let shape = self.shape();
2510 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2511 self.storage()
2512 .copy_strided_src(&mut storage, 0, self.layout())?;
2513 let op = BackpropOp::new1(self, Op::Copy);
2514 Ok(from_storage(storage, shape.clone(), op, false))
2515 }
2516 }
2517
2518 pub fn force_contiguous(&self) -> Result<Tensor> {
2520 let shape = self.shape();
2521 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2522 self.storage()
2523 .copy_strided_src(&mut storage, 0, self.layout())?;
2524 let op = BackpropOp::new1(self, Op::Copy);
2525 Ok(from_storage(storage, shape.clone(), op, false))
2526 }
2527
2528 pub(crate) fn make_var(&self) -> Result<Tensor> {
2531 let shape = self.shape().clone();
2532 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2533 self.storage()
2534 .copy_strided_src(&mut storage, 0, self.layout())?;
2535 Ok(from_storage(storage, shape, BackpropOp::none(), true))
2536 }
2537
2538 pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2563 let shape = s.into_shape(self.elem_count())?;
2564 if shape.elem_count() != self.elem_count() {
2565 return Err(Error::ShapeMismatchBinaryOp {
2566 lhs: self.shape().clone(),
2567 rhs: shape,
2568 op: "reshape",
2569 }
2570 .bt());
2571 }
2572 let op = BackpropOp::new1(self, Op::Reshape);
2573 if self.is_contiguous() {
2574 let tensor_ = Tensor_ {
2575 id: TensorId::new(),
2576 storage: self.storage.clone(),
2577 layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2578 op,
2579 is_variable: false,
2580 dtype: self.dtype,
2581 device: self.device.clone(),
2582 };
2583 Ok(Tensor(Arc::new(tensor_)))
2584 } else {
2585 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2586 self.storage()
2587 .copy_strided_src(&mut storage, 0, self.layout())?;
2588 Ok(from_storage(storage, shape, op, false))
2589 }
2590 }
2591
2592 pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2606 let dims = self.dims();
2609 let dim = dim.to_index(self.shape(), "squeeze")?;
2610 if dims[dim] == 1 {
2611 let mut dims = dims.to_vec();
2612 let mut strides = self.stride().to_vec();
2613 dims.remove(dim);
2614 strides.remove(dim);
2615 let tensor_ = Tensor_ {
2616 id: TensorId::new(),
2617 storage: self.storage.clone(),
2618 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2619 op: BackpropOp::new1(self, Op::Reshape),
2620 is_variable: false,
2621 dtype: self.dtype,
2622 device: self.device.clone(),
2623 };
2624 Ok(Tensor(Arc::new(tensor_)))
2625 } else {
2626 Ok(self.clone())
2627 }
2628 }
2629
2630 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2644 let mut dims = self.dims().to_vec();
2645 let mut strides = self.stride().to_vec();
2646 let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2647 dims.insert(dim, 1);
2649 let stride = if dim < strides.len() { strides[dim] } else { 1 };
2652 strides.insert(dim, stride);
2653 let tensor_ = Tensor_ {
2654 id: TensorId::new(),
2655 storage: self.storage.clone(),
2656 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2657 op: BackpropOp::new1(self, Op::Reshape),
2658 is_variable: false,
2659 dtype: self.dtype,
2660 device: self.device.clone(),
2661 };
2662 Ok(Tensor(Arc::new(tensor_)))
2663 }
2664
2665 pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2682 if args.is_empty() {
2683 Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2684 }
2685 let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2686 let args = args
2687 .iter()
2688 .map(|t| t.as_ref().unsqueeze(dim))
2689 .collect::<Result<Vec<_>>>()?;
2690 Self::cat(&args, dim)
2691 }
2692
2693 pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2696 if left == 0 && right == 0 {
2697 Ok(self.clone())
2698 } else if left == 0 {
2699 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2700 let mut dims = self.dims().to_vec();
2701 dims[dim] = right;
2702 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2703 Tensor::cat(&[self, &right], dim)
2704 } else if right == 0 {
2705 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2706 let mut dims = self.dims().to_vec();
2707 dims[dim] = left;
2708 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2709 Tensor::cat(&[&left, self], dim)
2710 } else {
2711 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2712 let mut dims = self.dims().to_vec();
2713 dims[dim] = left;
2714 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2715 dims[dim] = right;
2716 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2717 Tensor::cat(&[&left, self, &right], dim)
2718 }
2719 }
2720
2721 pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2724 if left == 0 && right == 0 {
2725 Ok(self.clone())
2726 } else if self.elem_count() == 0 {
2727 bail!("cannot use pad_with_same on an empty tensor")
2728 } else if left == 0 {
2729 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2730 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2731 let mut v = vec![self];
2732 for _ in 0..right {
2733 v.push(&r)
2734 }
2735 Tensor::cat(&v, dim)
2736 } else if right == 0 {
2737 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2738 let l = self.narrow(dim, 0, 1)?;
2739 let mut v = vec![];
2740 for _ in 0..left {
2741 v.push(&l)
2742 }
2743 v.push(self);
2744 Tensor::cat(&v, dim)
2745 } else {
2746 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2747 let l = self.narrow(dim, 0, 1)?;
2748 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2749 let mut v = vec![];
2750 for _ in 0..left {
2751 v.push(&l)
2752 }
2753 v.push(self);
2754 for _ in 0..right {
2755 v.push(&r)
2756 }
2757 Tensor::cat(&v, dim)
2758 }
2759 }
2760
2761 pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2763 m.forward(self)
2764 }
2765
2766 pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2768 m.forward_t(self, train)
2769 }
2770
2771 pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2772 self.storage.read().unwrap()
2773 }
2774
2775 pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2776 self.storage.write().unwrap()
2777 }
2778
2779 pub(crate) fn storage_mut_and_layout(
2782 &self,
2783 ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2784 let storage = self.storage.write().unwrap();
2785 (storage, &self.layout)
2786 }
2787
2788 pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2790 let storage = self.storage.read().unwrap();
2791 (storage, &self.layout)
2792 }
2793
2794 pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2795 let lhs: &RwLock<Storage> = self.storage.as_ref();
2796 let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2797 std::ptr::eq(lhs, rhs)
2798 }
2799
2800 pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2803 let rank = self.rank() as i64;
2804 if rank <= axis {
2805 bail!("axis {axis} is too large, tensor rank {rank}")
2806 } else if 0 <= axis {
2807 Ok(axis as usize)
2808 } else {
2809 let naxis = rank + axis;
2810 if naxis < 0 {
2811 bail!("axis {axis} is too small, tensor rank {rank}")
2812 }
2813 Ok(naxis as usize)
2814 }
2815 }
2816
2817 pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2819 let t = Tensor::arange(0u32, n as u32, device)?;
2820 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2821 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2822 t1.le(&t2)?.to_dtype(dtype)
2823 }
2824
2825 pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2827 let t = Tensor::arange(0u32, n as u32, device)?;
2828 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2829 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2830 t1.ge(&t2)?.to_dtype(dtype)
2831 }
2832
2833 pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2835 let t = Tensor::arange(0u32, n as u32, device)?;
2836 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2837 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2838 t1.eq(&t2)?.to_dtype(dtype)
2839 }
2840
2841 pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2846 let dim = dim.to_index(self.shape(), "cumsum")?;
2847 let rank = self.rank();
2848 if rank == 0 {
2849 return Ok(self.clone());
2850 }
2851 let n_axis = self.dim(dim)?;
2852 let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2853 if rank == 1 {
2854 self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2855 } else {
2856 let last = rank - 1;
2857 let t = self.transpose(dim, last)?;
2858 let t = t.broadcast_matmul(&triu)?;
2859 t.transpose(dim, last)
2860 }
2861 }
2862
2863 pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2866 &self,
2867 ranges: &[D],
2868 src: &Tensor,
2869 ) -> Result<Self> {
2870 let src_dims = src.dims();
2871 let self_dims = self.dims();
2872 if self_dims.len() != src_dims.len() {
2873 bail!(
2874 "slice-assign requires input with the same rank {} <> {}",
2875 self_dims.len(),
2876 src_dims.len()
2877 )
2878 }
2879 if self_dims.len() != ranges.len() {
2880 bail!(
2881 "slice-assign requires input with the same rank as there are ranges {} <> {}",
2882 self_dims.len(),
2883 ranges.len()
2884 )
2885 }
2886 let mut src = src.clone();
2887 let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2888 for (i, range) in ranges.iter().enumerate() {
2889 let start_included = match range.start_bound() {
2890 std::ops::Bound::Unbounded => 0,
2891 std::ops::Bound::Included(v) => *v,
2892 std::ops::Bound::Excluded(v) => *v + 1,
2893 };
2894 let end_excluded = match range.end_bound() {
2895 std::ops::Bound::Unbounded => self_dims[i],
2896 std::ops::Bound::Included(v) => *v + 1,
2897 std::ops::Bound::Excluded(v) => *v,
2898 };
2899 if end_excluded <= start_included {
2900 bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2901 }
2902 if self_dims[i] < end_excluded {
2903 bail!(
2904 "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2905 self_dims[i]
2906 )
2907 }
2908 if end_excluded - start_included != src_dims[i] {
2909 bail!(
2910 "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2911 )
2912 }
2913 src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2914 mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2915 }
2916 mask.where_cond(&src, self)
2917 }
2918
2919 pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2921 let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2922 if sum_dims.is_empty() {
2923 return Ok(self.clone());
2924 }
2925 let max = sum_dims[1..]
2926 .iter()
2927 .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2928 max.max_keepdim(dim)
2929 })?;
2930 let exp = self.broadcast_sub(&max)?.exp()?;
2931 let sum = exp.sum(sum_dims.clone())?;
2932
2933 sum.log()? + max.squeeze_dims(&sum_dims)
2934 }
2935
2936 pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2938 rhs.mul(&self.log()?)?.exp()
2939 }
2940
2941 pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2943 rhs.broadcast_mul(&self.log()?)?.exp()
2944 }
2945
2946 pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2958 let mut result = self.clone();
2959 for &dim in dims.iter() {
2960 let size = result.dim(dim)?;
2961 let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2962 let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2963 result = result.index_select(&indices_tensor, dim)?;
2964 }
2965 Ok(result)
2966 }
2967
2968 pub fn unfold<D: Dim>(&self, dim: D, size: usize, step: usize) -> Result<Self> {
2971 let mut sizes = self.dims().to_vec();
2973 let mut strides = self.stride().to_vec();
2974
2975 let dim = dim.to_index(self.shape(), "unfold")?;
2976
2977 let max_len = if self.dims().is_empty() {
2978 1
2979 } else {
2980 sizes[dim]
2981 };
2982 if size > max_len {
2983 bail!(
2984 "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}"
2985 )
2986 }
2987 sizes.push(size);
2988 strides.push(if self.dims().is_empty() {
2989 1
2990 } else {
2991 strides[dim]
2992 });
2993
2994 if !self.dims().is_empty() {
2995 sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize;
2996 strides[dim] *= step;
2997 }
2998
2999 let tensor_ = Tensor_ {
3000 id: TensorId::new(),
3001 storage: self.storage.clone(),
3002 layout: Layout::new(sizes.into(), strides, self.layout.start_offset()),
3003 op: BackpropOp::new1(self, Op::Reshape),
3004 is_variable: false,
3005 dtype: self.dtype,
3006 device: self.device.clone(),
3007 };
3008 Ok(Tensor(Arc::new(tensor_)))
3009 }
3010}
3011
3012macro_rules! bin_trait {
3013 ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
3014 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
3015 type Output = Result<Tensor>;
3016
3017 fn $fn1(self, rhs: B) -> Self::Output {
3018 Tensor::$fn1(&self, rhs.borrow())
3019 }
3020 }
3021
3022 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
3023 type Output = Result<Tensor>;
3024
3025 fn $fn1(self, rhs: B) -> Self::Output {
3026 Tensor::$fn1(&self, rhs.borrow())
3027 }
3028 }
3029
3030 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
3031 type Output = Result<Tensor>;
3032
3033 fn $fn1(self, rhs: Tensor) -> Self::Output {
3034 Tensor::$fn1(self?.borrow(), &rhs)
3035 }
3036 }
3037
3038 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
3039 type Output = Result<Tensor>;
3040
3041 fn $fn1(self, rhs: &Tensor) -> Self::Output {
3042 Tensor::$fn1(self?.borrow(), rhs)
3043 }
3044 }
3045
3046 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
3047 type Output = Result<Tensor>;
3048
3049 fn $fn1(self, rhs: Result<B>) -> Self::Output {
3050 Tensor::$fn1(&self, rhs?.borrow())
3051 }
3052 }
3053
3054 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
3055 type Output = Result<Tensor>;
3056
3057 fn $fn1(self, rhs: Result<B>) -> Self::Output {
3058 Tensor::$fn1(&self, rhs?.borrow())
3059 }
3060 }
3061
3062 impl std::ops::$trait<f64> for Tensor {
3063 type Output = Result<Tensor>;
3064
3065 fn $fn1(self, rhs: f64) -> Self::Output {
3066 self.affine($mul(rhs), $add(rhs))
3067 }
3068 }
3069
3070 impl std::ops::$trait<f64> for &Tensor {
3071 type Output = Result<Tensor>;
3072
3073 fn $fn1(self, rhs: f64) -> Self::Output {
3074 self.affine($mul(rhs), $add(rhs))
3075 }
3076 }
3077 };
3078}
3079
3080bin_trait!(Add, add, |_| 1., |v| v);
3081bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
3082bin_trait!(Mul, mul, |v| v, |_| 0.);
3083bin_trait!(Div, div, |v| 1. / v, |_| 0.);
3084
3085impl std::ops::Add<Tensor> for f64 {
3086 type Output = Result<Tensor>;
3087
3088 fn add(self, rhs: Tensor) -> Self::Output {
3089 rhs + self
3090 }
3091}
3092
3093impl std::ops::Add<&Tensor> for f64 {
3094 type Output = Result<Tensor>;
3095
3096 fn add(self, rhs: &Tensor) -> Self::Output {
3097 rhs + self
3098 }
3099}
3100
3101impl std::ops::Mul<Tensor> for f64 {
3102 type Output = Result<Tensor>;
3103
3104 fn mul(self, rhs: Tensor) -> Self::Output {
3105 rhs * self
3106 }
3107}
3108
3109impl std::ops::Mul<&Tensor> for f64 {
3110 type Output = Result<Tensor>;
3111
3112 fn mul(self, rhs: &Tensor) -> Self::Output {
3113 rhs * self
3114 }
3115}
3116
3117impl std::ops::Sub<Tensor> for f64 {
3118 type Output = Result<Tensor>;
3119
3120 fn sub(self, rhs: Tensor) -> Self::Output {
3121 rhs.affine(-1., self)
3122 }
3123}
3124
3125impl std::ops::Sub<&Tensor> for f64 {
3126 type Output = Result<Tensor>;
3127
3128 fn sub(self, rhs: &Tensor) -> Self::Output {
3129 rhs.affine(-1., self)
3130 }
3131}
3132
3133impl std::ops::Div<Tensor> for f64 {
3134 type Output = Result<Tensor>;
3135
3136 #[allow(clippy::suspicious_arithmetic_impl)]
3137 fn div(self, rhs: Tensor) -> Self::Output {
3138 rhs.recip()? * self
3139 }
3140}
3141
3142impl std::ops::Div<&Tensor> for f64 {
3143 type Output = Result<Tensor>;
3144
3145 #[allow(clippy::suspicious_arithmetic_impl)]
3146 fn div(self, rhs: &Tensor) -> Self::Output {
3147 rhs.recip()? * self
3148 }
3149}
3150
3151impl<S: Into<Shape>> From<(Storage, S)> for Tensor {
3152 fn from((storage, shape): (Storage, S)) -> Self {
3153 from_storage(storage, shape, BackpropOp::none(), false)
3154 }
3155}