1#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15 fn new() -> Self {
16 use std::sync::atomic;
18 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20 }
21}
22
23pub struct Tensor_ {
24 id: TensorId,
25 storage: Arc<RwLock<Storage>>,
38 layout: Layout,
39 op: BackpropOp,
40 is_variable: bool,
41 dtype: DType,
42 device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46 fn as_ref(&self) -> &Tensor {
47 self
48 }
49}
50
51#[derive(Clone)]
55pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71 type Target = Tensor_;
72
73 fn deref(&self) -> &Self::Target {
74 self.0.as_ref()
75 }
76}
77
78macro_rules! unary_op {
79 ($fn_name:ident, $op_name:ident) => {
80 pub fn $fn_name(&self) -> Result<Self> {
81 let shape = self.shape();
82 if shape.elem_count() == 0 {
83 return Ok(self.clone());
84 }
85 let storage = self
86 .storage()
87 .unary_impl::<crate::op::$op_name>(self.layout())?;
88 let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89 Ok(from_storage(storage, shape.clone(), op, false))
90 }
91 };
92}
93
94macro_rules! binary_op {
95 ($fn_name:ident, $op_name:ident) => {
96 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97 let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98 if shape.elem_count() == 0 {
99 return Ok(self.clone());
100 }
101 let storage = self.storage().binary_impl::<crate::op::$op_name>(
102 &*rhs.storage(),
103 self.layout(),
104 rhs.layout(),
105 )?;
106 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107 Ok(from_storage(storage, shape.clone(), op, false))
108 }
109 };
110}
111
112macro_rules! binary_op_scalar {
113 ($fn_name:ident, $op_name:ident) => {
114 pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115 let rhs = match rhs.to_tensor_scalar()? {
116 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117 crate::scalar::TensorScalar::Scalar(rhs) => rhs
118 .to_dtype(self.dtype())?
119 .to_device(self.device())?
120 .broadcast_as(self.shape())?,
121 };
122 let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123 if self.elem_count() == 0 {
124 return Ok(self.clone());
125 }
126 let storage = self.storage().binary_impl::<crate::op::$op_name>(
127 &*rhs.storage(),
128 self.layout(),
129 rhs.layout(),
130 )?;
131 let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132 Ok(from_storage(storage, shape.clone(), op, false))
133 }
134 };
135}
136
137macro_rules! broadcast_binary_op {
138 ($fn_name:ident, $inner_fn_name:ident) => {
139 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140 let lhs = self;
141 let shape = lhs
142 .shape()
143 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144 let l_broadcast = shape != *lhs.shape();
145 let r_broadcast = shape != *rhs.shape();
146 match (l_broadcast, r_broadcast) {
147 (true, true) => lhs
148 .broadcast_as(&shape)?
149 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152 (false, false) => lhs.$inner_fn_name(rhs),
153 }
154 }
155 };
156}
157
158pub(crate) fn from_storage<S: Into<Shape>>(
160 storage: Storage,
161 shape: S,
162 op: BackpropOp,
163 is_variable: bool,
164) -> Tensor {
165 let dtype = storage.dtype();
166 let device = storage.device();
167 let tensor_ = Tensor_ {
168 id: TensorId::new(),
169 storage: Arc::new(RwLock::new(storage)),
170 layout: Layout::contiguous(shape),
171 op,
172 is_variable,
173 dtype,
174 device,
175 };
176 Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180 pub(crate) fn ones_impl<S: Into<Shape>>(
181 shape: S,
182 dtype: DType,
183 device: &Device,
184 is_variable: bool,
185 ) -> Result<Self> {
186 let none = BackpropOp::none();
187 let shape = shape.into();
188 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189 let layout = Layout::contiguous(shape.clone());
190 storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191 Ok(from_storage(storage, shape, none, is_variable))
192 }
193
194 pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204 Self::ones_impl(shape, dtype, device, false)
205 }
206
207 pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208 self.storage_mut().const_set(value, self.layout())
209 }
210
211 pub fn zero_set(&self) -> Result<()> {
212 self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213 }
214
215 pub fn one_set(&self) -> Result<()> {
216 self.const_set(crate::scalar::Scalar::one(self.dtype()))
217 }
218
219 pub fn ones_like(&self) -> Result<Self> {
229 Tensor::ones(self.shape(), self.dtype(), self.device())
230 }
231
232 pub(crate) fn zeros_impl<S: Into<Shape>>(
235 shape: S,
236 dtype: DType,
237 device: &Device,
238 is_variable: bool,
239 ) -> Result<Self> {
240 let none = BackpropOp::none();
241 let shape = shape.into();
242 let storage = device.zeros(&shape, dtype)?;
243 Ok(from_storage(storage, shape, none, is_variable))
244 }
245
246 pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256 Self::zeros_impl(shape, dtype, device, false)
257 }
258
259 pub fn zeros_like(&self) -> Result<Self> {
270 Tensor::zeros(self.shape(), self.dtype(), self.device())
271 }
272
273 pub(crate) unsafe fn empty_impl<S: Into<Shape>>(
276 shape: S,
277 dtype: DType,
278 device: &Device,
279 is_variable: bool,
280 ) -> Result<Self> {
281 let none = BackpropOp::none();
282 let shape = shape.into();
283 let storage = device.alloc_uninit(&shape, dtype)?;
284 Ok(from_storage(storage, shape, none, is_variable))
285 }
286
287 pub unsafe fn empty<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
299 Self::empty_impl(shape, dtype, device, false)
300 }
301
302 pub unsafe fn empty_like(&self) -> Result<Self> {
315 Tensor::empty(self.shape(), self.dtype(), self.device())
316 }
317
318 pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
319 lo: T,
320 up: T,
321 s: S,
322 device: &Device,
323 is_variable: bool,
324 ) -> Result<Self> {
325 let s = s.into();
326 let storage = device.rand_uniform(lo, up, &s)?;
327 let none = BackpropOp::none();
328 Ok(from_storage(storage, s, none, is_variable))
329 }
330
331 pub(crate) fn rand_f64_impl<S: Into<Shape>>(
332 lo: f64,
333 up: f64,
334 s: S,
335 dtype: DType,
336 device: &Device,
337 is_variable: bool,
338 ) -> Result<Self> {
339 let s = s.into();
340 let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
341 let none = BackpropOp::none();
342 Ok(from_storage(storage, s, none, is_variable))
343 }
344
345 pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
347 lo: T,
348 up: T,
349 s: S,
350 device: &Device,
351 ) -> Result<Self> {
352 Self::rand_impl(lo, up, s, device, false)
353 }
354
355 pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
356 Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
357 }
358
359 pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
360 mean: T,
361 std: T,
362 s: S,
363 device: &Device,
364 is_variable: bool,
365 ) -> Result<Self> {
366 let s = s.into();
367 let storage = device.rand_normal(mean, std, &s)?;
368 let none = BackpropOp::none();
369 Ok(from_storage(storage, s, none, is_variable))
370 }
371
372 pub(crate) fn randn_f64_impl<S: Into<Shape>>(
373 mean: f64,
374 std: f64,
375 s: S,
376 dtype: DType,
377 device: &Device,
378 is_variable: bool,
379 ) -> Result<Self> {
380 let s = s.into();
381 let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
382 let none = BackpropOp::none();
383 Ok(from_storage(storage, s, none, is_variable))
384 }
385
386 pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
387 Tensor::randn_f64_impl(
388 mean,
389 stdev,
390 self.shape(),
391 self.dtype(),
392 self.device(),
393 false,
394 )
395 }
396
397 pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
400 mean: T,
401 std: T,
402 s: S,
403 device: &Device,
404 ) -> Result<Self> {
405 Self::randn_impl(mean, std, s, device, false)
406 }
407
408 pub(crate) fn new_impl<A: crate::device::NdArray>(
409 array: A,
410 shape: Shape,
411 device: &Device,
412 is_variable: bool,
413 ) -> Result<Self> {
414 let n: usize = shape.elem_count();
415 let buffer_size: usize = array.shape()?.elem_count();
416 if buffer_size != n {
417 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
418 }
419 let storage = device.storage(array)?;
420 let none = BackpropOp::none();
421 Ok(from_storage(storage, shape, none, is_variable))
422 }
423
424 pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
426 let shape = array.shape()?;
427 Self::new_impl(array, shape, device, false)
428 }
429
430 pub fn full<D: crate::WithDType, S: Into<Shape>>(
441 value: D,
442 shape: S,
443 device: &Device,
444 ) -> Result<Self> {
445 let none = BackpropOp::none();
446 let shape = shape.into();
447 let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
448 let layout = Layout::contiguous(shape.clone());
449 storage.const_set(value.to_scalar(), &layout)?;
450 Ok(from_storage(storage, shape, none, false))
451 }
452
453 pub fn from_iter<D: crate::WithDType>(
462 iter: impl IntoIterator<Item = D>,
463 device: &Device,
464 ) -> Result<Self> {
465 let data = iter.into_iter().collect::<Vec<_>>();
466 let len = data.len();
467 Self::from_vec_impl(data, len, device, false)
468 }
469
470 pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
480 Self::arange_step(start, end, D::one(), device)
481 }
482
483 pub fn arange_step<D: crate::WithDType>(
493 start: D,
494 end: D,
495 step: D,
496 device: &Device,
497 ) -> Result<Self> {
498 if D::is_zero(&step) {
499 bail!("step cannot be zero")
500 }
501 let mut data = vec![];
502 let mut current = start;
503 if step >= D::zero() {
504 while current < end {
505 data.push(current);
506 current += step;
507 }
508 } else {
509 while current > end {
510 data.push(current);
511 current += step;
512 }
513 }
514 let len = data.len();
515 Self::from_vec_impl(data, len, device, false)
516 }
517
518 pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
519 data: Vec<D>,
520 shape: S,
521 device: &Device,
522 is_variable: bool,
523 ) -> Result<Self> {
524 let shape = shape.into_shape(data.len())?;
525 let storage = device.storage_owned(data)?;
526 let none = BackpropOp::none();
527 Ok(from_storage(storage, shape, none, is_variable))
528 }
529
530 pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
544 data: Vec<D>,
545 shape: S,
546 device: &Device,
547 ) -> Result<Self> {
548 Self::from_vec_impl(data, shape, device, false)
549 }
550
551 pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
565 array: &[D],
566 shape: S,
567 device: &Device,
568 ) -> Result<Self> {
569 let shape = shape.into_shape(array.len())?;
570 let storage = device.storage_from_slice(array)?;
571 let none = BackpropOp::none();
572 Ok(from_storage(storage, shape, none, false))
573 }
574
575 pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
576 let lhs = self.shape();
577 let rhs = rhs.shape();
578 if lhs != rhs {
579 Err(Error::ShapeMismatchBinaryOp {
580 lhs: lhs.clone(),
581 rhs: rhs.clone(),
582 op,
583 }
584 .bt())
585 } else {
586 Ok(lhs)
587 }
588 }
589
590 pub fn track_op(&self) -> bool {
593 self.is_variable || self.op.is_some()
594 }
595
596 pub fn from_storage<S: Into<Shape>>(
602 storage: Storage,
603 shape: S,
604 op: BackpropOp,
605 is_variable: bool,
606 ) -> Tensor {
607 from_storage(storage, shape, op, is_variable)
608 }
609
610 binary_op!(add, Add);
613 binary_op!(mul, Mul);
614 binary_op!(sub, Sub);
615 binary_op!(div, Div);
616 binary_op_scalar!(maximum, Maximum);
617 binary_op_scalar!(minimum, Minimum);
618 broadcast_binary_op!(broadcast_add, add);
619 broadcast_binary_op!(broadcast_mul, mul);
620 broadcast_binary_op!(broadcast_sub, sub);
621 broadcast_binary_op!(broadcast_div, div);
622 broadcast_binary_op!(broadcast_maximum, maximum);
623 broadcast_binary_op!(broadcast_minimum, minimum);
624 broadcast_binary_op!(broadcast_eq, eq);
625 broadcast_binary_op!(broadcast_ne, ne);
626 broadcast_binary_op!(broadcast_lt, lt);
627 broadcast_binary_op!(broadcast_le, le);
628 broadcast_binary_op!(broadcast_gt, gt);
629 broadcast_binary_op!(broadcast_ge, ge);
630
631 unary_op!(recip, Recip);
632 unary_op!(neg, Neg);
633 unary_op!(exp, Exp);
634 unary_op!(log, Log);
635 unary_op!(sin, Sin);
636 unary_op!(cos, Cos);
637 unary_op!(tanh, Tanh);
638 unary_op!(abs, Abs);
639 unary_op!(sqr, Sqr);
640 unary_op!(sqrt, Sqrt);
641 unary_op!(gelu, Gelu);
642 unary_op!(gelu_erf, GeluErf);
643 unary_op!(erf, Erf);
644 unary_op!(relu, Relu);
645 unary_op!(silu, Silu);
646 unary_op!(ceil, Ceil);
647 unary_op!(floor, Floor);
648 unary_op!(round, Round);
649 unary_op!(sign, Sign);
650
651 pub fn round_to(&self, decimals: i32) -> Result<Self> {
656 let mult = 10f64.powi(decimals);
657 (self * mult)?.round()? * (1f64 / mult)
658 }
659
660 pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
663 if self.rank() != 0 {
664 Err(Error::UnexpectedNumberOfDims {
665 expected: 0,
666 got: self.rank(),
667 shape: self.shape().clone(),
668 }
669 .bt())?
670 }
671 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
672 let data = S::cpu_storage_as_slice(cpu_storage)?;
673 Ok::<_, Error>(data[self.layout().start_offset()])
674 };
675 match &*self.storage() {
676 Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
677 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
678 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
679 }
680 }
681
682 pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
684 self.to_scalar::<S>()
685 }
686
687 pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
689 let repeats = shape.into();
691 let repeats = repeats.dims();
692 let mut inp = if self.rank() < repeats.len() {
693 let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
694 self.reshape(shape)?
695 } else {
696 self.clone()
697 };
698 for (idx, &repeat) in repeats.iter().enumerate() {
699 if repeat > 1 {
700 inp = Tensor::cat(&vec![&inp; repeat], idx)?
701 }
702 }
703 Ok(inp)
704 }
705
706 pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
743 if args.len() <= 1 {
744 Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
745 }
746 let args: Vec<_> = if xy_indexing {
747 args.iter().rev().collect()
748 } else {
749 args.iter().collect()
750 };
751
752 let mut shape = Vec::with_capacity(args.len());
753 for arg in args.iter() {
754 shape.push(arg.as_ref().dims1()?)
755 }
756
757 let mut grids = Vec::with_capacity(args.len());
758 for idx in 0..args.len() {
759 let mut ones = vec![1usize; args.len()];
760 ones[idx] = shape[idx];
761 let arg = args[idx].as_ref().reshape(ones)?;
762 let mut repeats = shape.clone();
763 repeats[idx] = 1;
764 let repeated_tensor = arg.repeat(repeats)?;
765 grids.push(repeated_tensor);
766 }
767 if xy_indexing {
768 grids.reverse();
769 }
770 Ok(grids)
771 }
772
773 pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
785 if self.elem_count() == 0 {
786 return Ok(self.clone());
787 }
788 let storage = self.storage().affine(self.layout(), mul, add)?;
789 let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
790 Ok(from_storage(storage, self.shape(), op, false))
791 }
792
793 pub fn elu(&self, alpha: f64) -> Result<Self> {
795 if self.elem_count() == 0 {
796 return Ok(self.clone());
797 }
798 let storage = self.storage().elu(self.layout(), alpha)?;
799 let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
800 Ok(from_storage(storage, self.shape(), op, false))
801 }
802
803 pub fn powf(&self, e: f64) -> Result<Self> {
805 if self.elem_count() == 0 {
806 return Ok(self.clone());
807 }
808 let storage = self.storage().powf(self.layout(), e)?;
809 let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
810 Ok(from_storage(storage, self.shape(), op, false))
811 }
812
813 pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
814 if dim >= self.dims().len() {
815 Err(Error::DimOutOfRange {
816 shape: self.shape().clone(),
817 dim: dim as i32,
818 op,
819 }
820 .bt())?
821 } else {
822 Ok(())
823 }
824 }
825
826 pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
829 let dim = dim.to_index(self.shape(), "chunk")?;
830 let size = self.dim(dim)?;
831 if size < chunks {
832 (0..size).map(|i| self.narrow(dim, i, 1)).collect()
833 } else {
834 let chunk_size = size / chunks;
835 let cnt_additional = size % chunks;
836 let mut tensors = vec![];
837 let mut sum_chunk_size = 0;
838 for i in 0..chunks {
839 let chunk_size = if i < cnt_additional {
840 chunk_size + 1
841 } else {
842 chunk_size
843 };
844 let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
845 tensors.push(tensor);
846 sum_chunk_size += chunk_size
847 }
848 Ok(tensors)
849 }
850 }
851
852 pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
879 let dims = self.dims();
880 let dim = dim.to_index(self.shape(), "narrow")?;
881 let err = |msg| {
882 Err::<(), _>(
883 Error::NarrowInvalidArgs {
884 shape: self.shape().clone(),
885 dim,
886 start,
887 len,
888 msg,
889 }
890 .bt(),
891 )
892 };
893 if start > dims[dim] {
894 err("start > dim_len")?
895 }
896 if start.saturating_add(len) > dims[dim] {
897 err("start + len > dim_len")?
898 }
899 if start == 0 && dims[dim] == len {
900 Ok(self.clone())
901 } else {
902 let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
903 let layout = self.layout().narrow(dim, start, len)?;
904 let tensor_ = Tensor_ {
905 id: TensorId::new(),
906 storage: self.storage.clone(),
907 layout,
908 op,
909 is_variable: false,
910 dtype: self.dtype,
911 device: self.device.clone(),
912 };
913 Ok(Tensor(Arc::new(tensor_)))
914 }
915 }
916
917 fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
918 match dims {
919 [] => Ok(self),
920 [i] => self.squeeze(*i),
921 dims => {
922 let dims = self
923 .dims()
924 .iter()
925 .enumerate()
926 .filter_map(|(dim_idx, &v)| {
927 if dims.contains(&dim_idx) {
928 None
929 } else {
930 Some(v)
931 }
932 })
933 .collect::<Vec<_>>();
934 self.reshape(dims)
935 }
936 }
937 }
938
939 fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
940 let dim = dim.to_index(self.shape(), op.name())?;
941 let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
942 let mut dims = self.dims().to_vec();
943 dims[dim] = 1;
944 let op = match op {
945 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
946 BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
947 }
948 ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
949 };
950 let res = from_storage(storage, dims, op, false);
951 if keepdim {
952 Ok(res)
953 } else {
954 res.squeeze_dims(&[dim])
955 }
956 }
957
958 fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
959 let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
960 let storage = self
961 .storage()
962 .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
963 let mut dims = self.dims().to_vec();
964 for &sum_dim in sum_dims.iter() {
965 dims[sum_dim] = 1
966 }
967 let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
968 let sum = from_storage(storage, dims, op, false);
969 if keepdim {
970 Ok(sum)
971 } else {
972 sum.squeeze_dims(&sum_dims)
973 }
974 }
975
976 pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
990 where
991 D: Dim + Clone,
992 {
993 let dim = dim.to_index(self.shape(), "roll")?;
994 let dim_size = self.dim(dim)?;
995 let shift = shift.rem_euclid(dim_size as i32) as usize;
996 if shift == 0 {
997 Ok(self.clone())
998 } else {
999 let a = self.narrow(dim, 0, dim_size - shift)?;
1000 let b = self.narrow(dim, dim_size - shift, shift)?;
1001 Tensor::cat(&[&b, &a], dim)
1002 }
1003 }
1004
1005 pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1023 self.sum_impl(sum_dims, true)
1024 }
1025
1026 pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1030 self.sum_impl(sum_dims, false)
1031 }
1032
1033 pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1051 let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
1052 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1053 let scale = 1f64 / (reduced_dim as f64);
1054 self.sum_impl(mean_dims, true)? * scale
1055 }
1056
1057 pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1061 let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
1062 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1063 let scale = 1f64 / (reduced_dim as f64);
1064 self.sum_impl(mean_dims, false)? * scale
1065 }
1066
1067 pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1069 let dim = dim.to_index(self.shape(), "var")?;
1070 let mean = self.mean_keepdim(dim)?;
1071 let squares = self.broadcast_sub(&mean)?.sqr()?;
1072 squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1073 }
1074
1075 pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1077 let dim = dim.to_index(self.shape(), "var")?;
1078 self.var_keepdim(dim)?.squeeze(dim)
1079 }
1080
1081 pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1084 self.reduce_impl(dim, true, ReduceOp::Max)
1085 }
1086
1087 pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1089 self.reduce_impl(dim, false, ReduceOp::Max)
1090 }
1091
1092 pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1095 self.reduce_impl(dim, true, ReduceOp::Min)
1096 }
1097
1098 pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1100 self.reduce_impl(dim, false, ReduceOp::Min)
1101 }
1102
1103 pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1104 self.reduce_impl(dim, true, ReduceOp::ArgMax)
1105 }
1106
1107 pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1109 self.reduce_impl(dim, false, ReduceOp::ArgMax)
1110 }
1111
1112 pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1113 self.reduce_impl(dim, true, ReduceOp::ArgMin)
1114 }
1115
1116 pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1118 self.reduce_impl(dim, false, ReduceOp::ArgMin)
1119 }
1120
1121 pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1126 let rhs = match rhs.to_tensor_scalar()? {
1127 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1128 crate::scalar::TensorScalar::Scalar(rhs) => rhs
1129 .to_dtype(self.dtype())?
1130 .to_device(self.device())?
1131 .broadcast_as(self.shape())?,
1132 };
1133 let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1134 let storage = self
1135 .storage()
1136 .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1137 let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1138 Ok(from_storage(storage, shape.dims(), op, false))
1139 }
1140
1141 pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1143 self.cmp(rhs, CmpOp::Eq)
1144 }
1145
1146 pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1148 self.cmp(rhs, CmpOp::Ne)
1149 }
1150
1151 pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1154 self.cmp(rhs, CmpOp::Lt)
1155 }
1156
1157 pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1160 self.cmp(rhs, CmpOp::Gt)
1161 }
1162
1163 pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1166 self.cmp(rhs, CmpOp::Ge)
1167 }
1168
1169 pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1172 self.cmp(rhs, CmpOp::Le)
1173 }
1174
1175 pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1177 self.maximum(min)?.minimum(max)
1178 }
1179
1180 pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1185 let (n, c, _l) = self.dims3()?;
1186 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1187 let storage = self
1188 .storage()
1189 .upsample_nearest1d(self.layout(), target_size)?;
1190 Ok(from_storage(storage, (n, c, target_size), op, false))
1191 }
1192
1193 pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1195 self.interpolate1d(target_size)
1196 }
1197
1198 pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1204 let (n, c, _h, _w) = self.dims4()?;
1205 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1206 arg,
1207 target_h,
1208 target_w,
1209 });
1210 let storage = self
1211 .storage()
1212 .upsample_nearest2d(self.layout(), target_h, target_w)?;
1213 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1214 }
1215
1216 pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1218 self.interpolate2d(target_h, target_w)
1219 }
1220
1221 pub fn upsample_bilinear2d(
1245 &self,
1246 target_h: usize,
1247 target_w: usize,
1248 align_corners: bool,
1249 ) -> Result<Self> {
1250 let (n, c, _h, _w) = self.dims4()?;
1251 let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1252 arg,
1253 target_h,
1254 target_w,
1255 align_corners,
1256 });
1257 let storage = self.storage().upsample_bilinear2d(
1259 self.layout(),
1260 target_h,
1261 target_w,
1262 align_corners,
1263 None,
1264 None,
1265 )?;
1266 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1267 }
1268
1269 pub fn upsample_bilinear2d_with_scale(
1293 &self,
1294 scale_h: f64,
1295 scale_w: f64,
1296 align_corners: bool,
1297 ) -> Result<Self> {
1298 let (n, c, height_in, width_in) = self.dims4()?;
1299
1300 let height_out = (height_in as f64 * scale_h).floor() as usize;
1302 let width_out = (width_in as f64 * scale_w).floor() as usize;
1303
1304 if height_in == height_out && width_in == width_out {
1306 return Ok(self.clone());
1307 }
1308
1309 let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1310 arg,
1311 target_h: height_out,
1312 target_w: width_out,
1313 align_corners,
1314 });
1315
1316 let storage = self.storage().upsample_bilinear2d(
1319 self.layout(),
1320 height_out,
1321 width_out,
1322 align_corners,
1323 Some(scale_h),
1324 Some(scale_w),
1325 )?;
1326 Ok(from_storage(
1327 storage,
1328 (n, c, height_out, width_out),
1329 op,
1330 false,
1331 ))
1332 }
1333
1334 pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1341 let sz = sz.to_usize2();
1342 self.avg_pool2d_with_stride(sz, sz)
1343 }
1344
1345 pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1348 &self,
1349 kernel_size: T,
1350 stride: T,
1351 ) -> Result<Self> {
1352 let kernel_size = kernel_size.to_usize2();
1353 let stride = stride.to_usize2();
1354 let (n, c, h, w) = self.dims4()?;
1355 if h < kernel_size.0 || w < kernel_size.1 {
1356 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1357 }
1358 let h_out = (h - kernel_size.0) / stride.0 + 1;
1360 let w_out = (w - kernel_size.1) / stride.1 + 1;
1361 let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1362 arg,
1363 kernel_size,
1364 stride,
1365 });
1366 let storage = self
1367 .storage()
1368 .avg_pool2d(self.layout(), kernel_size, stride)?;
1369 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1370 }
1371
1372 pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1379 let sz = sz.to_usize2();
1380 self.max_pool2d_with_stride(sz, sz)
1381 }
1382
1383 pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1386 &self,
1387 kernel_size: T,
1388 stride: T,
1389 ) -> Result<Self> {
1390 let kernel_size = kernel_size.to_usize2();
1391 let stride = stride.to_usize2();
1392 let (n, c, h, w) = self.dims4()?;
1393 if h < kernel_size.0 || w < kernel_size.1 {
1394 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1395 }
1396 let h_out = (h - kernel_size.0) / stride.0 + 1;
1398 let w_out = (w - kernel_size.1) / stride.1 + 1;
1399 let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1400 arg,
1401 kernel_size,
1402 stride,
1403 });
1404 let storage = self
1405 .storage()
1406 .max_pool2d(self.layout(), kernel_size, stride)?;
1407 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1408 }
1409
1410 pub fn dot(&self, rhs: &Self) -> Result<Self> {
1426 if self.dims().len() != 1 || rhs.dims().len() != 1 {
1427 return Err(Error::ShapeMismatchBinaryOp {
1428 lhs: self.shape().clone(),
1429 rhs: rhs.shape().clone(),
1430 op: "dot",
1431 });
1432 }
1433
1434 (self * rhs).and_then(|ret| ret.sum_all())
1435 }
1436
1437 pub fn norm(&self) -> Result<Self> {
1450 if self.dtype().is_int() {
1451 bail!("norm not supported for integer dtypes");
1452 }
1453
1454 self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
1455 }
1456
1457 pub fn mv(&self, rhs: &Self) -> Result<Self> {
1472 let lhs_dims = self.dims();
1474 let rhs_dims = rhs.dims();
1475 if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
1476 return Err(Error::ShapeMismatchBinaryOp {
1477 lhs: self.shape().clone(),
1478 rhs: rhs.shape().clone(),
1479 op: "mv",
1480 });
1481 }
1482
1483 self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
1485 }
1486
1487 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1496 let a_dims = self.shape().dims();
1497 let b_dims = rhs.shape().dims();
1498
1499 let dim = a_dims.len();
1500
1501 if dim < 2 || b_dims.len() != dim {
1502 Err(Error::ShapeMismatchBinaryOp {
1503 lhs: self.shape().clone(),
1504 rhs: rhs.shape().clone(),
1505 op: "matmul",
1506 }
1507 .bt())?
1508 }
1509
1510 let m = a_dims[dim - 2];
1511 let k = a_dims[dim - 1];
1512 let k2 = b_dims[dim - 2];
1513 let n = b_dims[dim - 1];
1514
1515 let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1516 if c_shape.elem_count() == 0 || k == 0 {
1517 return Tensor::zeros(c_shape, self.dtype(), self.device());
1518 }
1519 let batching: usize = a_dims[..dim - 2].iter().product();
1520 let batching_b: usize = b_dims[..dim - 2].iter().product();
1521 if k != k2 || batching != batching_b {
1522 Err(Error::ShapeMismatchBinaryOp {
1523 lhs: self.shape().clone(),
1524 rhs: rhs.shape().clone(),
1525 op: "matmul",
1526 }
1527 .bt())?
1528 }
1529
1530 let storage = self.storage().matmul(
1531 &rhs.storage(),
1532 (batching, m, n, k),
1533 self.layout(),
1534 rhs.layout(),
1535 )?;
1536 let op = BackpropOp::new2(self, rhs, Op::Matmul);
1537 Ok(from_storage(storage, c_shape, op, false))
1538 }
1539
1540 pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1546 let lhs = self;
1547 let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1548 let l_broadcast = l_shape != *lhs.shape();
1549 let r_broadcast = r_shape != *rhs.shape();
1550 match (l_broadcast, r_broadcast) {
1552 (true, true) => lhs
1553 .broadcast_as(&l_shape)?
1554 .contiguous()?
1555 .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1556 (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1557 (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1558 (false, false) => lhs.matmul(rhs),
1559 }
1560 }
1561
1562 pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1566 let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1567 let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1568 let storage = self.storage().where_cond(
1569 self.layout(),
1570 &on_true.storage(),
1571 on_true.layout(),
1572 &on_false.storage(),
1573 on_false.layout(),
1574 )?;
1575 let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1576 Ok(from_storage(storage, shape, op, false))
1577 }
1578
1579 pub fn embedding(&self, ids: &Self) -> Result<Self> {
1599 if self.rank() != 2 || ids.rank() != 1 {
1600 Err(Error::ShapeMismatchBinaryOp {
1601 lhs: self.shape().clone(),
1602 rhs: ids.shape().clone(),
1603 op: "embedding",
1604 }
1605 .bt())?
1606 }
1607 self.index_select(ids, 0)
1608 }
1609
1610 fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1611 let source_dims = source.dims();
1612 let self_dims = self.dims();
1613 let mismatch = if source_dims.len() != self_dims.len() {
1614 true
1615 } else {
1616 let mut mismatch = false;
1617 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1618 if i != dim && d1 != d2 {
1619 mismatch = true;
1620 break;
1621 }
1622 }
1623 mismatch
1624 };
1625 if mismatch {
1626 Err(Error::ShapeMismatchBinaryOp {
1627 op: "scatter (self, src)",
1628 lhs: self.shape().clone(),
1629 rhs: source.shape().clone(),
1630 }
1631 .bt())?
1632 }
1633 if indexes.dims() != source.dims() {
1634 Err(Error::ShapeMismatchBinaryOp {
1635 op: "scatter (indexes, src)",
1636 lhs: indexes.shape().clone(),
1637 rhs: source.shape().clone(),
1638 }
1639 .bt())?
1640 }
1641 Ok(())
1642 }
1643
1644 pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1645 let dim = dim.to_index(self.shape(), "scatter")?;
1646 self.scatter_checks(indexes, source, dim)?;
1647 let shape = self.shape();
1648 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1649 self.storage()
1650 .copy_strided_src(&mut storage, 0, self.layout())?;
1651 let layout = Layout::contiguous(shape);
1652 storage.scatter_set(
1653 &layout,
1654 &indexes.storage(),
1655 indexes.layout(),
1656 &source.storage(),
1657 source.layout(),
1658 dim,
1659 )?;
1660 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1661 Op::Scatter(t1, t2, t3, dim)
1662 });
1663 Ok(from_storage(storage, self.shape(), op, false))
1664 }
1665
1666 pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1667 if self.same_storage(source) {
1668 crate::bail!("cannot use slice_set when self and src share their storage")
1669 }
1670 let dim = dim.to_index(self.shape(), "scatter-set")?;
1671 self.scatter_checks(indexes, source, dim)?;
1672 self.storage_mut().scatter_set(
1673 self.layout(),
1674 &indexes.storage(),
1675 indexes.layout(),
1676 &source.storage(),
1677 source.layout(),
1678 dim,
1679 )?;
1680 Ok(())
1681 }
1682
1683 pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1684 let dim = dim.to_index(self.shape(), "scatter-add")?;
1685 self.scatter_checks(indexes, source, dim)?;
1686 let shape = self.shape();
1687 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1688 self.storage()
1689 .copy_strided_src(&mut storage, 0, self.layout())?;
1690 let layout = Layout::contiguous(shape);
1691 storage.scatter_add(
1692 &layout,
1693 &indexes.storage(),
1694 indexes.layout(),
1695 &source.storage(),
1696 source.layout(),
1697 dim,
1698 )?;
1699 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1700 Op::ScatterAdd(t1, t2, t3, dim)
1701 });
1702 Ok(from_storage(storage, self.shape(), op, false))
1703 }
1704
1705 pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1706 if self.same_storage(source) {
1707 crate::bail!("cannot use slice_set when self and src share their storage")
1708 }
1709 let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1710 self.scatter_checks(indexes, source, dim)?;
1711 self.storage_mut().scatter_add(
1712 self.layout(),
1713 &indexes.storage(),
1714 indexes.layout(),
1715 &source.storage(),
1716 source.layout(),
1717 dim,
1718 )?;
1719 Ok(())
1720 }
1721
1722 pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1724 let dim = dim.to_index(self.shape(), "slice-scatter")?;
1725 if dim == 0 {
1726 self.slice_scatter0(src, start)
1727 } else {
1728 self.transpose(0, dim)?
1730 .slice_scatter0(&src.transpose(0, dim)?, start)?
1731 .transpose(0, dim)
1732 }
1733 }
1734
1735 pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1737 if self.dtype() != src.dtype() {
1738 Err(Error::DTypeMismatchBinaryOp {
1739 lhs: self.dtype(),
1740 rhs: src.dtype(),
1741 op: "slice-scatter",
1742 }
1743 .bt())?
1744 }
1745 if self.device().location() != src.device.location() {
1746 Err(Error::DeviceMismatchBinaryOp {
1747 lhs: self.device().location(),
1748 rhs: src.device().location(),
1749 op: "slice-scatter",
1750 }
1751 .bt())?
1752 }
1753 if self.rank() != src.rank() {
1754 Err(Error::UnexpectedNumberOfDims {
1755 expected: self.rank(),
1756 got: src.rank(),
1757 shape: src.shape().clone(),
1758 }
1759 .bt())?
1760 }
1761 let shape_ok =
1762 self.dims()
1763 .iter()
1764 .zip(src.dims().iter())
1765 .enumerate()
1766 .all(|(dim_idx, (&d1, &d2))| {
1767 if 0 == dim_idx {
1768 d2 + start <= d1
1769 } else {
1770 d1 == d2
1771 }
1772 });
1773 if !shape_ok {
1774 Err(Error::ShapeMismatchBinaryOp {
1775 op: "slice-scatter (self, src)",
1776 lhs: self.shape().clone(),
1777 rhs: src.shape().clone(),
1778 }
1779 .bt())?
1780 }
1781 let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1782 self.storage()
1783 .copy_strided_src(&mut storage, 0, self.layout())?;
1784 let offset = start * src.dims()[1..].iter().product::<usize>();
1785 src.storage()
1786 .copy_strided_src(&mut storage, offset, src.layout())?;
1787 let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1788 Ok(from_storage(storage, self.shape(), op, false))
1789 }
1790
1791 pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1793 let dim = dim.to_index(self.shape(), "index-add")?;
1794 let source_dims = source.dims();
1795 let self_dims = self.dims();
1796 let mismatch = if source_dims.len() != self_dims.len() {
1797 true
1798 } else {
1799 let mut mismatch = false;
1800 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1801 if i != dim && d1 != d2 {
1802 mismatch = true;
1803 break;
1804 }
1805 }
1806 mismatch
1807 };
1808 if mismatch {
1809 Err(Error::ShapeMismatchBinaryOp {
1810 op: "index-add (self, source)",
1811 lhs: self.shape().clone(),
1812 rhs: source.shape().clone(),
1813 }
1814 .bt())?
1815 }
1816 let indexes_len = indexes.dims1()?;
1820 if source_dims[dim] != indexes_len {
1821 Err(Error::ShapeMismatchBinaryOp {
1822 op: "index-add (ids, source))",
1823 lhs: indexes.shape().clone(),
1824 rhs: source.shape().clone(),
1825 }
1826 .bt())?
1827 }
1828 let storage = self.storage().index_add(
1829 self.layout(),
1830 &indexes.storage(),
1831 indexes.layout(),
1832 &source.storage(),
1833 source.layout(),
1834 dim,
1835 )?;
1836 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1837 Op::IndexAdd(t1, t2, t3, dim)
1838 });
1839 Ok(from_storage(storage, self.shape(), op, false))
1840 }
1841
1842 pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1854 let dim = dim.to_index(self.shape(), "gather")?;
1855
1856 let self_dims = self.dims();
1857 let indexes_dims = indexes.dims();
1858 let mismatch = if indexes_dims.len() != self_dims.len() {
1859 true
1860 } else {
1861 let mut mismatch = false;
1862 for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1863 if i != dim && d1 < d2 {
1864 mismatch = true;
1865 break;
1866 }
1867 }
1868 mismatch
1869 };
1870 if mismatch {
1871 Err(Error::ShapeMismatchBinaryOp {
1872 op: "gather",
1873 lhs: self.shape().clone(),
1874 rhs: indexes.shape().clone(),
1875 }
1876 .bt())?
1877 }
1878 let storage =
1879 self.storage()
1880 .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1881 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1882 Ok(from_storage(storage, indexes.shape(), op, false))
1883 }
1884
1885 pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1893 let dim = dim.to_index(self.shape(), "index-select")?;
1894 let indexes_len = match indexes.dims() {
1895 [l] => *l,
1896 _ => Err(Error::ShapeMismatchBinaryOp {
1897 lhs: self.shape().clone(),
1898 rhs: indexes.shape().clone(),
1899 op: "index-select",
1900 }
1901 .bt())?,
1902 };
1903 let storage = self.storage().index_select(
1904 &indexes.storage(),
1905 self.layout(),
1906 indexes.layout(),
1907 dim,
1908 )?;
1909 let mut dims = self.dims().to_vec();
1910 dims[dim] = indexes_len;
1911 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1912 Ok(from_storage(storage, dims, op, false))
1913 }
1914
1915 pub fn strided_index(&self) -> crate::StridedIndex<'_> {
1918 self.layout.strided_index()
1919 }
1920
1921 pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
1926 self.layout.strided_blocks()
1927 }
1928
1929 pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1931 if self.rank() != 1 {
1932 Err(Error::UnexpectedNumberOfDims {
1933 expected: 1,
1934 got: self.rank(),
1935 shape: self.shape().clone(),
1936 }
1937 .bt())?
1938 }
1939 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1940 let data = S::cpu_storage_as_slice(cpu_storage)?;
1941 let data = match self.layout.contiguous_offsets() {
1942 Some((o1, o2)) => data[o1..o2].to_vec(),
1943 None => self.strided_index().map(|i| data[i]).collect(),
1944 };
1945 Ok::<Vec<_>, Error>(data)
1946 };
1947 match &*self.storage() {
1948 Storage::Cpu(storage) => from_cpu_storage(storage),
1949 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1950 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1951 }
1952 }
1953
1954 pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1956 let (dim1, dim2) = self.dims2()?;
1957 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1958 let data = S::cpu_storage_as_slice(cpu_storage)?;
1959 let mut rows = vec![];
1960 match self.layout.contiguous_offsets() {
1961 Some((o1, o2)) => {
1962 let data = &data[o1..o2];
1963 for idx_row in 0..dim1 {
1964 rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1965 }
1966 }
1967 None => {
1968 let mut src_index = self.strided_index();
1969 for _idx_row in 0..dim1 {
1970 let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1971 rows.push(row)
1972 }
1973 assert!(src_index.next().is_none());
1974 }
1975 }
1976 Ok(rows)
1977 };
1978 match &*self.storage() {
1979 Storage::Cpu(storage) => from_cpu_storage(storage),
1980 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1981 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1982 }
1983 }
1984
1985 pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1987 let (dim1, dim2, dim3) = self.dims3()?;
1988 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1989 let data = S::cpu_storage_as_slice(cpu_storage)?;
1990 let mut top_rows = vec![];
1991 match self.layout.contiguous_offsets() {
1992 Some((o1, o2)) => {
1993 let data = &data[o1..o2];
1994 let dim23 = dim2 * dim3;
1995 for idx1 in 0..dim1 {
1996 let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1997 let mut rows = vec![];
1998 for idx2 in 0..dim2 {
1999 rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
2000 }
2001 top_rows.push(rows);
2002 }
2003 }
2004 None => {
2005 let mut src_index = self.strided_index();
2006 for _idx in 0..dim1 {
2007 let mut rows = vec![];
2008 for _jdx in 0..dim2 {
2009 let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
2010 rows.push(row)
2011 }
2012 top_rows.push(rows);
2013 }
2014 assert!(src_index.next().is_none());
2015 }
2016 }
2017 Ok(top_rows)
2018 };
2019 match &*self.storage() {
2020 Storage::Cpu(storage) => from_cpu_storage(storage),
2021 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2022 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2023 }
2024 }
2025
2026 pub fn dtype(&self) -> DType {
2028 self.dtype
2029 }
2030
2031 pub fn device(&self) -> &Device {
2033 &self.device
2034 }
2035
2036 pub fn shape(&self) -> &Shape {
2038 self.layout().shape()
2039 }
2040
2041 pub fn dims(&self) -> &[usize] {
2043 self.shape().dims()
2044 }
2045
2046 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
2048 let dim = dim.to_index(self.shape(), "dim")?;
2049 Ok(self.dims()[dim])
2050 }
2051
2052 pub fn layout(&self) -> &Layout {
2055 &self.layout
2056 }
2057
2058 pub fn stride(&self) -> &[usize] {
2059 self.layout.stride()
2060 }
2061
2062 pub fn rank(&self) -> usize {
2064 self.shape().rank()
2065 }
2066
2067 pub fn elem_count(&self) -> usize {
2069 self.shape().elem_count()
2070 }
2071
2072 pub fn id(&self) -> TensorId {
2074 self.id
2075 }
2076
2077 pub fn is_variable(&self) -> bool {
2080 self.is_variable
2081 }
2082
2083 pub(crate) fn op(&self) -> &Option<Op> {
2084 &self.op
2085 }
2086
2087 pub fn max_all(&self) -> Result<Tensor> {
2098 if self.rank() == 0 {
2099 Ok(self.clone())
2100 } else {
2101 self.flatten_all()?.max(0)
2102 }
2103 }
2104
2105 pub fn min_all(&self) -> Result<Tensor> {
2116 if self.rank() == 0 {
2117 Ok(self.clone())
2118 } else {
2119 self.flatten_all()?.min(0)
2120 }
2121 }
2122
2123 pub fn sum_all(&self) -> Result<Tensor> {
2134 let dims: Vec<_> = (0..self.rank()).collect();
2135 self.sum(dims)
2136 }
2137
2138 pub fn mean_all(&self) -> Result<Tensor> {
2139 self.sum_all()? / self.elem_count() as f64
2140 }
2141
2142 fn flatten_<D1: Dim, D2: Dim>(
2143 &self,
2144 start_dim: Option<D1>,
2145 end_dim: Option<D2>,
2146 ) -> Result<Tensor> {
2147 if self.rank() == 0 {
2148 self.reshape(1)
2149 } else {
2150 let start_dim = match start_dim {
2151 None => 0,
2152 Some(dim) => dim.to_index(self.shape(), "flatten")?,
2153 };
2154 let end_dim = match end_dim {
2155 None => self.rank() - 1,
2156 Some(dim) => dim.to_index(self.shape(), "flatten")?,
2157 };
2158 if start_dim < end_dim {
2159 let dims = self.dims();
2160 let mut dst_dims = dims[..start_dim].to_vec();
2161 dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
2162 if end_dim + 1 < dims.len() {
2163 dst_dims.extend(&dims[end_dim + 1..]);
2164 }
2165 self.reshape(dst_dims)
2166 } else {
2167 Ok(self.clone())
2168 }
2169 }
2170 }
2171
2172 pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
2175 self.flatten_(Some(start_dim), Some(end_dim))
2176 }
2177
2178 pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
2180 self.flatten_(None::<usize>, Some(end_dim))
2181 }
2182
2183 pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
2186 self.flatten_(Some(start_dim), None::<usize>)
2187 }
2188
2189 pub fn flatten_all(&self) -> Result<Tensor> {
2199 self.flatten_(None::<usize>, None::<usize>)
2200 }
2201
2202 pub fn get(&self, i: usize) -> Result<Tensor> {
2214 let dims = self.dims();
2215 if dims.is_empty() {
2216 Ok(self.clone())
2217 } else {
2218 self.narrow(0, i, 1)?.reshape(&dims[1..])
2219 }
2220 }
2221
2222 pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
2236 let dim = dim.to_index(self.shape(), "get_on_dim")?;
2237 self.narrow(dim, index, 1)?.squeeze(dim)
2238 }
2239
2240 pub fn t(&self) -> Result<Tensor> {
2251 let rank = self.rank();
2252 if rank < 2 {
2253 Err(Error::UnexpectedNumberOfDims {
2254 expected: 2,
2255 got: rank,
2256 shape: self.shape().clone(),
2257 }
2258 .bt())?
2259 }
2260 self.transpose(rank - 2, rank - 1)
2261 }
2262
2263 pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2266 let dim1 = dim1.to_index(self.shape(), "transpose")?;
2267 let dim2 = dim2.to_index(self.shape(), "transpose")?;
2268 if dim1 == dim2 {
2269 return Ok(self.clone());
2270 }
2271 let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2272 let tensor_ = Tensor_ {
2273 id: TensorId::new(),
2274 storage: self.storage.clone(),
2275 layout: self.layout.transpose(dim1, dim2)?,
2276 op,
2277 is_variable: false,
2278 dtype: self.dtype,
2279 device: self.device.clone(),
2280 };
2281 Ok(Tensor(Arc::new(tensor_)))
2282 }
2283
2284 pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2296 let dims = dims.to_indexes(self.shape(), "permute")?;
2297 let is_permutation =
2299 dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2300 if !is_permutation {
2301 bail!(
2302 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2303 self.dims(),
2304 dims
2305 )
2306 }
2307 let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2308 let tensor_ = Tensor_ {
2309 id: TensorId::new(),
2310 storage: self.storage.clone(),
2311 layout: self.layout.permute(&dims)?,
2312 op,
2313 is_variable: false,
2314 dtype: self.dtype,
2315 device: self.device.clone(),
2316 };
2317 Ok(Tensor(Arc::new(tensor_)))
2318 }
2319
2320 pub fn is_contiguous(&self) -> bool {
2322 self.layout.is_contiguous()
2323 }
2324
2325 pub fn is_fortran_contiguous(&self) -> bool {
2327 self.layout.is_fortran_contiguous()
2328 }
2329
2330 pub fn copy(&self) -> Result<Tensor> {
2333 let op = BackpropOp::new1(self, Op::Copy);
2334 let tensor_ = Tensor_ {
2335 id: TensorId::new(),
2336 storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2337 layout: self.layout.clone(),
2338 op,
2339 is_variable: false,
2340 dtype: self.dtype,
2341 device: self.device.clone(),
2342 };
2343 Ok(Tensor(Arc::new(tensor_)))
2344 }
2345
2346 pub fn detach(&self) -> Tensor {
2351 if self.op.is_none() && !self.is_variable {
2352 self.clone()
2353 } else {
2354 let tensor_ = Tensor_ {
2355 id: TensorId::new(),
2356 storage: self.storage.clone(),
2357 layout: self.layout.clone(),
2358 op: BackpropOp::none(),
2359 is_variable: false,
2360 dtype: self.dtype,
2361 device: self.device.clone(),
2362 };
2363 Tensor(Arc::new(tensor_))
2364 }
2365 }
2366
2367 pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2369 if self.device().same_device(device) {
2370 Ok(self.clone())
2371 } else {
2372 let storage = match (&*self.storage(), device) {
2373 (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2374 Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2375 }
2376 (Storage::Cpu(storage), Device::Metal(metal)) => {
2377 Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2378 }
2379 (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2380 (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2381 (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2382 let dst_storage = storage.transfer_to_device(cuda)?;
2384 Storage::Cuda(dst_storage)
2385 }
2386 (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2387 _ => {
2388 bail!(
2389 "not implemented yet, self.device: {:?}, device: {:?}",
2390 self.device(),
2391 device
2392 )
2393 }
2394 };
2395 let op = BackpropOp::new1(self, Op::ToDevice);
2396 let tensor_ = Tensor_ {
2397 id: TensorId::new(),
2398 storage: Arc::new(RwLock::new(storage)),
2399 layout: self.layout.clone(),
2400 op,
2401 is_variable: false,
2402 dtype: self.dtype,
2403 device: device.clone(),
2404 };
2405 Ok(Tensor(Arc::new(tensor_)))
2406 }
2407 }
2408
2409 pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2412 let left_shape = left_shape.into();
2413 let mut dims = left_shape.into_dims();
2414 dims.extend(self.dims());
2415 self.broadcast_as(dims)
2416 }
2417
2418 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2426 let tensor_ = Tensor_ {
2427 id: TensorId::new(),
2428 storage: self.storage.clone(),
2429 layout: self.layout.broadcast_as(shape)?,
2430 op: BackpropOp::new1(self, Op::Broadcast),
2431 is_variable: false,
2432 dtype: self.dtype,
2433 device: self.device.clone(),
2434 };
2435 Ok(Tensor(Arc::new(tensor_)))
2436 }
2437
2438 pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2440 self.broadcast_as(shape)
2441 }
2442
2443 pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2454 if self.dtype() == dtype {
2455 Ok(self.clone())
2456 } else {
2457 let shape = self.shape();
2458 let storage = self.storage().to_dtype(self.layout(), dtype)?;
2459 let op = BackpropOp::new1(self, Op::ToDType);
2460 Ok(from_storage(storage, shape.clone(), op, false))
2461 }
2462 }
2463
2464 pub fn contiguous(&self) -> Result<Tensor> {
2467 if self.is_contiguous() {
2468 Ok(self.clone())
2469 } else {
2470 let shape = self.shape();
2471 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2472 self.storage()
2473 .copy_strided_src(&mut storage, 0, self.layout())?;
2474 let op = BackpropOp::new1(self, Op::Copy);
2475 Ok(from_storage(storage, shape.clone(), op, false))
2476 }
2477 }
2478
2479 pub fn force_contiguous(&self) -> Result<Tensor> {
2481 let shape = self.shape();
2482 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2483 self.storage()
2484 .copy_strided_src(&mut storage, 0, self.layout())?;
2485 let op = BackpropOp::new1(self, Op::Copy);
2486 Ok(from_storage(storage, shape.clone(), op, false))
2487 }
2488
2489 pub(crate) fn make_var(&self) -> Result<Tensor> {
2492 let shape = self.shape().clone();
2493 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2494 self.storage()
2495 .copy_strided_src(&mut storage, 0, self.layout())?;
2496 Ok(from_storage(storage, shape, BackpropOp::none(), true))
2497 }
2498
2499 pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2524 let shape = s.into_shape(self.elem_count())?;
2525 if shape.elem_count() != self.elem_count() {
2526 return Err(Error::ShapeMismatchBinaryOp {
2527 lhs: self.shape().clone(),
2528 rhs: shape,
2529 op: "reshape",
2530 }
2531 .bt());
2532 }
2533 let op = BackpropOp::new1(self, Op::Reshape);
2534 if self.is_contiguous() {
2535 let tensor_ = Tensor_ {
2536 id: TensorId::new(),
2537 storage: self.storage.clone(),
2538 layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2539 op,
2540 is_variable: false,
2541 dtype: self.dtype,
2542 device: self.device.clone(),
2543 };
2544 Ok(Tensor(Arc::new(tensor_)))
2545 } else {
2546 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2547 self.storage()
2548 .copy_strided_src(&mut storage, 0, self.layout())?;
2549 Ok(from_storage(storage, shape, op, false))
2550 }
2551 }
2552
2553 pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2567 let dims = self.dims();
2570 let dim = dim.to_index(self.shape(), "squeeze")?;
2571 if dims[dim] == 1 {
2572 let mut dims = dims.to_vec();
2573 let mut strides = self.stride().to_vec();
2574 dims.remove(dim);
2575 strides.remove(dim);
2576 let tensor_ = Tensor_ {
2577 id: TensorId::new(),
2578 storage: self.storage.clone(),
2579 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2580 op: BackpropOp::new1(self, Op::Reshape),
2581 is_variable: false,
2582 dtype: self.dtype,
2583 device: self.device.clone(),
2584 };
2585 Ok(Tensor(Arc::new(tensor_)))
2586 } else {
2587 Ok(self.clone())
2588 }
2589 }
2590
2591 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2605 let mut dims = self.dims().to_vec();
2606 let mut strides = self.stride().to_vec();
2607 let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2608 dims.insert(dim, 1);
2610 let stride = if dim < strides.len() { strides[dim] } else { 1 };
2613 strides.insert(dim, stride);
2614 let tensor_ = Tensor_ {
2615 id: TensorId::new(),
2616 storage: self.storage.clone(),
2617 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2618 op: BackpropOp::new1(self, Op::Reshape),
2619 is_variable: false,
2620 dtype: self.dtype,
2621 device: self.device.clone(),
2622 };
2623 Ok(Tensor(Arc::new(tensor_)))
2624 }
2625
2626 pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2643 if args.is_empty() {
2644 Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2645 }
2646 let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2647 let args = args
2648 .iter()
2649 .map(|t| t.as_ref().unsqueeze(dim))
2650 .collect::<Result<Vec<_>>>()?;
2651 Self::cat(&args, dim)
2652 }
2653
2654 pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2657 if left == 0 && right == 0 {
2658 Ok(self.clone())
2659 } else if left == 0 {
2660 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2661 let mut dims = self.dims().to_vec();
2662 dims[dim] = right;
2663 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2664 Tensor::cat(&[self, &right], dim)
2665 } else if right == 0 {
2666 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2667 let mut dims = self.dims().to_vec();
2668 dims[dim] = left;
2669 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2670 Tensor::cat(&[&left, self], dim)
2671 } else {
2672 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2673 let mut dims = self.dims().to_vec();
2674 dims[dim] = left;
2675 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2676 dims[dim] = right;
2677 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2678 Tensor::cat(&[&left, self, &right], dim)
2679 }
2680 }
2681
2682 pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2685 if left == 0 && right == 0 {
2686 Ok(self.clone())
2687 } else if self.elem_count() == 0 {
2688 bail!("cannot use pad_with_same on an empty tensor")
2689 } else if left == 0 {
2690 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2691 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2692 let mut v = vec![self];
2693 for _ in 0..right {
2694 v.push(&r)
2695 }
2696 Tensor::cat(&v, dim)
2697 } else if right == 0 {
2698 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2699 let l = self.narrow(dim, 0, 1)?;
2700 let mut v = vec![];
2701 for _ in 0..left {
2702 v.push(&l)
2703 }
2704 v.push(self);
2705 Tensor::cat(&v, dim)
2706 } else {
2707 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2708 let l = self.narrow(dim, 0, 1)?;
2709 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2710 let mut v = vec![];
2711 for _ in 0..left {
2712 v.push(&l)
2713 }
2714 v.push(self);
2715 for _ in 0..right {
2716 v.push(&r)
2717 }
2718 Tensor::cat(&v, dim)
2719 }
2720 }
2721
2722 pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2724 m.forward(self)
2725 }
2726
2727 pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2729 m.forward_t(self, train)
2730 }
2731
2732 pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2733 self.storage.read().unwrap()
2734 }
2735
2736 pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2737 self.storage.write().unwrap()
2738 }
2739
2740 pub(crate) fn storage_mut_and_layout(
2743 &self,
2744 ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2745 let storage = self.storage.write().unwrap();
2746 (storage, &self.layout)
2747 }
2748
2749 pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2751 let storage = self.storage.read().unwrap();
2752 (storage, &self.layout)
2753 }
2754
2755 pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2756 let lhs: &RwLock<Storage> = self.storage.as_ref();
2757 let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2758 std::ptr::eq(lhs, rhs)
2759 }
2760
2761 pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2764 let rank = self.rank() as i64;
2765 if rank <= axis {
2766 bail!("axis {axis} is too large, tensor rank {rank}")
2767 } else if 0 <= axis {
2768 Ok(axis as usize)
2769 } else {
2770 let naxis = rank + axis;
2771 if naxis < 0 {
2772 bail!("axis {axis} is too small, tensor rank {rank}")
2773 }
2774 Ok(naxis as usize)
2775 }
2776 }
2777
2778 pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2780 let t = Tensor::arange(0u32, n as u32, device)?;
2781 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2782 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2783 t1.le(&t2)?.to_dtype(dtype)
2784 }
2785
2786 pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2788 let t = Tensor::arange(0u32, n as u32, device)?;
2789 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2790 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2791 t1.ge(&t2)?.to_dtype(dtype)
2792 }
2793
2794 pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2796 let t = Tensor::arange(0u32, n as u32, device)?;
2797 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2798 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2799 t1.eq(&t2)?.to_dtype(dtype)
2800 }
2801
2802 pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2807 let dim = dim.to_index(self.shape(), "cumsum")?;
2808 let rank = self.rank();
2809 if rank == 0 {
2810 return Ok(self.clone());
2811 }
2812 let n_axis = self.dim(dim)?;
2813 let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2814 if rank == 1 {
2815 self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2816 } else {
2817 let last = rank - 1;
2818 let t = self.transpose(dim, last)?;
2819 let t = t.broadcast_matmul(&triu)?;
2820 t.transpose(dim, last)
2821 }
2822 }
2823
2824 pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2827 &self,
2828 ranges: &[D],
2829 src: &Tensor,
2830 ) -> Result<Self> {
2831 let src_dims = src.dims();
2832 let self_dims = self.dims();
2833 if self_dims.len() != src_dims.len() {
2834 bail!(
2835 "slice-assign requires input with the same rank {} <> {}",
2836 self_dims.len(),
2837 src_dims.len()
2838 )
2839 }
2840 if self_dims.len() != ranges.len() {
2841 bail!(
2842 "slice-assign requires input with the same rank as there are ranges {} <> {}",
2843 self_dims.len(),
2844 ranges.len()
2845 )
2846 }
2847 let mut src = src.clone();
2848 let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2849 for (i, range) in ranges.iter().enumerate() {
2850 let start_included = match range.start_bound() {
2851 std::ops::Bound::Unbounded => 0,
2852 std::ops::Bound::Included(v) => *v,
2853 std::ops::Bound::Excluded(v) => *v + 1,
2854 };
2855 let end_excluded = match range.end_bound() {
2856 std::ops::Bound::Unbounded => self_dims[i],
2857 std::ops::Bound::Included(v) => *v + 1,
2858 std::ops::Bound::Excluded(v) => *v,
2859 };
2860 if end_excluded <= start_included {
2861 bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2862 }
2863 if self_dims[i] < end_excluded {
2864 bail!(
2865 "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2866 self_dims[i]
2867 )
2868 }
2869 if end_excluded - start_included != src_dims[i] {
2870 bail!(
2871 "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2872 )
2873 }
2874 src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2875 mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2876 }
2877 mask.where_cond(&src, self)
2878 }
2879
2880 pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2882 let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2883 if sum_dims.is_empty() {
2884 return Ok(self.clone());
2885 }
2886 let max = sum_dims[1..]
2887 .iter()
2888 .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2889 max.max_keepdim(dim)
2890 })?;
2891 let exp = self.broadcast_sub(&max)?.exp()?;
2892 let sum = exp.sum(sum_dims.clone())?;
2893
2894 sum.log()? + max.squeeze_dims(&sum_dims)
2895 }
2896
2897 pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2899 rhs.mul(&self.log()?)?.exp()
2900 }
2901
2902 pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2904 rhs.broadcast_mul(&self.log()?)?.exp()
2905 }
2906
2907 pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2919 let mut result = self.clone();
2920 for &dim in dims.iter() {
2921 let size = result.dim(dim)?;
2922 let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2923 let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2924 result = result.index_select(&indices_tensor, dim)?;
2925 }
2926 Ok(result)
2927 }
2928
2929 pub fn unfold<D: Dim>(&self, dim: D, size: usize, step: usize) -> Result<Self> {
2932 let mut sizes = self.dims().to_vec();
2934 let mut strides = self.stride().to_vec();
2935
2936 let dim = dim.to_index(self.shape(), "unfold")?;
2937
2938 let max_len = if self.dims().is_empty() {
2939 1
2940 } else {
2941 sizes[dim]
2942 };
2943 if size > max_len {
2944 bail!(
2945 "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}"
2946 )
2947 }
2948 sizes.push(size);
2949 strides.push(if self.dims().is_empty() {
2950 1
2951 } else {
2952 strides[dim]
2953 });
2954
2955 if !self.dims().is_empty() {
2956 sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize;
2957 strides[dim] *= step;
2958 }
2959
2960 let tensor_ = Tensor_ {
2961 id: TensorId::new(),
2962 storage: self.storage.clone(),
2963 layout: Layout::new(sizes.into(), strides, self.layout.start_offset()),
2964 op: BackpropOp::new1(self, Op::Reshape),
2965 is_variable: false,
2966 dtype: self.dtype,
2967 device: self.device.clone(),
2968 };
2969 Ok(Tensor(Arc::new(tensor_)))
2970 }
2971}
2972
2973macro_rules! bin_trait {
2974 ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2975 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2976 type Output = Result<Tensor>;
2977
2978 fn $fn1(self, rhs: B) -> Self::Output {
2979 Tensor::$fn1(&self, rhs.borrow())
2980 }
2981 }
2982
2983 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2984 type Output = Result<Tensor>;
2985
2986 fn $fn1(self, rhs: B) -> Self::Output {
2987 Tensor::$fn1(&self, rhs.borrow())
2988 }
2989 }
2990
2991 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2992 type Output = Result<Tensor>;
2993
2994 fn $fn1(self, rhs: Tensor) -> Self::Output {
2995 Tensor::$fn1(self?.borrow(), &rhs)
2996 }
2997 }
2998
2999 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
3000 type Output = Result<Tensor>;
3001
3002 fn $fn1(self, rhs: &Tensor) -> Self::Output {
3003 Tensor::$fn1(self?.borrow(), rhs)
3004 }
3005 }
3006
3007 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
3008 type Output = Result<Tensor>;
3009
3010 fn $fn1(self, rhs: Result<B>) -> Self::Output {
3011 Tensor::$fn1(&self, rhs?.borrow())
3012 }
3013 }
3014
3015 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
3016 type Output = Result<Tensor>;
3017
3018 fn $fn1(self, rhs: Result<B>) -> Self::Output {
3019 Tensor::$fn1(&self, rhs?.borrow())
3020 }
3021 }
3022
3023 impl std::ops::$trait<f64> for Tensor {
3024 type Output = Result<Tensor>;
3025
3026 fn $fn1(self, rhs: f64) -> Self::Output {
3027 self.affine($mul(rhs), $add(rhs))
3028 }
3029 }
3030
3031 impl std::ops::$trait<f64> for &Tensor {
3032 type Output = Result<Tensor>;
3033
3034 fn $fn1(self, rhs: f64) -> Self::Output {
3035 self.affine($mul(rhs), $add(rhs))
3036 }
3037 }
3038 };
3039}
3040
3041bin_trait!(Add, add, |_| 1., |v| v);
3042bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
3043bin_trait!(Mul, mul, |v| v, |_| 0.);
3044bin_trait!(Div, div, |v| 1. / v, |_| 0.);
3045
3046impl std::ops::Add<Tensor> for f64 {
3047 type Output = Result<Tensor>;
3048
3049 fn add(self, rhs: Tensor) -> Self::Output {
3050 rhs + self
3051 }
3052}
3053
3054impl std::ops::Add<&Tensor> for f64 {
3055 type Output = Result<Tensor>;
3056
3057 fn add(self, rhs: &Tensor) -> Self::Output {
3058 rhs + self
3059 }
3060}
3061
3062impl std::ops::Mul<Tensor> for f64 {
3063 type Output = Result<Tensor>;
3064
3065 fn mul(self, rhs: Tensor) -> Self::Output {
3066 rhs * self
3067 }
3068}
3069
3070impl std::ops::Mul<&Tensor> for f64 {
3071 type Output = Result<Tensor>;
3072
3073 fn mul(self, rhs: &Tensor) -> Self::Output {
3074 rhs * self
3075 }
3076}
3077
3078impl std::ops::Sub<Tensor> for f64 {
3079 type Output = Result<Tensor>;
3080
3081 fn sub(self, rhs: Tensor) -> Self::Output {
3082 rhs.affine(-1., self)
3083 }
3084}
3085
3086impl std::ops::Sub<&Tensor> for f64 {
3087 type Output = Result<Tensor>;
3088
3089 fn sub(self, rhs: &Tensor) -> Self::Output {
3090 rhs.affine(-1., self)
3091 }
3092}
3093
3094impl std::ops::Div<Tensor> for f64 {
3095 type Output = Result<Tensor>;
3096
3097 #[allow(clippy::suspicious_arithmetic_impl)]
3098 fn div(self, rhs: Tensor) -> Self::Output {
3099 rhs.recip()? * self
3100 }
3101}
3102
3103impl std::ops::Div<&Tensor> for f64 {
3104 type Output = Result<Tensor>;
3105
3106 #[allow(clippy::suspicious_arithmetic_impl)]
3107 fn div(self, rhs: &Tensor) -> Self::Output {
3108 rhs.recip()? * self
3109 }
3110}
3111
3112impl<S: Into<Shape>> From<(Storage, S)> for Tensor {
3113 fn from((storage, shape): (Storage, S)) -> Self {
3114 from_storage(storage, shape, BackpropOp::none(), false)
3115 }
3116}