1#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15 fn new() -> Self {
16 use std::sync::atomic;
18 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20 }
21}
22
23pub struct Tensor_ {
24 id: TensorId,
25 storage: Arc<RwLock<Storage>>,
38 layout: Layout,
39 op: BackpropOp,
40 is_variable: bool,
41 dtype: DType,
42 device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46 fn as_ref(&self) -> &Tensor {
47 self
48 }
49}
50
51#[derive(Clone)]
55pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71 type Target = Tensor_;
72
73 fn deref(&self) -> &Self::Target {
74 self.0.as_ref()
75 }
76}
77
78macro_rules! unary_op {
79 ($fn_name:ident, $op_name:ident) => {
80 pub fn $fn_name(&self) -> Result<Self> {
81 let shape = self.shape();
82 if shape.elem_count() == 0 {
83 return Ok(self.clone());
84 }
85 let storage = self
86 .storage()
87 .unary_impl::<crate::op::$op_name>(self.layout())?;
88 let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89 Ok(from_storage(storage, shape.clone(), op, false))
90 }
91 };
92}
93
94macro_rules! binary_op {
95 ($fn_name:ident, $op_name:ident) => {
96 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97 let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98 if shape.elem_count() == 0 {
99 return Ok(self.clone());
100 }
101 let storage = self.storage().binary_impl::<crate::op::$op_name>(
102 &*rhs.storage(),
103 self.layout(),
104 rhs.layout(),
105 )?;
106 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107 Ok(from_storage(storage, shape.clone(), op, false))
108 }
109 };
110}
111
112macro_rules! binary_op_scalar {
113 ($fn_name:ident, $op_name:ident) => {
114 pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115 let rhs = match rhs.to_tensor_scalar()? {
116 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117 crate::scalar::TensorScalar::Scalar(rhs) => rhs
118 .to_dtype(self.dtype())?
119 .to_device(self.device())?
120 .broadcast_as(self.shape())?,
121 };
122 let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123 if self.elem_count() == 0 {
124 return Ok(self.clone());
125 }
126 let storage = self.storage().binary_impl::<crate::op::$op_name>(
127 &*rhs.storage(),
128 self.layout(),
129 rhs.layout(),
130 )?;
131 let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132 Ok(from_storage(storage, shape.clone(), op, false))
133 }
134 };
135}
136
137macro_rules! broadcast_binary_op {
138 ($fn_name:ident, $inner_fn_name:ident) => {
139 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140 let lhs = self;
141 let shape = lhs
142 .shape()
143 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144 let l_broadcast = shape != *lhs.shape();
145 let r_broadcast = shape != *rhs.shape();
146 match (l_broadcast, r_broadcast) {
147 (true, true) => lhs
148 .broadcast_as(&shape)?
149 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152 (false, false) => lhs.$inner_fn_name(rhs),
153 }
154 }
155 };
156}
157
158pub(crate) fn from_storage<S: Into<Shape>>(
160 storage: Storage,
161 shape: S,
162 op: BackpropOp,
163 is_variable: bool,
164) -> Tensor {
165 let dtype = storage.dtype();
166 let device = storage.device();
167 let tensor_ = Tensor_ {
168 id: TensorId::new(),
169 storage: Arc::new(RwLock::new(storage)),
170 layout: Layout::contiguous(shape),
171 op,
172 is_variable,
173 dtype,
174 device,
175 };
176 Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180 pub(crate) fn ones_impl<S: Into<Shape>>(
181 shape: S,
182 dtype: DType,
183 device: &Device,
184 is_variable: bool,
185 ) -> Result<Self> {
186 let none = BackpropOp::none();
187 let shape = shape.into();
188 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189 let layout = Layout::contiguous(shape.clone());
190 storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191 Ok(from_storage(storage, shape, none, is_variable))
192 }
193
194 pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204 Self::ones_impl(shape, dtype, device, false)
205 }
206
207 pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208 self.storage_mut().const_set(value, self.layout())
209 }
210
211 pub fn zero_set(&self) -> Result<()> {
212 self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213 }
214
215 pub fn one_set(&self) -> Result<()> {
216 self.const_set(crate::scalar::Scalar::one(self.dtype()))
217 }
218
219 pub fn ones_like(&self) -> Result<Self> {
229 Tensor::ones(self.shape(), self.dtype(), self.device())
230 }
231
232 pub(crate) fn zeros_impl<S: Into<Shape>>(
235 shape: S,
236 dtype: DType,
237 device: &Device,
238 is_variable: bool,
239 ) -> Result<Self> {
240 let none = BackpropOp::none();
241 let shape = shape.into();
242 let storage = device.zeros(&shape, dtype)?;
243 Ok(from_storage(storage, shape, none, is_variable))
244 }
245
246 pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256 Self::zeros_impl(shape, dtype, device, false)
257 }
258
259 pub fn zeros_like(&self) -> Result<Self> {
270 Tensor::zeros(self.shape(), self.dtype(), self.device())
271 }
272
273 pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
274 lo: T,
275 up: T,
276 s: S,
277 device: &Device,
278 is_variable: bool,
279 ) -> Result<Self> {
280 let s = s.into();
281 let storage = device.rand_uniform(lo, up, &s)?;
282 let none = BackpropOp::none();
283 Ok(from_storage(storage, s, none, is_variable))
284 }
285
286 pub(crate) fn rand_f64_impl<S: Into<Shape>>(
287 lo: f64,
288 up: f64,
289 s: S,
290 dtype: DType,
291 device: &Device,
292 is_variable: bool,
293 ) -> Result<Self> {
294 let s = s.into();
295 let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
296 let none = BackpropOp::none();
297 Ok(from_storage(storage, s, none, is_variable))
298 }
299
300 pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
302 lo: T,
303 up: T,
304 s: S,
305 device: &Device,
306 ) -> Result<Self> {
307 Self::rand_impl(lo, up, s, device, false)
308 }
309
310 pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
311 Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
312 }
313
314 pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
315 mean: T,
316 std: T,
317 s: S,
318 device: &Device,
319 is_variable: bool,
320 ) -> Result<Self> {
321 let s = s.into();
322 let storage = device.rand_normal(mean, std, &s)?;
323 let none = BackpropOp::none();
324 Ok(from_storage(storage, s, none, is_variable))
325 }
326
327 pub(crate) fn randn_f64_impl<S: Into<Shape>>(
328 mean: f64,
329 std: f64,
330 s: S,
331 dtype: DType,
332 device: &Device,
333 is_variable: bool,
334 ) -> Result<Self> {
335 let s = s.into();
336 let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
337 let none = BackpropOp::none();
338 Ok(from_storage(storage, s, none, is_variable))
339 }
340
341 pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
342 Tensor::randn_f64_impl(
343 mean,
344 stdev,
345 self.shape(),
346 self.dtype(),
347 self.device(),
348 false,
349 )
350 }
351
352 pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
355 mean: T,
356 std: T,
357 s: S,
358 device: &Device,
359 ) -> Result<Self> {
360 Self::randn_impl(mean, std, s, device, false)
361 }
362
363 pub(crate) fn new_impl<A: crate::device::NdArray>(
364 array: A,
365 shape: Shape,
366 device: &Device,
367 is_variable: bool,
368 ) -> Result<Self> {
369 let n: usize = shape.elem_count();
370 let buffer_size: usize = array.shape()?.elem_count();
371 if buffer_size != n {
372 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
373 }
374 let storage = device.storage(array)?;
375 let none = BackpropOp::none();
376 Ok(from_storage(storage, shape, none, is_variable))
377 }
378
379 pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
381 let shape = array.shape()?;
382 Self::new_impl(array, shape, device, false)
383 }
384
385 pub fn full<D: crate::WithDType, S: Into<Shape>>(
397 value: D,
398 shape: S,
399 device: &Device,
400 ) -> Result<Self> {
401 Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
402 }
403
404 pub fn from_iter<D: crate::WithDType>(
413 iter: impl IntoIterator<Item = D>,
414 device: &Device,
415 ) -> Result<Self> {
416 let data = iter.into_iter().collect::<Vec<_>>();
417 let len = data.len();
418 Self::from_vec_impl(data, len, device, false)
419 }
420
421 pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
431 Self::arange_step(start, end, D::one(), device)
432 }
433
434 pub fn arange_step<D: crate::WithDType>(
444 start: D,
445 end: D,
446 step: D,
447 device: &Device,
448 ) -> Result<Self> {
449 if D::is_zero(&step) {
450 bail!("step cannot be zero")
451 }
452 let mut data = vec![];
453 let mut current = start;
454 if step >= D::zero() {
455 while current < end {
456 data.push(current);
457 current += step;
458 }
459 } else {
460 while current > end {
461 data.push(current);
462 current += step;
463 }
464 }
465 let len = data.len();
466 Self::from_vec_impl(data, len, device, false)
467 }
468
469 pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
470 data: Vec<D>,
471 shape: S,
472 device: &Device,
473 is_variable: bool,
474 ) -> Result<Self> {
475 let shape = shape.into_shape(data.len())?;
476 let storage = device.storage_owned(data)?;
477 let none = BackpropOp::none();
478 Ok(from_storage(storage, shape, none, is_variable))
479 }
480
481 pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
495 data: Vec<D>,
496 shape: S,
497 device: &Device,
498 ) -> Result<Self> {
499 Self::from_vec_impl(data, shape, device, false)
500 }
501
502 pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
516 array: &[D],
517 shape: S,
518 device: &Device,
519 ) -> Result<Self> {
520 let shape = shape.into_shape(array.len())?;
521 let storage = device.storage_from_slice(array)?;
522 let none = BackpropOp::none();
523 Ok(from_storage(storage, shape, none, false))
524 }
525
526 pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
527 let lhs = self.shape();
528 let rhs = rhs.shape();
529 if lhs != rhs {
530 Err(Error::ShapeMismatchBinaryOp {
531 lhs: lhs.clone(),
532 rhs: rhs.clone(),
533 op,
534 }
535 .bt())
536 } else {
537 Ok(lhs)
538 }
539 }
540
541 pub fn track_op(&self) -> bool {
544 self.is_variable || self.op.is_some()
545 }
546
547 binary_op!(add, Add);
550 binary_op!(mul, Mul);
551 binary_op!(sub, Sub);
552 binary_op!(div, Div);
553 binary_op_scalar!(maximum, Maximum);
554 binary_op_scalar!(minimum, Minimum);
555 broadcast_binary_op!(broadcast_add, add);
556 broadcast_binary_op!(broadcast_mul, mul);
557 broadcast_binary_op!(broadcast_sub, sub);
558 broadcast_binary_op!(broadcast_div, div);
559 broadcast_binary_op!(broadcast_maximum, maximum);
560 broadcast_binary_op!(broadcast_minimum, minimum);
561 broadcast_binary_op!(broadcast_eq, eq);
562 broadcast_binary_op!(broadcast_ne, ne);
563 broadcast_binary_op!(broadcast_lt, lt);
564 broadcast_binary_op!(broadcast_le, le);
565 broadcast_binary_op!(broadcast_gt, gt);
566 broadcast_binary_op!(broadcast_ge, ge);
567
568 unary_op!(recip, Recip);
569 unary_op!(neg, Neg);
570 unary_op!(exp, Exp);
571 unary_op!(log, Log);
572 unary_op!(sin, Sin);
573 unary_op!(cos, Cos);
574 unary_op!(tanh, Tanh);
575 unary_op!(abs, Abs);
576 unary_op!(sqr, Sqr);
577 unary_op!(sqrt, Sqrt);
578 unary_op!(gelu, Gelu);
579 unary_op!(gelu_erf, GeluErf);
580 unary_op!(erf, Erf);
581 unary_op!(relu, Relu);
582 unary_op!(silu, Silu);
583 unary_op!(ceil, Ceil);
584 unary_op!(floor, Floor);
585 unary_op!(round, Round);
586 unary_op!(sign, Sign);
587
588 pub fn round_to(&self, decimals: i32) -> Result<Self> {
593 let mult = 10f64.powi(decimals);
594 (self * mult)?.round()? * (1f64 / mult)
595 }
596
597 pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
600 if self.rank() != 0 {
601 Err(Error::UnexpectedNumberOfDims {
602 expected: 0,
603 got: self.rank(),
604 shape: self.shape().clone(),
605 }
606 .bt())?
607 }
608 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
609 let data = S::cpu_storage_as_slice(cpu_storage)?;
610 Ok::<_, Error>(data[self.layout().start_offset()])
611 };
612 match &*self.storage() {
613 Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
614 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
615 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
616 }
617 }
618
619 pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
621 self.to_scalar::<S>()
622 }
623
624 pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
626 let repeats = shape.into();
628 let repeats = repeats.dims();
629 let mut inp = if self.rank() < repeats.len() {
630 let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
631 self.reshape(shape)?
632 } else {
633 self.clone()
634 };
635 for (idx, &repeat) in repeats.iter().enumerate() {
636 if repeat > 1 {
637 inp = Tensor::cat(&vec![&inp; repeat], idx)?
638 }
639 }
640 Ok(inp)
641 }
642
643 pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
680 if args.len() <= 1 {
681 Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
682 }
683 let args: Vec<_> = if xy_indexing {
684 args.iter().rev().collect()
685 } else {
686 args.iter().collect()
687 };
688
689 let mut shape = Vec::with_capacity(args.len());
690 for arg in args.iter() {
691 shape.push(arg.as_ref().dims1()?)
692 }
693
694 let mut grids = Vec::with_capacity(args.len());
695 for idx in 0..args.len() {
696 let mut ones = vec![1usize; args.len()];
697 ones[idx] = shape[idx];
698 let arg = args[idx].as_ref().reshape(ones)?;
699 let mut repeats = shape.clone();
700 repeats[idx] = 1;
701 let repeated_tensor = arg.repeat(repeats)?;
702 grids.push(repeated_tensor);
703 }
704 if xy_indexing {
705 grids.reverse();
706 }
707 Ok(grids)
708 }
709
710 pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
722 if self.elem_count() == 0 {
723 return Ok(self.clone());
724 }
725 let storage = self.storage().affine(self.layout(), mul, add)?;
726 let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
727 Ok(from_storage(storage, self.shape(), op, false))
728 }
729
730 pub fn elu(&self, alpha: f64) -> Result<Self> {
732 if self.elem_count() == 0 {
733 return Ok(self.clone());
734 }
735 let storage = self.storage().elu(self.layout(), alpha)?;
736 let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
737 Ok(from_storage(storage, self.shape(), op, false))
738 }
739
740 pub fn powf(&self, e: f64) -> Result<Self> {
742 if self.elem_count() == 0 {
743 return Ok(self.clone());
744 }
745 let storage = self.storage().powf(self.layout(), e)?;
746 let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
747 Ok(from_storage(storage, self.shape(), op, false))
748 }
749
750 pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
751 if dim >= self.dims().len() {
752 Err(Error::DimOutOfRange {
753 shape: self.shape().clone(),
754 dim: dim as i32,
755 op,
756 }
757 .bt())?
758 } else {
759 Ok(())
760 }
761 }
762
763 pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
766 let dim = dim.to_index(self.shape(), "chunk")?;
767 let size = self.dim(dim)?;
768 if size < chunks {
769 (0..size).map(|i| self.narrow(dim, i, 1)).collect()
770 } else {
771 let chunk_size = size / chunks;
772 let cnt_additional = size % chunks;
773 let mut tensors = vec![];
774 let mut sum_chunk_size = 0;
775 for i in 0..chunks {
776 let chunk_size = if i < cnt_additional {
777 chunk_size + 1
778 } else {
779 chunk_size
780 };
781 let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
782 tensors.push(tensor);
783 sum_chunk_size += chunk_size
784 }
785 Ok(tensors)
786 }
787 }
788
789 pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
816 let dims = self.dims();
817 let dim = dim.to_index(self.shape(), "narrow")?;
818 let err = |msg| {
819 Err::<(), _>(
820 Error::NarrowInvalidArgs {
821 shape: self.shape().clone(),
822 dim,
823 start,
824 len,
825 msg,
826 }
827 .bt(),
828 )
829 };
830 if start > dims[dim] {
831 err("start > dim_len")?
832 }
833 if start.saturating_add(len) > dims[dim] {
834 err("start + len > dim_len")?
835 }
836 if start == 0 && dims[dim] == len {
837 Ok(self.clone())
838 } else {
839 let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
840 let layout = self.layout().narrow(dim, start, len)?;
841 let tensor_ = Tensor_ {
842 id: TensorId::new(),
843 storage: self.storage.clone(),
844 layout,
845 op,
846 is_variable: false,
847 dtype: self.dtype,
848 device: self.device.clone(),
849 };
850 Ok(Tensor(Arc::new(tensor_)))
851 }
852 }
853
854 fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
855 match dims {
856 [] => Ok(self),
857 [i] => self.squeeze(*i),
858 dims => {
859 let dims = self
860 .dims()
861 .iter()
862 .enumerate()
863 .filter_map(|(dim_idx, &v)| {
864 if dims.contains(&dim_idx) {
865 None
866 } else {
867 Some(v)
868 }
869 })
870 .collect::<Vec<_>>();
871 self.reshape(dims)
872 }
873 }
874 }
875
876 fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
877 let dim = dim.to_index(self.shape(), op.name())?;
878 let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
879 let mut dims = self.dims().to_vec();
880 dims[dim] = 1;
881 let op = match op {
882 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
883 BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
884 }
885 ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
886 };
887 let res = from_storage(storage, dims, op, false);
888 if keepdim {
889 Ok(res)
890 } else {
891 res.squeeze_dims(&[dim])
892 }
893 }
894
895 fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
896 let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
897 let storage = self
898 .storage()
899 .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
900 let mut dims = self.dims().to_vec();
901 for &sum_dim in sum_dims.iter() {
902 dims[sum_dim] = 1
903 }
904 let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
905 let sum = from_storage(storage, dims, op, false);
906 if keepdim {
907 Ok(sum)
908 } else {
909 sum.squeeze_dims(&sum_dims)
910 }
911 }
912
913 pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
927 where
928 D: Dim + Clone,
929 {
930 let dim = dim.to_index(self.shape(), "roll")?;
931 let dim_size = self.dim(dim)?;
932 let shift = shift.rem_euclid(dim_size as i32) as usize;
933 if shift == 0 {
934 Ok(self.clone())
935 } else {
936 let a = self.narrow(dim, 0, dim_size - shift)?;
937 let b = self.narrow(dim, dim_size - shift, shift)?;
938 Tensor::cat(&[&b, &a], dim)
939 }
940 }
941
942 pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
960 self.sum_impl(sum_dims, true)
961 }
962
963 pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
967 self.sum_impl(sum_dims, false)
968 }
969
970 pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
988 let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
989 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
990 let scale = 1f64 / (reduced_dim as f64);
991 self.sum_impl(mean_dims, true)? * scale
992 }
993
994 pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
998 let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
999 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1000 let scale = 1f64 / (reduced_dim as f64);
1001 self.sum_impl(mean_dims, false)? * scale
1002 }
1003
1004 pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1006 let dim = dim.to_index(self.shape(), "var")?;
1007 let mean = self.mean_keepdim(dim)?;
1008 let squares = self.broadcast_sub(&mean)?.sqr()?;
1009 squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1010 }
1011
1012 pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1014 let dim = dim.to_index(self.shape(), "var")?;
1015 self.var_keepdim(dim)?.squeeze(dim)
1016 }
1017
1018 pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1021 self.reduce_impl(dim, true, ReduceOp::Max)
1022 }
1023
1024 pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1026 self.reduce_impl(dim, false, ReduceOp::Max)
1027 }
1028
1029 pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1032 self.reduce_impl(dim, true, ReduceOp::Min)
1033 }
1034
1035 pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1037 self.reduce_impl(dim, false, ReduceOp::Min)
1038 }
1039
1040 pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1041 self.reduce_impl(dim, true, ReduceOp::ArgMax)
1042 }
1043
1044 pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1046 self.reduce_impl(dim, false, ReduceOp::ArgMax)
1047 }
1048
1049 pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1050 self.reduce_impl(dim, true, ReduceOp::ArgMin)
1051 }
1052
1053 pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1055 self.reduce_impl(dim, false, ReduceOp::ArgMin)
1056 }
1057
1058 pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1063 let rhs = match rhs.to_tensor_scalar()? {
1064 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1065 crate::scalar::TensorScalar::Scalar(rhs) => rhs
1066 .to_dtype(self.dtype())?
1067 .to_device(self.device())?
1068 .broadcast_as(self.shape())?,
1069 };
1070 let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1071 let storage = self
1072 .storage()
1073 .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1074 let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1075 Ok(from_storage(storage, shape.dims(), op, false))
1076 }
1077
1078 pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1080 self.cmp(rhs, CmpOp::Eq)
1081 }
1082
1083 pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1085 self.cmp(rhs, CmpOp::Ne)
1086 }
1087
1088 pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1091 self.cmp(rhs, CmpOp::Lt)
1092 }
1093
1094 pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1097 self.cmp(rhs, CmpOp::Gt)
1098 }
1099
1100 pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1103 self.cmp(rhs, CmpOp::Ge)
1104 }
1105
1106 pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1109 self.cmp(rhs, CmpOp::Le)
1110 }
1111
1112 pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1114 self.maximum(min)?.minimum(max)
1115 }
1116
1117 pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1122 let (n, c, _l) = self.dims3()?;
1123 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1124 let storage = self
1125 .storage()
1126 .upsample_nearest1d(self.layout(), target_size)?;
1127 Ok(from_storage(storage, (n, c, target_size), op, false))
1128 }
1129
1130 pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1132 self.interpolate1d(target_size)
1133 }
1134
1135 pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1141 let (n, c, _h, _w) = self.dims4()?;
1142 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1143 arg,
1144 target_h,
1145 target_w,
1146 });
1147 let storage = self
1148 .storage()
1149 .upsample_nearest2d(self.layout(), target_h, target_w)?;
1150 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1151 }
1152
1153 pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1155 self.interpolate2d(target_h, target_w)
1156 }
1157
1158 pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1165 let sz = sz.to_usize2();
1166 self.avg_pool2d_with_stride(sz, sz)
1167 }
1168
1169 pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1172 &self,
1173 kernel_size: T,
1174 stride: T,
1175 ) -> Result<Self> {
1176 let kernel_size = kernel_size.to_usize2();
1177 let stride = stride.to_usize2();
1178 let (n, c, h, w) = self.dims4()?;
1179 if h < kernel_size.0 || w < kernel_size.1 {
1180 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1181 }
1182 let h_out = (h - kernel_size.0) / stride.0 + 1;
1184 let w_out = (w - kernel_size.1) / stride.1 + 1;
1185 let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1186 arg,
1187 kernel_size,
1188 stride,
1189 });
1190 let storage = self
1191 .storage()
1192 .avg_pool2d(self.layout(), kernel_size, stride)?;
1193 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1194 }
1195
1196 pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1203 let sz = sz.to_usize2();
1204 self.max_pool2d_with_stride(sz, sz)
1205 }
1206
1207 pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1210 &self,
1211 kernel_size: T,
1212 stride: T,
1213 ) -> Result<Self> {
1214 let kernel_size = kernel_size.to_usize2();
1215 let stride = stride.to_usize2();
1216 let (n, c, h, w) = self.dims4()?;
1217 if h < kernel_size.0 || w < kernel_size.1 {
1218 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1219 }
1220 let h_out = (h - kernel_size.0) / stride.0 + 1;
1222 let w_out = (w - kernel_size.1) / stride.1 + 1;
1223 let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1224 arg,
1225 kernel_size,
1226 stride,
1227 });
1228 let storage = self
1229 .storage()
1230 .max_pool2d(self.layout(), kernel_size, stride)?;
1231 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1232 }
1233
1234 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1243 let a_dims = self.shape().dims();
1244 let b_dims = rhs.shape().dims();
1245
1246 let dim = a_dims.len();
1247
1248 if dim < 2 || b_dims.len() != dim {
1249 Err(Error::ShapeMismatchBinaryOp {
1250 lhs: self.shape().clone(),
1251 rhs: rhs.shape().clone(),
1252 op: "matmul",
1253 }
1254 .bt())?
1255 }
1256
1257 let m = a_dims[dim - 2];
1258 let k = a_dims[dim - 1];
1259 let k2 = b_dims[dim - 2];
1260 let n = b_dims[dim - 1];
1261
1262 let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1263 if c_shape.elem_count() == 0 || k == 0 {
1264 return Tensor::zeros(c_shape, self.dtype(), self.device());
1265 }
1266 let batching: usize = a_dims[..dim - 2].iter().product();
1267 let batching_b: usize = b_dims[..dim - 2].iter().product();
1268 if k != k2 || batching != batching_b {
1269 Err(Error::ShapeMismatchBinaryOp {
1270 lhs: self.shape().clone(),
1271 rhs: rhs.shape().clone(),
1272 op: "matmul",
1273 }
1274 .bt())?
1275 }
1276
1277 let storage = self.storage().matmul(
1278 &rhs.storage(),
1279 (batching, m, n, k),
1280 self.layout(),
1281 rhs.layout(),
1282 )?;
1283 let op = BackpropOp::new2(self, rhs, Op::Matmul);
1284 Ok(from_storage(storage, c_shape, op, false))
1285 }
1286
1287 pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1293 let lhs = self;
1294 let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1295 let l_broadcast = l_shape != *lhs.shape();
1296 let r_broadcast = r_shape != *rhs.shape();
1297 match (l_broadcast, r_broadcast) {
1299 (true, true) => lhs
1300 .broadcast_as(&l_shape)?
1301 .contiguous()?
1302 .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1303 (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1304 (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1305 (false, false) => lhs.matmul(rhs),
1306 }
1307 }
1308
1309 pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1313 let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1314 let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1315 let storage = self.storage().where_cond(
1316 self.layout(),
1317 &on_true.storage(),
1318 on_true.layout(),
1319 &on_false.storage(),
1320 on_false.layout(),
1321 )?;
1322 let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1323 Ok(from_storage(storage, shape, op, false))
1324 }
1325
1326 pub fn embedding(&self, ids: &Self) -> Result<Self> {
1346 if self.rank() != 2 || ids.rank() != 1 {
1347 Err(Error::ShapeMismatchBinaryOp {
1348 lhs: self.shape().clone(),
1349 rhs: ids.shape().clone(),
1350 op: "embedding",
1351 }
1352 .bt())?
1353 }
1354 self.index_select(ids, 0)
1355 }
1356
1357 fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1358 let source_dims = source.dims();
1359 let self_dims = self.dims();
1360 let mismatch = if source_dims.len() != self_dims.len() {
1361 true
1362 } else {
1363 let mut mismatch = false;
1364 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1365 if i != dim && d1 != d2 {
1366 mismatch = true;
1367 break;
1368 }
1369 }
1370 mismatch
1371 };
1372 if mismatch {
1373 Err(Error::ShapeMismatchBinaryOp {
1374 op: "scatter (self, src)",
1375 lhs: self.shape().clone(),
1376 rhs: source.shape().clone(),
1377 }
1378 .bt())?
1379 }
1380 if indexes.dims() != source.dims() {
1381 Err(Error::ShapeMismatchBinaryOp {
1382 op: "scatter (indexes, src)",
1383 lhs: indexes.shape().clone(),
1384 rhs: source.shape().clone(),
1385 }
1386 .bt())?
1387 }
1388 Ok(())
1389 }
1390
1391 pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1392 let dim = dim.to_index(self.shape(), "scatter")?;
1393 self.scatter_checks(indexes, source, dim)?;
1394 let shape = self.shape();
1395 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1396 self.storage()
1397 .copy_strided_src(&mut storage, 0, self.layout())?;
1398 let layout = Layout::contiguous(shape);
1399 storage.scatter_set(
1400 &layout,
1401 &indexes.storage(),
1402 indexes.layout(),
1403 &source.storage(),
1404 source.layout(),
1405 dim,
1406 )?;
1407 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1408 Op::Scatter(t1, t2, t3, dim)
1409 });
1410 Ok(from_storage(storage, self.shape(), op, false))
1411 }
1412
1413 pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1414 if self.same_storage(source) {
1415 crate::bail!("cannot use slice_set when self and src share their storage")
1416 }
1417 let dim = dim.to_index(self.shape(), "scatter-set")?;
1418 self.scatter_checks(indexes, source, dim)?;
1419 self.storage_mut().scatter_set(
1420 self.layout(),
1421 &indexes.storage(),
1422 indexes.layout(),
1423 &source.storage(),
1424 source.layout(),
1425 dim,
1426 )?;
1427 Ok(())
1428 }
1429
1430 pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1431 let dim = dim.to_index(self.shape(), "scatter-add")?;
1432 self.scatter_checks(indexes, source, dim)?;
1433 let shape = self.shape();
1434 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1435 self.storage()
1436 .copy_strided_src(&mut storage, 0, self.layout())?;
1437 let layout = Layout::contiguous(shape);
1438 storage.scatter_add(
1439 &layout,
1440 &indexes.storage(),
1441 indexes.layout(),
1442 &source.storage(),
1443 source.layout(),
1444 dim,
1445 )?;
1446 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1447 Op::ScatterAdd(t1, t2, t3, dim)
1448 });
1449 Ok(from_storage(storage, self.shape(), op, false))
1450 }
1451
1452 pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1453 if self.same_storage(source) {
1454 crate::bail!("cannot use slice_set when self and src share their storage")
1455 }
1456 let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1457 self.scatter_checks(indexes, source, dim)?;
1458 self.storage_mut().scatter_add(
1459 self.layout(),
1460 &indexes.storage(),
1461 indexes.layout(),
1462 &source.storage(),
1463 source.layout(),
1464 dim,
1465 )?;
1466 Ok(())
1467 }
1468
1469 pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1471 let dim = dim.to_index(self.shape(), "slice-scatter")?;
1472 if dim == 0 {
1473 self.slice_scatter0(src, start)
1474 } else {
1475 self.transpose(0, dim)?
1477 .slice_scatter0(&src.transpose(0, dim)?, start)?
1478 .transpose(0, dim)
1479 }
1480 }
1481
1482 pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1484 if self.dtype() != src.dtype() {
1485 Err(Error::DTypeMismatchBinaryOp {
1486 lhs: self.dtype(),
1487 rhs: src.dtype(),
1488 op: "slice-scatter",
1489 }
1490 .bt())?
1491 }
1492 if self.device().location() != src.device.location() {
1493 Err(Error::DeviceMismatchBinaryOp {
1494 lhs: self.device().location(),
1495 rhs: src.device().location(),
1496 op: "slice-scatter",
1497 }
1498 .bt())?
1499 }
1500 if self.rank() != src.rank() {
1501 Err(Error::UnexpectedNumberOfDims {
1502 expected: self.rank(),
1503 got: src.rank(),
1504 shape: src.shape().clone(),
1505 }
1506 .bt())?
1507 }
1508 let shape_ok =
1509 self.dims()
1510 .iter()
1511 .zip(src.dims().iter())
1512 .enumerate()
1513 .all(|(dim_idx, (&d1, &d2))| {
1514 if 0 == dim_idx {
1515 d2 + start <= d1
1516 } else {
1517 d1 == d2
1518 }
1519 });
1520 if !shape_ok {
1521 Err(Error::ShapeMismatchBinaryOp {
1522 op: "slice-scatter (self, src)",
1523 lhs: self.shape().clone(),
1524 rhs: src.shape().clone(),
1525 }
1526 .bt())?
1527 }
1528 let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1529 self.storage()
1530 .copy_strided_src(&mut storage, 0, self.layout())?;
1531 let offset = start * src.dims()[1..].iter().product::<usize>();
1532 src.storage()
1533 .copy_strided_src(&mut storage, offset, src.layout())?;
1534 let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1535 Ok(from_storage(storage, self.shape(), op, false))
1536 }
1537
1538 pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1540 let dim = dim.to_index(self.shape(), "index-add")?;
1541 let source_dims = source.dims();
1542 let self_dims = self.dims();
1543 let mismatch = if source_dims.len() != self_dims.len() {
1544 true
1545 } else {
1546 let mut mismatch = false;
1547 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1548 if i != dim && d1 != d2 {
1549 mismatch = true;
1550 break;
1551 }
1552 }
1553 mismatch
1554 };
1555 if mismatch {
1556 Err(Error::ShapeMismatchBinaryOp {
1557 op: "index-add (self, source)",
1558 lhs: self.shape().clone(),
1559 rhs: source.shape().clone(),
1560 }
1561 .bt())?
1562 }
1563 let indexes_len = indexes.dims1()?;
1567 if source_dims[dim] != indexes_len {
1568 Err(Error::ShapeMismatchBinaryOp {
1569 op: "index-add (ids, source))",
1570 lhs: indexes.shape().clone(),
1571 rhs: source.shape().clone(),
1572 }
1573 .bt())?
1574 }
1575 let storage = self.storage().index_add(
1576 self.layout(),
1577 &indexes.storage(),
1578 indexes.layout(),
1579 &source.storage(),
1580 source.layout(),
1581 dim,
1582 )?;
1583 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1584 Op::IndexAdd(t1, t2, t3, dim)
1585 });
1586 Ok(from_storage(storage, self.shape(), op, false))
1587 }
1588
1589 pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1601 let dim = dim.to_index(self.shape(), "gather")?;
1602
1603 let self_dims = self.dims();
1604 let indexes_dims = indexes.dims();
1605 let mismatch = if indexes_dims.len() != self_dims.len() {
1606 true
1607 } else {
1608 let mut mismatch = false;
1609 for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1610 if i != dim && d1 < d2 {
1611 mismatch = true;
1612 break;
1613 }
1614 }
1615 mismatch
1616 };
1617 if mismatch {
1618 Err(Error::ShapeMismatchBinaryOp {
1619 op: "gather",
1620 lhs: self.shape().clone(),
1621 rhs: indexes.shape().clone(),
1622 }
1623 .bt())?
1624 }
1625 let storage =
1626 self.storage()
1627 .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1628 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1629 Ok(from_storage(storage, indexes.shape(), op, false))
1630 }
1631
1632 pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1640 let dim = dim.to_index(self.shape(), "index-select")?;
1641 let indexes_len = match indexes.dims() {
1642 [l] => *l,
1643 _ => Err(Error::ShapeMismatchBinaryOp {
1644 lhs: self.shape().clone(),
1645 rhs: indexes.shape().clone(),
1646 op: "index-select",
1647 }
1648 .bt())?,
1649 };
1650 let storage = self.storage().index_select(
1651 &indexes.storage(),
1652 self.layout(),
1653 indexes.layout(),
1654 dim,
1655 )?;
1656 let mut dims = self.dims().to_vec();
1657 dims[dim] = indexes_len;
1658 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1659 Ok(from_storage(storage, dims, op, false))
1660 }
1661
1662 pub fn strided_index(&self) -> crate::StridedIndex {
1665 self.layout.strided_index()
1666 }
1667
1668 pub fn strided_blocks(&self) -> crate::StridedBlocks {
1673 self.layout.strided_blocks()
1674 }
1675
1676 pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1678 if self.rank() != 1 {
1679 Err(Error::UnexpectedNumberOfDims {
1680 expected: 1,
1681 got: self.rank(),
1682 shape: self.shape().clone(),
1683 }
1684 .bt())?
1685 }
1686 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1687 let data = S::cpu_storage_as_slice(cpu_storage)?;
1688 let data = match self.layout.contiguous_offsets() {
1689 Some((o1, o2)) => data[o1..o2].to_vec(),
1690 None => self.strided_index().map(|i| data[i]).collect(),
1691 };
1692 Ok::<Vec<_>, Error>(data)
1693 };
1694 match &*self.storage() {
1695 Storage::Cpu(storage) => from_cpu_storage(storage),
1696 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1697 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1698 }
1699 }
1700
1701 pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1703 let (dim1, dim2) = self.dims2()?;
1704 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1705 let data = S::cpu_storage_as_slice(cpu_storage)?;
1706 let mut rows = vec![];
1707 match self.layout.contiguous_offsets() {
1708 Some((o1, o2)) => {
1709 let data = &data[o1..o2];
1710 for idx_row in 0..dim1 {
1711 rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1712 }
1713 }
1714 None => {
1715 let mut src_index = self.strided_index();
1716 for _idx_row in 0..dim1 {
1717 let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1718 rows.push(row)
1719 }
1720 assert!(src_index.next().is_none());
1721 }
1722 }
1723 Ok(rows)
1724 };
1725 match &*self.storage() {
1726 Storage::Cpu(storage) => from_cpu_storage(storage),
1727 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1728 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1729 }
1730 }
1731
1732 pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1734 let (dim1, dim2, dim3) = self.dims3()?;
1735 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1736 let data = S::cpu_storage_as_slice(cpu_storage)?;
1737 let mut top_rows = vec![];
1738 match self.layout.contiguous_offsets() {
1739 Some((o1, o2)) => {
1740 let data = &data[o1..o2];
1741 let dim23 = dim2 * dim3;
1742 for idx1 in 0..dim1 {
1743 let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1744 let mut rows = vec![];
1745 for idx2 in 0..dim2 {
1746 rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
1747 }
1748 top_rows.push(rows);
1749 }
1750 }
1751 None => {
1752 let mut src_index = self.strided_index();
1753 for _idx in 0..dim1 {
1754 let mut rows = vec![];
1755 for _jdx in 0..dim2 {
1756 let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
1757 rows.push(row)
1758 }
1759 top_rows.push(rows);
1760 }
1761 assert!(src_index.next().is_none());
1762 }
1763 }
1764 Ok(top_rows)
1765 };
1766 match &*self.storage() {
1767 Storage::Cpu(storage) => from_cpu_storage(storage),
1768 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1769 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1770 }
1771 }
1772
1773 pub fn dtype(&self) -> DType {
1775 self.dtype
1776 }
1777
1778 pub fn device(&self) -> &Device {
1780 &self.device
1781 }
1782
1783 pub fn shape(&self) -> &Shape {
1785 self.layout().shape()
1786 }
1787
1788 pub fn dims(&self) -> &[usize] {
1790 self.shape().dims()
1791 }
1792
1793 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
1795 let dim = dim.to_index(self.shape(), "dim")?;
1796 Ok(self.dims()[dim])
1797 }
1798
1799 pub fn layout(&self) -> &Layout {
1802 &self.layout
1803 }
1804
1805 pub fn stride(&self) -> &[usize] {
1806 self.layout.stride()
1807 }
1808
1809 pub fn rank(&self) -> usize {
1811 self.shape().rank()
1812 }
1813
1814 pub fn elem_count(&self) -> usize {
1816 self.shape().elem_count()
1817 }
1818
1819 pub fn id(&self) -> TensorId {
1821 self.id
1822 }
1823
1824 pub fn is_variable(&self) -> bool {
1827 self.is_variable
1828 }
1829
1830 pub(crate) fn op(&self) -> &Option<Op> {
1831 &self.op
1832 }
1833
1834 pub fn max_all(&self) -> Result<Tensor> {
1845 if self.rank() == 0 {
1846 Ok(self.clone())
1847 } else {
1848 self.flatten_all()?.max(0)
1849 }
1850 }
1851
1852 pub fn min_all(&self) -> Result<Tensor> {
1863 if self.rank() == 0 {
1864 Ok(self.clone())
1865 } else {
1866 self.flatten_all()?.min(0)
1867 }
1868 }
1869
1870 pub fn sum_all(&self) -> Result<Tensor> {
1881 let dims: Vec<_> = (0..self.rank()).collect();
1882 self.sum(dims)
1883 }
1884
1885 pub fn mean_all(&self) -> Result<Tensor> {
1886 self.sum_all()? / self.elem_count() as f64
1887 }
1888
1889 fn flatten_<D1: Dim, D2: Dim>(
1890 &self,
1891 start_dim: Option<D1>,
1892 end_dim: Option<D2>,
1893 ) -> Result<Tensor> {
1894 if self.rank() == 0 {
1895 self.reshape(1)
1896 } else {
1897 let start_dim = match start_dim {
1898 None => 0,
1899 Some(dim) => dim.to_index(self.shape(), "flatten")?,
1900 };
1901 let end_dim = match end_dim {
1902 None => self.rank() - 1,
1903 Some(dim) => dim.to_index(self.shape(), "flatten")?,
1904 };
1905 if start_dim < end_dim {
1906 let dims = self.dims();
1907 let mut dst_dims = dims[..start_dim].to_vec();
1908 dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
1909 if end_dim + 1 < dims.len() {
1910 dst_dims.extend(&dims[end_dim + 1..]);
1911 }
1912 self.reshape(dst_dims)
1913 } else {
1914 Ok(self.clone())
1915 }
1916 }
1917 }
1918
1919 pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
1922 self.flatten_(Some(start_dim), Some(end_dim))
1923 }
1924
1925 pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
1927 self.flatten_(None::<usize>, Some(end_dim))
1928 }
1929
1930 pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
1933 self.flatten_(Some(start_dim), None::<usize>)
1934 }
1935
1936 pub fn flatten_all(&self) -> Result<Tensor> {
1946 self.flatten_(None::<usize>, None::<usize>)
1947 }
1948
1949 pub fn get(&self, i: usize) -> Result<Tensor> {
1961 let dims = self.dims();
1962 if dims.is_empty() {
1963 Ok(self.clone())
1964 } else {
1965 self.narrow(0, i, 1)?.reshape(&dims[1..])
1966 }
1967 }
1968
1969 pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
1983 let dim = dim.to_index(self.shape(), "get_on_dim")?;
1984 self.narrow(dim, index, 1)?.squeeze(dim)
1985 }
1986
1987 pub fn t(&self) -> Result<Tensor> {
1998 let rank = self.rank();
1999 if rank < 2 {
2000 Err(Error::UnexpectedNumberOfDims {
2001 expected: 2,
2002 got: rank,
2003 shape: self.shape().clone(),
2004 }
2005 .bt())?
2006 }
2007 self.transpose(rank - 2, rank - 1)
2008 }
2009
2010 pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2013 let dim1 = dim1.to_index(self.shape(), "transpose")?;
2014 let dim2 = dim2.to_index(self.shape(), "transpose")?;
2015 if dim1 == dim2 {
2016 return Ok(self.clone());
2017 }
2018 let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2019 let tensor_ = Tensor_ {
2020 id: TensorId::new(),
2021 storage: self.storage.clone(),
2022 layout: self.layout.transpose(dim1, dim2)?,
2023 op,
2024 is_variable: false,
2025 dtype: self.dtype,
2026 device: self.device.clone(),
2027 };
2028 Ok(Tensor(Arc::new(tensor_)))
2029 }
2030
2031 pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2043 let dims = dims.to_indexes(self.shape(), "permute")?;
2044 let is_permutation =
2046 dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2047 if !is_permutation {
2048 bail!(
2049 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2050 self.dims(),
2051 dims
2052 )
2053 }
2054 let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2055 let tensor_ = Tensor_ {
2056 id: TensorId::new(),
2057 storage: self.storage.clone(),
2058 layout: self.layout.permute(&dims)?,
2059 op,
2060 is_variable: false,
2061 dtype: self.dtype,
2062 device: self.device.clone(),
2063 };
2064 Ok(Tensor(Arc::new(tensor_)))
2065 }
2066
2067 pub fn is_contiguous(&self) -> bool {
2069 self.layout.is_contiguous()
2070 }
2071
2072 pub fn is_fortran_contiguous(&self) -> bool {
2074 self.layout.is_fortran_contiguous()
2075 }
2076
2077 pub fn copy(&self) -> Result<Tensor> {
2080 let op = BackpropOp::new1(self, Op::Copy);
2081 let tensor_ = Tensor_ {
2082 id: TensorId::new(),
2083 storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2084 layout: self.layout.clone(),
2085 op,
2086 is_variable: false,
2087 dtype: self.dtype,
2088 device: self.device.clone(),
2089 };
2090 Ok(Tensor(Arc::new(tensor_)))
2091 }
2092
2093 pub fn detach(&self) -> Tensor {
2098 if self.op.is_none() && !self.is_variable {
2099 self.clone()
2100 } else {
2101 let tensor_ = Tensor_ {
2102 id: TensorId::new(),
2103 storage: self.storage.clone(),
2104 layout: self.layout.clone(),
2105 op: BackpropOp::none(),
2106 is_variable: false,
2107 dtype: self.dtype,
2108 device: self.device.clone(),
2109 };
2110 Tensor(Arc::new(tensor_))
2111 }
2112 }
2113
2114 pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2116 if self.device().same_device(device) {
2117 Ok(self.clone())
2118 } else {
2119 let storage = match (&*self.storage(), device) {
2120 (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2121 Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2122 }
2123 (Storage::Cpu(storage), Device::Metal(metal)) => {
2124 Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2125 }
2126 (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2127 (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2128 (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2129 let cpu_storage = storage.to_cpu_storage()?;
2132 Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2133 }
2134 (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2135 _ => {
2136 bail!(
2137 "not implemented yet, self.device: {:?}, device: {:?}",
2138 self.device(),
2139 device
2140 )
2141 }
2142 };
2143 let op = BackpropOp::new1(self, Op::ToDevice);
2144 let tensor_ = Tensor_ {
2145 id: TensorId::new(),
2146 storage: Arc::new(RwLock::new(storage)),
2147 layout: self.layout.clone(),
2148 op,
2149 is_variable: false,
2150 dtype: self.dtype,
2151 device: device.clone(),
2152 };
2153 Ok(Tensor(Arc::new(tensor_)))
2154 }
2155 }
2156
2157 pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2160 let left_shape = left_shape.into();
2161 let mut dims = left_shape.into_dims();
2162 dims.extend(self.dims());
2163 self.broadcast_as(dims)
2164 }
2165
2166 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2174 let tensor_ = Tensor_ {
2175 id: TensorId::new(),
2176 storage: self.storage.clone(),
2177 layout: self.layout.broadcast_as(shape)?,
2178 op: BackpropOp::new1(self, Op::Broadcast),
2179 is_variable: false,
2180 dtype: self.dtype,
2181 device: self.device.clone(),
2182 };
2183 Ok(Tensor(Arc::new(tensor_)))
2184 }
2185
2186 pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2188 self.broadcast_as(shape)
2189 }
2190
2191 pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2202 if self.dtype() == dtype {
2203 Ok(self.clone())
2204 } else {
2205 let shape = self.shape();
2206 let storage = self.storage().to_dtype(self.layout(), dtype)?;
2207 let op = BackpropOp::new1(self, Op::ToDType);
2208 Ok(from_storage(storage, shape.clone(), op, false))
2209 }
2210 }
2211
2212 pub fn contiguous(&self) -> Result<Tensor> {
2215 if self.is_contiguous() {
2216 Ok(self.clone())
2217 } else {
2218 let shape = self.shape();
2219 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2220 self.storage()
2221 .copy_strided_src(&mut storage, 0, self.layout())?;
2222 let op = BackpropOp::new1(self, Op::Copy);
2223 Ok(from_storage(storage, shape.clone(), op, false))
2224 }
2225 }
2226
2227 pub fn force_contiguous(&self) -> Result<Tensor> {
2229 let shape = self.shape();
2230 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2231 self.storage()
2232 .copy_strided_src(&mut storage, 0, self.layout())?;
2233 let op = BackpropOp::new1(self, Op::Copy);
2234 Ok(from_storage(storage, shape.clone(), op, false))
2235 }
2236
2237 pub(crate) fn make_var(&self) -> Result<Tensor> {
2240 let shape = self.shape().clone();
2241 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2242 self.storage()
2243 .copy_strided_src(&mut storage, 0, self.layout())?;
2244 Ok(from_storage(storage, shape, BackpropOp::none(), true))
2245 }
2246
2247 pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2272 let shape = s.into_shape(self.elem_count())?;
2273 if shape.elem_count() != self.elem_count() {
2274 return Err(Error::ShapeMismatchBinaryOp {
2275 lhs: self.shape().clone(),
2276 rhs: shape,
2277 op: "reshape",
2278 }
2279 .bt());
2280 }
2281 let op = BackpropOp::new1(self, Op::Reshape);
2282 if self.is_contiguous() {
2283 let tensor_ = Tensor_ {
2284 id: TensorId::new(),
2285 storage: self.storage.clone(),
2286 layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2287 op,
2288 is_variable: false,
2289 dtype: self.dtype,
2290 device: self.device.clone(),
2291 };
2292 Ok(Tensor(Arc::new(tensor_)))
2293 } else {
2294 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2295 self.storage()
2296 .copy_strided_src(&mut storage, 0, self.layout())?;
2297 Ok(from_storage(storage, shape, op, false))
2298 }
2299 }
2300
2301 pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2315 let dims = self.dims();
2318 let dim = dim.to_index(self.shape(), "squeeze")?;
2319 if dims[dim] == 1 {
2320 let mut dims = dims.to_vec();
2321 let mut strides = self.stride().to_vec();
2322 dims.remove(dim);
2323 strides.remove(dim);
2324 let tensor_ = Tensor_ {
2325 id: TensorId::new(),
2326 storage: self.storage.clone(),
2327 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2328 op: BackpropOp::new1(self, Op::Reshape),
2329 is_variable: false,
2330 dtype: self.dtype,
2331 device: self.device.clone(),
2332 };
2333 Ok(Tensor(Arc::new(tensor_)))
2334 } else {
2335 Ok(self.clone())
2336 }
2337 }
2338
2339 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2353 let mut dims = self.dims().to_vec();
2354 let mut strides = self.stride().to_vec();
2355 let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2356 dims.insert(dim, 1);
2358 let stride = if dim < strides.len() { strides[dim] } else { 1 };
2361 strides.insert(dim, stride);
2362 let tensor_ = Tensor_ {
2363 id: TensorId::new(),
2364 storage: self.storage.clone(),
2365 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2366 op: BackpropOp::new1(self, Op::Reshape),
2367 is_variable: false,
2368 dtype: self.dtype,
2369 device: self.device.clone(),
2370 };
2371 Ok(Tensor(Arc::new(tensor_)))
2372 }
2373
2374 pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2391 if args.is_empty() {
2392 Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2393 }
2394 let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2395 let args = args
2396 .iter()
2397 .map(|t| t.as_ref().unsqueeze(dim))
2398 .collect::<Result<Vec<_>>>()?;
2399 Self::cat(&args, dim)
2400 }
2401
2402 pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2405 if left == 0 && right == 0 {
2406 Ok(self.clone())
2407 } else if left == 0 {
2408 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2409 let mut dims = self.dims().to_vec();
2410 dims[dim] = right;
2411 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2412 Tensor::cat(&[self, &right], dim)
2413 } else if right == 0 {
2414 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2415 let mut dims = self.dims().to_vec();
2416 dims[dim] = left;
2417 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2418 Tensor::cat(&[&left, self], dim)
2419 } else {
2420 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2421 let mut dims = self.dims().to_vec();
2422 dims[dim] = left;
2423 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2424 dims[dim] = right;
2425 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2426 Tensor::cat(&[&left, self, &right], dim)
2427 }
2428 }
2429
2430 pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2433 if left == 0 && right == 0 {
2434 Ok(self.clone())
2435 } else if self.elem_count() == 0 {
2436 bail!("cannot use pad_with_same on an empty tensor")
2437 } else if left == 0 {
2438 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2439 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2440 let mut v = vec![self];
2441 for _ in 0..right {
2442 v.push(&r)
2443 }
2444 Tensor::cat(&v, dim)
2445 } else if right == 0 {
2446 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2447 let l = self.narrow(dim, 0, 1)?;
2448 let mut v = vec![];
2449 for _ in 0..left {
2450 v.push(&l)
2451 }
2452 v.push(self);
2453 Tensor::cat(&v, dim)
2454 } else {
2455 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2456 let l = self.narrow(dim, 0, 1)?;
2457 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2458 let mut v = vec![];
2459 for _ in 0..left {
2460 v.push(&l)
2461 }
2462 v.push(self);
2463 for _ in 0..right {
2464 v.push(&r)
2465 }
2466 Tensor::cat(&v, dim)
2467 }
2468 }
2469
2470 pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2472 m.forward(self)
2473 }
2474
2475 pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2477 m.forward_t(self, train)
2478 }
2479
2480 pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2481 self.storage.read().unwrap()
2482 }
2483
2484 pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2485 self.storage.write().unwrap()
2486 }
2487
2488 pub(crate) fn storage_mut_and_layout(
2491 &self,
2492 ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2493 let storage = self.storage.write().unwrap();
2494 (storage, &self.layout)
2495 }
2496
2497 pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2499 let storage = self.storage.read().unwrap();
2500 (storage, &self.layout)
2501 }
2502
2503 pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2504 let lhs: &RwLock<Storage> = self.storage.as_ref();
2505 let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2506 std::ptr::eq(lhs, rhs)
2507 }
2508
2509 pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2512 let rank = self.rank() as i64;
2513 if rank <= axis {
2514 bail!("axis {axis} is too large, tensor rank {rank}")
2515 } else if 0 <= axis {
2516 Ok(axis as usize)
2517 } else {
2518 let naxis = rank + axis;
2519 if naxis < 0 {
2520 bail!("axis {axis} is too small, tensor rank {rank}")
2521 }
2522 Ok(naxis as usize)
2523 }
2524 }
2525
2526 pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2528 let t = Tensor::arange(0u32, n as u32, device)?;
2529 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2530 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2531 t1.le(&t2)?.to_dtype(dtype)
2532 }
2533
2534 pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2536 let t = Tensor::arange(0u32, n as u32, device)?;
2537 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2538 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2539 t1.ge(&t2)?.to_dtype(dtype)
2540 }
2541
2542 pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2544 let t = Tensor::arange(0u32, n as u32, device)?;
2545 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2546 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2547 t1.eq(&t2)?.to_dtype(dtype)
2548 }
2549
2550 pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2555 let dim = dim.to_index(self.shape(), "cumsum")?;
2556 let rank = self.rank();
2557 if rank == 0 {
2558 return Ok(self.clone());
2559 }
2560 let n_axis = self.dim(dim)?;
2561 let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2562 if rank == 1 {
2563 self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2564 } else {
2565 let last = rank - 1;
2566 let t = self.transpose(dim, last)?;
2567 let t = t.broadcast_matmul(&triu)?;
2568 t.transpose(dim, last)
2569 }
2570 }
2571
2572 pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2575 &self,
2576 ranges: &[D],
2577 src: &Tensor,
2578 ) -> Result<Self> {
2579 let src_dims = src.dims();
2580 let self_dims = self.dims();
2581 if self_dims.len() != src_dims.len() {
2582 bail!(
2583 "slice-assign requires input with the same rank {} <> {}",
2584 self_dims.len(),
2585 src_dims.len()
2586 )
2587 }
2588 if self_dims.len() != ranges.len() {
2589 bail!(
2590 "slice-assign requires input with the same rank as there are ranges {} <> {}",
2591 self_dims.len(),
2592 ranges.len()
2593 )
2594 }
2595 let mut src = src.clone();
2596 let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2597 for (i, range) in ranges.iter().enumerate() {
2598 let start_included = match range.start_bound() {
2599 std::ops::Bound::Unbounded => 0,
2600 std::ops::Bound::Included(v) => *v,
2601 std::ops::Bound::Excluded(v) => *v + 1,
2602 };
2603 let end_excluded = match range.end_bound() {
2604 std::ops::Bound::Unbounded => self_dims[i],
2605 std::ops::Bound::Included(v) => *v + 1,
2606 std::ops::Bound::Excluded(v) => *v,
2607 };
2608 if end_excluded <= start_included {
2609 bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2610 }
2611 if self_dims[i] < end_excluded {
2612 bail!(
2613 "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2614 self_dims[i]
2615 )
2616 }
2617 if end_excluded - start_included != src_dims[i] {
2618 bail!(
2619 "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2620 )
2621 }
2622 src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2623 mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2624 }
2625 mask.where_cond(&src, self)
2626 }
2627
2628 pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2630 let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2631 if sum_dims.is_empty() {
2632 return Ok(self.clone());
2633 }
2634 let max = sum_dims[1..]
2635 .iter()
2636 .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2637 max.max_keepdim(dim)
2638 })?;
2639 let exp = self.broadcast_sub(&max)?.exp()?;
2640 let sum = exp.sum(sum_dims.clone())?;
2641
2642 sum.log()? + max.squeeze_dims(&sum_dims)
2643 }
2644
2645 pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2647 rhs.mul(&self.log()?)?.exp()
2648 }
2649
2650 pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2652 rhs.broadcast_mul(&self.log()?)?.exp()
2653 }
2654
2655 pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2667 let mut result = self.clone();
2668 for &dim in dims.iter() {
2669 let size = result.dim(dim)?;
2670 let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2671 let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2672 result = result.index_select(&indices_tensor, dim)?;
2673 }
2674 Ok(result)
2675 }
2676}
2677
2678macro_rules! bin_trait {
2679 ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2680 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2681 type Output = Result<Tensor>;
2682
2683 fn $fn1(self, rhs: B) -> Self::Output {
2684 Tensor::$fn1(&self, rhs.borrow())
2685 }
2686 }
2687
2688 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2689 type Output = Result<Tensor>;
2690
2691 fn $fn1(self, rhs: B) -> Self::Output {
2692 Tensor::$fn1(&self, rhs.borrow())
2693 }
2694 }
2695
2696 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2697 type Output = Result<Tensor>;
2698
2699 fn $fn1(self, rhs: Tensor) -> Self::Output {
2700 Tensor::$fn1(self?.borrow(), &rhs)
2701 }
2702 }
2703
2704 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
2705 type Output = Result<Tensor>;
2706
2707 fn $fn1(self, rhs: &Tensor) -> Self::Output {
2708 Tensor::$fn1(self?.borrow(), rhs)
2709 }
2710 }
2711
2712 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
2713 type Output = Result<Tensor>;
2714
2715 fn $fn1(self, rhs: Result<B>) -> Self::Output {
2716 Tensor::$fn1(&self, rhs?.borrow())
2717 }
2718 }
2719
2720 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
2721 type Output = Result<Tensor>;
2722
2723 fn $fn1(self, rhs: Result<B>) -> Self::Output {
2724 Tensor::$fn1(&self, rhs?.borrow())
2725 }
2726 }
2727
2728 impl std::ops::$trait<f64> for Tensor {
2729 type Output = Result<Tensor>;
2730
2731 fn $fn1(self, rhs: f64) -> Self::Output {
2732 self.affine($mul(rhs), $add(rhs))
2733 }
2734 }
2735
2736 impl std::ops::$trait<f64> for &Tensor {
2737 type Output = Result<Tensor>;
2738
2739 fn $fn1(self, rhs: f64) -> Self::Output {
2740 self.affine($mul(rhs), $add(rhs))
2741 }
2742 }
2743 };
2744}
2745
2746bin_trait!(Add, add, |_| 1., |v| v);
2747bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
2748bin_trait!(Mul, mul, |v| v, |_| 0.);
2749bin_trait!(Div, div, |v| 1. / v, |_| 0.);
2750
2751impl std::ops::Add<Tensor> for f64 {
2752 type Output = Result<Tensor>;
2753
2754 fn add(self, rhs: Tensor) -> Self::Output {
2755 rhs + self
2756 }
2757}
2758
2759impl std::ops::Add<&Tensor> for f64 {
2760 type Output = Result<Tensor>;
2761
2762 fn add(self, rhs: &Tensor) -> Self::Output {
2763 rhs + self
2764 }
2765}
2766
2767impl std::ops::Mul<Tensor> for f64 {
2768 type Output = Result<Tensor>;
2769
2770 fn mul(self, rhs: Tensor) -> Self::Output {
2771 rhs * self
2772 }
2773}
2774
2775impl std::ops::Mul<&Tensor> for f64 {
2776 type Output = Result<Tensor>;
2777
2778 fn mul(self, rhs: &Tensor) -> Self::Output {
2779 rhs * self
2780 }
2781}
2782
2783impl std::ops::Sub<Tensor> for f64 {
2784 type Output = Result<Tensor>;
2785
2786 fn sub(self, rhs: Tensor) -> Self::Output {
2787 rhs.affine(-1., self)
2788 }
2789}
2790
2791impl std::ops::Sub<&Tensor> for f64 {
2792 type Output = Result<Tensor>;
2793
2794 fn sub(self, rhs: &Tensor) -> Self::Output {
2795 rhs.affine(-1., self)
2796 }
2797}
2798
2799impl std::ops::Div<Tensor> for f64 {
2800 type Output = Result<Tensor>;
2801
2802 #[allow(clippy::suspicious_arithmetic_impl)]
2803 fn div(self, rhs: Tensor) -> Self::Output {
2804 rhs.recip()? * self
2805 }
2806}
2807
2808impl std::ops::Div<&Tensor> for f64 {
2809 type Output = Result<Tensor>;
2810
2811 #[allow(clippy::suspicious_arithmetic_impl)]
2812 fn div(self, rhs: &Tensor) -> Self::Output {
2813 rhs.recip()? * self
2814 }
2815}