1#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15 fn new() -> Self {
16 use std::sync::atomic;
18 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20 }
21}
22
23pub struct Tensor_ {
24 id: TensorId,
25 storage: Arc<RwLock<Storage>>,
38 layout: Layout,
39 op: BackpropOp,
40 is_variable: bool,
41 dtype: DType,
42 device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46 fn as_ref(&self) -> &Tensor {
47 self
48 }
49}
50
51#[derive(Clone)]
55pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71 type Target = Tensor_;
72
73 fn deref(&self) -> &Self::Target {
74 self.0.as_ref()
75 }
76}
77
78macro_rules! unary_op {
79 ($fn_name:ident, $op_name:ident) => {
80 pub fn $fn_name(&self) -> Result<Self> {
81 let shape = self.shape();
82 if shape.elem_count() == 0 {
83 return Ok(self.clone());
84 }
85 let storage = self
86 .storage()
87 .unary_impl::<crate::op::$op_name>(self.layout())?;
88 let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89 Ok(from_storage(storage, shape.clone(), op, false))
90 }
91 };
92}
93
94macro_rules! binary_op {
95 ($fn_name:ident, $op_name:ident) => {
96 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97 let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98 if shape.elem_count() == 0 {
99 return Ok(self.clone());
100 }
101 let storage = self.storage().binary_impl::<crate::op::$op_name>(
102 &*rhs.storage(),
103 self.layout(),
104 rhs.layout(),
105 )?;
106 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107 Ok(from_storage(storage, shape.clone(), op, false))
108 }
109 };
110}
111
112macro_rules! binary_op_scalar {
113 ($fn_name:ident, $op_name:ident) => {
114 pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115 let rhs = match rhs.to_tensor_scalar()? {
116 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117 crate::scalar::TensorScalar::Scalar(rhs) => rhs
118 .to_dtype(self.dtype())?
119 .to_device(self.device())?
120 .broadcast_as(self.shape())?,
121 };
122 let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123 if self.elem_count() == 0 {
124 return Ok(self.clone());
125 }
126 let storage = self.storage().binary_impl::<crate::op::$op_name>(
127 &*rhs.storage(),
128 self.layout(),
129 rhs.layout(),
130 )?;
131 let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132 Ok(from_storage(storage, shape.clone(), op, false))
133 }
134 };
135}
136
137macro_rules! broadcast_binary_op {
138 ($fn_name:ident, $inner_fn_name:ident) => {
139 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140 let lhs = self;
141 let shape = lhs
142 .shape()
143 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144 let l_broadcast = shape != *lhs.shape();
145 let r_broadcast = shape != *rhs.shape();
146 match (l_broadcast, r_broadcast) {
147 (true, true) => lhs
148 .broadcast_as(&shape)?
149 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152 (false, false) => lhs.$inner_fn_name(rhs),
153 }
154 }
155 };
156}
157
158pub(crate) fn from_storage<S: Into<Shape>>(
160 storage: Storage,
161 shape: S,
162 op: BackpropOp,
163 is_variable: bool,
164) -> Tensor {
165 let dtype = storage.dtype();
166 let device = storage.device();
167 let tensor_ = Tensor_ {
168 id: TensorId::new(),
169 storage: Arc::new(RwLock::new(storage)),
170 layout: Layout::contiguous(shape),
171 op,
172 is_variable,
173 dtype,
174 device,
175 };
176 Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180 pub(crate) fn ones_impl<S: Into<Shape>>(
181 shape: S,
182 dtype: DType,
183 device: &Device,
184 is_variable: bool,
185 ) -> Result<Self> {
186 let none = BackpropOp::none();
187 let shape = shape.into();
188 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189 let layout = Layout::contiguous(shape.clone());
190 storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191 Ok(from_storage(storage, shape, none, is_variable))
192 }
193
194 pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204 Self::ones_impl(shape, dtype, device, false)
205 }
206
207 pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208 self.storage_mut().const_set(value, self.layout())
209 }
210
211 pub fn zero_set(&self) -> Result<()> {
212 self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213 }
214
215 pub fn one_set(&self) -> Result<()> {
216 self.const_set(crate::scalar::Scalar::one(self.dtype()))
217 }
218
219 pub fn ones_like(&self) -> Result<Self> {
229 Tensor::ones(self.shape(), self.dtype(), self.device())
230 }
231
232 pub(crate) fn zeros_impl<S: Into<Shape>>(
235 shape: S,
236 dtype: DType,
237 device: &Device,
238 is_variable: bool,
239 ) -> Result<Self> {
240 let none = BackpropOp::none();
241 let shape = shape.into();
242 let storage = device.zeros(&shape, dtype)?;
243 Ok(from_storage(storage, shape, none, is_variable))
244 }
245
246 pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256 Self::zeros_impl(shape, dtype, device, false)
257 }
258
259 pub fn zeros_like(&self) -> Result<Self> {
270 Tensor::zeros(self.shape(), self.dtype(), self.device())
271 }
272
273 pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
274 lo: T,
275 up: T,
276 s: S,
277 device: &Device,
278 is_variable: bool,
279 ) -> Result<Self> {
280 let s = s.into();
281 let storage = device.rand_uniform(lo, up, &s)?;
282 let none = BackpropOp::none();
283 Ok(from_storage(storage, s, none, is_variable))
284 }
285
286 pub(crate) fn rand_f64_impl<S: Into<Shape>>(
287 lo: f64,
288 up: f64,
289 s: S,
290 dtype: DType,
291 device: &Device,
292 is_variable: bool,
293 ) -> Result<Self> {
294 let s = s.into();
295 let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
296 let none = BackpropOp::none();
297 Ok(from_storage(storage, s, none, is_variable))
298 }
299
300 pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
302 lo: T,
303 up: T,
304 s: S,
305 device: &Device,
306 ) -> Result<Self> {
307 Self::rand_impl(lo, up, s, device, false)
308 }
309
310 pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
311 Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
312 }
313
314 pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
315 mean: T,
316 std: T,
317 s: S,
318 device: &Device,
319 is_variable: bool,
320 ) -> Result<Self> {
321 let s = s.into();
322 let storage = device.rand_normal(mean, std, &s)?;
323 let none = BackpropOp::none();
324 Ok(from_storage(storage, s, none, is_variable))
325 }
326
327 pub(crate) fn randn_f64_impl<S: Into<Shape>>(
328 mean: f64,
329 std: f64,
330 s: S,
331 dtype: DType,
332 device: &Device,
333 is_variable: bool,
334 ) -> Result<Self> {
335 let s = s.into();
336 let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
337 let none = BackpropOp::none();
338 Ok(from_storage(storage, s, none, is_variable))
339 }
340
341 pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
342 Tensor::randn_f64_impl(
343 mean,
344 stdev,
345 self.shape(),
346 self.dtype(),
347 self.device(),
348 false,
349 )
350 }
351
352 pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
355 mean: T,
356 std: T,
357 s: S,
358 device: &Device,
359 ) -> Result<Self> {
360 Self::randn_impl(mean, std, s, device, false)
361 }
362
363 pub(crate) fn new_impl<A: crate::device::NdArray>(
364 array: A,
365 shape: Shape,
366 device: &Device,
367 is_variable: bool,
368 ) -> Result<Self> {
369 let n: usize = shape.elem_count();
370 let buffer_size: usize = array.shape()?.elem_count();
371 if buffer_size != n {
372 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
373 }
374 let storage = device.storage(array)?;
375 let none = BackpropOp::none();
376 Ok(from_storage(storage, shape, none, is_variable))
377 }
378
379 pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
381 let shape = array.shape()?;
382 Self::new_impl(array, shape, device, false)
383 }
384
385 pub fn full<D: crate::WithDType, S: Into<Shape>>(
396 value: D,
397 shape: S,
398 device: &Device,
399 ) -> Result<Self> {
400 let none = BackpropOp::none();
401 let shape = shape.into();
402 let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
403 let layout = Layout::contiguous(shape.clone());
404 storage.const_set(value.to_scalar(), &layout)?;
405 Ok(from_storage(storage, shape, none, false))
406 }
407
408 pub fn from_iter<D: crate::WithDType>(
417 iter: impl IntoIterator<Item = D>,
418 device: &Device,
419 ) -> Result<Self> {
420 let data = iter.into_iter().collect::<Vec<_>>();
421 let len = data.len();
422 Self::from_vec_impl(data, len, device, false)
423 }
424
425 pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
435 Self::arange_step(start, end, D::one(), device)
436 }
437
438 pub fn arange_step<D: crate::WithDType>(
448 start: D,
449 end: D,
450 step: D,
451 device: &Device,
452 ) -> Result<Self> {
453 if D::is_zero(&step) {
454 bail!("step cannot be zero")
455 }
456 let mut data = vec![];
457 let mut current = start;
458 if step >= D::zero() {
459 while current < end {
460 data.push(current);
461 current += step;
462 }
463 } else {
464 while current > end {
465 data.push(current);
466 current += step;
467 }
468 }
469 let len = data.len();
470 Self::from_vec_impl(data, len, device, false)
471 }
472
473 pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
474 data: Vec<D>,
475 shape: S,
476 device: &Device,
477 is_variable: bool,
478 ) -> Result<Self> {
479 let shape = shape.into_shape(data.len())?;
480 let storage = device.storage_owned(data)?;
481 let none = BackpropOp::none();
482 Ok(from_storage(storage, shape, none, is_variable))
483 }
484
485 pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
499 data: Vec<D>,
500 shape: S,
501 device: &Device,
502 ) -> Result<Self> {
503 Self::from_vec_impl(data, shape, device, false)
504 }
505
506 pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
520 array: &[D],
521 shape: S,
522 device: &Device,
523 ) -> Result<Self> {
524 let shape = shape.into_shape(array.len())?;
525 let storage = device.storage_from_slice(array)?;
526 let none = BackpropOp::none();
527 Ok(from_storage(storage, shape, none, false))
528 }
529
530 pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
531 let lhs = self.shape();
532 let rhs = rhs.shape();
533 if lhs != rhs {
534 Err(Error::ShapeMismatchBinaryOp {
535 lhs: lhs.clone(),
536 rhs: rhs.clone(),
537 op,
538 }
539 .bt())
540 } else {
541 Ok(lhs)
542 }
543 }
544
545 pub fn track_op(&self) -> bool {
548 self.is_variable || self.op.is_some()
549 }
550
551 binary_op!(add, Add);
554 binary_op!(mul, Mul);
555 binary_op!(sub, Sub);
556 binary_op!(div, Div);
557 binary_op_scalar!(maximum, Maximum);
558 binary_op_scalar!(minimum, Minimum);
559 broadcast_binary_op!(broadcast_add, add);
560 broadcast_binary_op!(broadcast_mul, mul);
561 broadcast_binary_op!(broadcast_sub, sub);
562 broadcast_binary_op!(broadcast_div, div);
563 broadcast_binary_op!(broadcast_maximum, maximum);
564 broadcast_binary_op!(broadcast_minimum, minimum);
565 broadcast_binary_op!(broadcast_eq, eq);
566 broadcast_binary_op!(broadcast_ne, ne);
567 broadcast_binary_op!(broadcast_lt, lt);
568 broadcast_binary_op!(broadcast_le, le);
569 broadcast_binary_op!(broadcast_gt, gt);
570 broadcast_binary_op!(broadcast_ge, ge);
571
572 unary_op!(recip, Recip);
573 unary_op!(neg, Neg);
574 unary_op!(exp, Exp);
575 unary_op!(log, Log);
576 unary_op!(sin, Sin);
577 unary_op!(cos, Cos);
578 unary_op!(tanh, Tanh);
579 unary_op!(abs, Abs);
580 unary_op!(sqr, Sqr);
581 unary_op!(sqrt, Sqrt);
582 unary_op!(gelu, Gelu);
583 unary_op!(gelu_erf, GeluErf);
584 unary_op!(erf, Erf);
585 unary_op!(relu, Relu);
586 unary_op!(silu, Silu);
587 unary_op!(ceil, Ceil);
588 unary_op!(floor, Floor);
589 unary_op!(round, Round);
590 unary_op!(sign, Sign);
591
592 pub fn round_to(&self, decimals: i32) -> Result<Self> {
597 let mult = 10f64.powi(decimals);
598 (self * mult)?.round()? * (1f64 / mult)
599 }
600
601 pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
604 if self.rank() != 0 {
605 Err(Error::UnexpectedNumberOfDims {
606 expected: 0,
607 got: self.rank(),
608 shape: self.shape().clone(),
609 }
610 .bt())?
611 }
612 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
613 let data = S::cpu_storage_as_slice(cpu_storage)?;
614 Ok::<_, Error>(data[self.layout().start_offset()])
615 };
616 match &*self.storage() {
617 Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
618 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
619 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
620 }
621 }
622
623 pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
625 self.to_scalar::<S>()
626 }
627
628 pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
630 let repeats = shape.into();
632 let repeats = repeats.dims();
633 let mut inp = if self.rank() < repeats.len() {
634 let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
635 self.reshape(shape)?
636 } else {
637 self.clone()
638 };
639 for (idx, &repeat) in repeats.iter().enumerate() {
640 if repeat > 1 {
641 inp = Tensor::cat(&vec![&inp; repeat], idx)?
642 }
643 }
644 Ok(inp)
645 }
646
647 pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
684 if args.len() <= 1 {
685 Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
686 }
687 let args: Vec<_> = if xy_indexing {
688 args.iter().rev().collect()
689 } else {
690 args.iter().collect()
691 };
692
693 let mut shape = Vec::with_capacity(args.len());
694 for arg in args.iter() {
695 shape.push(arg.as_ref().dims1()?)
696 }
697
698 let mut grids = Vec::with_capacity(args.len());
699 for idx in 0..args.len() {
700 let mut ones = vec![1usize; args.len()];
701 ones[idx] = shape[idx];
702 let arg = args[idx].as_ref().reshape(ones)?;
703 let mut repeats = shape.clone();
704 repeats[idx] = 1;
705 let repeated_tensor = arg.repeat(repeats)?;
706 grids.push(repeated_tensor);
707 }
708 if xy_indexing {
709 grids.reverse();
710 }
711 Ok(grids)
712 }
713
714 pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
726 if self.elem_count() == 0 {
727 return Ok(self.clone());
728 }
729 let storage = self.storage().affine(self.layout(), mul, add)?;
730 let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
731 Ok(from_storage(storage, self.shape(), op, false))
732 }
733
734 pub fn elu(&self, alpha: f64) -> Result<Self> {
736 if self.elem_count() == 0 {
737 return Ok(self.clone());
738 }
739 let storage = self.storage().elu(self.layout(), alpha)?;
740 let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
741 Ok(from_storage(storage, self.shape(), op, false))
742 }
743
744 pub fn powf(&self, e: f64) -> Result<Self> {
746 if self.elem_count() == 0 {
747 return Ok(self.clone());
748 }
749 let storage = self.storage().powf(self.layout(), e)?;
750 let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
751 Ok(from_storage(storage, self.shape(), op, false))
752 }
753
754 pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
755 if dim >= self.dims().len() {
756 Err(Error::DimOutOfRange {
757 shape: self.shape().clone(),
758 dim: dim as i32,
759 op,
760 }
761 .bt())?
762 } else {
763 Ok(())
764 }
765 }
766
767 pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
770 let dim = dim.to_index(self.shape(), "chunk")?;
771 let size = self.dim(dim)?;
772 if size < chunks {
773 (0..size).map(|i| self.narrow(dim, i, 1)).collect()
774 } else {
775 let chunk_size = size / chunks;
776 let cnt_additional = size % chunks;
777 let mut tensors = vec![];
778 let mut sum_chunk_size = 0;
779 for i in 0..chunks {
780 let chunk_size = if i < cnt_additional {
781 chunk_size + 1
782 } else {
783 chunk_size
784 };
785 let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
786 tensors.push(tensor);
787 sum_chunk_size += chunk_size
788 }
789 Ok(tensors)
790 }
791 }
792
793 pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
820 let dims = self.dims();
821 let dim = dim.to_index(self.shape(), "narrow")?;
822 let err = |msg| {
823 Err::<(), _>(
824 Error::NarrowInvalidArgs {
825 shape: self.shape().clone(),
826 dim,
827 start,
828 len,
829 msg,
830 }
831 .bt(),
832 )
833 };
834 if start > dims[dim] {
835 err("start > dim_len")?
836 }
837 if start.saturating_add(len) > dims[dim] {
838 err("start + len > dim_len")?
839 }
840 if start == 0 && dims[dim] == len {
841 Ok(self.clone())
842 } else {
843 let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
844 let layout = self.layout().narrow(dim, start, len)?;
845 let tensor_ = Tensor_ {
846 id: TensorId::new(),
847 storage: self.storage.clone(),
848 layout,
849 op,
850 is_variable: false,
851 dtype: self.dtype,
852 device: self.device.clone(),
853 };
854 Ok(Tensor(Arc::new(tensor_)))
855 }
856 }
857
858 fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
859 match dims {
860 [] => Ok(self),
861 [i] => self.squeeze(*i),
862 dims => {
863 let dims = self
864 .dims()
865 .iter()
866 .enumerate()
867 .filter_map(|(dim_idx, &v)| {
868 if dims.contains(&dim_idx) {
869 None
870 } else {
871 Some(v)
872 }
873 })
874 .collect::<Vec<_>>();
875 self.reshape(dims)
876 }
877 }
878 }
879
880 fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
881 let dim = dim.to_index(self.shape(), op.name())?;
882 let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
883 let mut dims = self.dims().to_vec();
884 dims[dim] = 1;
885 let op = match op {
886 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
887 BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
888 }
889 ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
890 };
891 let res = from_storage(storage, dims, op, false);
892 if keepdim {
893 Ok(res)
894 } else {
895 res.squeeze_dims(&[dim])
896 }
897 }
898
899 fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
900 let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
901 let storage = self
902 .storage()
903 .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
904 let mut dims = self.dims().to_vec();
905 for &sum_dim in sum_dims.iter() {
906 dims[sum_dim] = 1
907 }
908 let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
909 let sum = from_storage(storage, dims, op, false);
910 if keepdim {
911 Ok(sum)
912 } else {
913 sum.squeeze_dims(&sum_dims)
914 }
915 }
916
917 pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
931 where
932 D: Dim + Clone,
933 {
934 let dim = dim.to_index(self.shape(), "roll")?;
935 let dim_size = self.dim(dim)?;
936 let shift = shift.rem_euclid(dim_size as i32) as usize;
937 if shift == 0 {
938 Ok(self.clone())
939 } else {
940 let a = self.narrow(dim, 0, dim_size - shift)?;
941 let b = self.narrow(dim, dim_size - shift, shift)?;
942 Tensor::cat(&[&b, &a], dim)
943 }
944 }
945
946 pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
964 self.sum_impl(sum_dims, true)
965 }
966
967 pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
971 self.sum_impl(sum_dims, false)
972 }
973
974 pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
992 let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
993 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
994 let scale = 1f64 / (reduced_dim as f64);
995 self.sum_impl(mean_dims, true)? * scale
996 }
997
998 pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1002 let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
1003 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1004 let scale = 1f64 / (reduced_dim as f64);
1005 self.sum_impl(mean_dims, false)? * scale
1006 }
1007
1008 pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1010 let dim = dim.to_index(self.shape(), "var")?;
1011 let mean = self.mean_keepdim(dim)?;
1012 let squares = self.broadcast_sub(&mean)?.sqr()?;
1013 squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1014 }
1015
1016 pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1018 let dim = dim.to_index(self.shape(), "var")?;
1019 self.var_keepdim(dim)?.squeeze(dim)
1020 }
1021
1022 pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1025 self.reduce_impl(dim, true, ReduceOp::Max)
1026 }
1027
1028 pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1030 self.reduce_impl(dim, false, ReduceOp::Max)
1031 }
1032
1033 pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1036 self.reduce_impl(dim, true, ReduceOp::Min)
1037 }
1038
1039 pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1041 self.reduce_impl(dim, false, ReduceOp::Min)
1042 }
1043
1044 pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1045 self.reduce_impl(dim, true, ReduceOp::ArgMax)
1046 }
1047
1048 pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1050 self.reduce_impl(dim, false, ReduceOp::ArgMax)
1051 }
1052
1053 pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1054 self.reduce_impl(dim, true, ReduceOp::ArgMin)
1055 }
1056
1057 pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1059 self.reduce_impl(dim, false, ReduceOp::ArgMin)
1060 }
1061
1062 pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1067 let rhs = match rhs.to_tensor_scalar()? {
1068 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1069 crate::scalar::TensorScalar::Scalar(rhs) => rhs
1070 .to_dtype(self.dtype())?
1071 .to_device(self.device())?
1072 .broadcast_as(self.shape())?,
1073 };
1074 let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1075 let storage = self
1076 .storage()
1077 .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1078 let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1079 Ok(from_storage(storage, shape.dims(), op, false))
1080 }
1081
1082 pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1084 self.cmp(rhs, CmpOp::Eq)
1085 }
1086
1087 pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1089 self.cmp(rhs, CmpOp::Ne)
1090 }
1091
1092 pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1095 self.cmp(rhs, CmpOp::Lt)
1096 }
1097
1098 pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1101 self.cmp(rhs, CmpOp::Gt)
1102 }
1103
1104 pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1107 self.cmp(rhs, CmpOp::Ge)
1108 }
1109
1110 pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1113 self.cmp(rhs, CmpOp::Le)
1114 }
1115
1116 pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1118 self.maximum(min)?.minimum(max)
1119 }
1120
1121 pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1126 let (n, c, _l) = self.dims3()?;
1127 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1128 let storage = self
1129 .storage()
1130 .upsample_nearest1d(self.layout(), target_size)?;
1131 Ok(from_storage(storage, (n, c, target_size), op, false))
1132 }
1133
1134 pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1136 self.interpolate1d(target_size)
1137 }
1138
1139 pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1145 let (n, c, _h, _w) = self.dims4()?;
1146 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1147 arg,
1148 target_h,
1149 target_w,
1150 });
1151 let storage = self
1152 .storage()
1153 .upsample_nearest2d(self.layout(), target_h, target_w)?;
1154 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1155 }
1156
1157 pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1159 self.interpolate2d(target_h, target_w)
1160 }
1161
1162 pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1169 let sz = sz.to_usize2();
1170 self.avg_pool2d_with_stride(sz, sz)
1171 }
1172
1173 pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1176 &self,
1177 kernel_size: T,
1178 stride: T,
1179 ) -> Result<Self> {
1180 let kernel_size = kernel_size.to_usize2();
1181 let stride = stride.to_usize2();
1182 let (n, c, h, w) = self.dims4()?;
1183 if h < kernel_size.0 || w < kernel_size.1 {
1184 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1185 }
1186 let h_out = (h - kernel_size.0) / stride.0 + 1;
1188 let w_out = (w - kernel_size.1) / stride.1 + 1;
1189 let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1190 arg,
1191 kernel_size,
1192 stride,
1193 });
1194 let storage = self
1195 .storage()
1196 .avg_pool2d(self.layout(), kernel_size, stride)?;
1197 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1198 }
1199
1200 pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1207 let sz = sz.to_usize2();
1208 self.max_pool2d_with_stride(sz, sz)
1209 }
1210
1211 pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1214 &self,
1215 kernel_size: T,
1216 stride: T,
1217 ) -> Result<Self> {
1218 let kernel_size = kernel_size.to_usize2();
1219 let stride = stride.to_usize2();
1220 let (n, c, h, w) = self.dims4()?;
1221 if h < kernel_size.0 || w < kernel_size.1 {
1222 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1223 }
1224 let h_out = (h - kernel_size.0) / stride.0 + 1;
1226 let w_out = (w - kernel_size.1) / stride.1 + 1;
1227 let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1228 arg,
1229 kernel_size,
1230 stride,
1231 });
1232 let storage = self
1233 .storage()
1234 .max_pool2d(self.layout(), kernel_size, stride)?;
1235 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1236 }
1237
1238 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1247 let a_dims = self.shape().dims();
1248 let b_dims = rhs.shape().dims();
1249
1250 let dim = a_dims.len();
1251
1252 if dim < 2 || b_dims.len() != dim {
1253 Err(Error::ShapeMismatchBinaryOp {
1254 lhs: self.shape().clone(),
1255 rhs: rhs.shape().clone(),
1256 op: "matmul",
1257 }
1258 .bt())?
1259 }
1260
1261 let m = a_dims[dim - 2];
1262 let k = a_dims[dim - 1];
1263 let k2 = b_dims[dim - 2];
1264 let n = b_dims[dim - 1];
1265
1266 let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1267 if c_shape.elem_count() == 0 || k == 0 {
1268 return Tensor::zeros(c_shape, self.dtype(), self.device());
1269 }
1270 let batching: usize = a_dims[..dim - 2].iter().product();
1271 let batching_b: usize = b_dims[..dim - 2].iter().product();
1272 if k != k2 || batching != batching_b {
1273 Err(Error::ShapeMismatchBinaryOp {
1274 lhs: self.shape().clone(),
1275 rhs: rhs.shape().clone(),
1276 op: "matmul",
1277 }
1278 .bt())?
1279 }
1280
1281 let storage = self.storage().matmul(
1282 &rhs.storage(),
1283 (batching, m, n, k),
1284 self.layout(),
1285 rhs.layout(),
1286 )?;
1287 let op = BackpropOp::new2(self, rhs, Op::Matmul);
1288 Ok(from_storage(storage, c_shape, op, false))
1289 }
1290
1291 pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1297 let lhs = self;
1298 let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1299 let l_broadcast = l_shape != *lhs.shape();
1300 let r_broadcast = r_shape != *rhs.shape();
1301 match (l_broadcast, r_broadcast) {
1303 (true, true) => lhs
1304 .broadcast_as(&l_shape)?
1305 .contiguous()?
1306 .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1307 (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1308 (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1309 (false, false) => lhs.matmul(rhs),
1310 }
1311 }
1312
1313 pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1317 let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1318 let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1319 let storage = self.storage().where_cond(
1320 self.layout(),
1321 &on_true.storage(),
1322 on_true.layout(),
1323 &on_false.storage(),
1324 on_false.layout(),
1325 )?;
1326 let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1327 Ok(from_storage(storage, shape, op, false))
1328 }
1329
1330 pub fn embedding(&self, ids: &Self) -> Result<Self> {
1350 if self.rank() != 2 || ids.rank() != 1 {
1351 Err(Error::ShapeMismatchBinaryOp {
1352 lhs: self.shape().clone(),
1353 rhs: ids.shape().clone(),
1354 op: "embedding",
1355 }
1356 .bt())?
1357 }
1358 self.index_select(ids, 0)
1359 }
1360
1361 fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1362 let source_dims = source.dims();
1363 let self_dims = self.dims();
1364 let mismatch = if source_dims.len() != self_dims.len() {
1365 true
1366 } else {
1367 let mut mismatch = false;
1368 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1369 if i != dim && d1 != d2 {
1370 mismatch = true;
1371 break;
1372 }
1373 }
1374 mismatch
1375 };
1376 if mismatch {
1377 Err(Error::ShapeMismatchBinaryOp {
1378 op: "scatter (self, src)",
1379 lhs: self.shape().clone(),
1380 rhs: source.shape().clone(),
1381 }
1382 .bt())?
1383 }
1384 if indexes.dims() != source.dims() {
1385 Err(Error::ShapeMismatchBinaryOp {
1386 op: "scatter (indexes, src)",
1387 lhs: indexes.shape().clone(),
1388 rhs: source.shape().clone(),
1389 }
1390 .bt())?
1391 }
1392 Ok(())
1393 }
1394
1395 pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1396 let dim = dim.to_index(self.shape(), "scatter")?;
1397 self.scatter_checks(indexes, source, dim)?;
1398 let shape = self.shape();
1399 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1400 self.storage()
1401 .copy_strided_src(&mut storage, 0, self.layout())?;
1402 let layout = Layout::contiguous(shape);
1403 storage.scatter_set(
1404 &layout,
1405 &indexes.storage(),
1406 indexes.layout(),
1407 &source.storage(),
1408 source.layout(),
1409 dim,
1410 )?;
1411 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1412 Op::Scatter(t1, t2, t3, dim)
1413 });
1414 Ok(from_storage(storage, self.shape(), op, false))
1415 }
1416
1417 pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1418 if self.same_storage(source) {
1419 crate::bail!("cannot use slice_set when self and src share their storage")
1420 }
1421 let dim = dim.to_index(self.shape(), "scatter-set")?;
1422 self.scatter_checks(indexes, source, dim)?;
1423 self.storage_mut().scatter_set(
1424 self.layout(),
1425 &indexes.storage(),
1426 indexes.layout(),
1427 &source.storage(),
1428 source.layout(),
1429 dim,
1430 )?;
1431 Ok(())
1432 }
1433
1434 pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1435 let dim = dim.to_index(self.shape(), "scatter-add")?;
1436 self.scatter_checks(indexes, source, dim)?;
1437 let shape = self.shape();
1438 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1439 self.storage()
1440 .copy_strided_src(&mut storage, 0, self.layout())?;
1441 let layout = Layout::contiguous(shape);
1442 storage.scatter_add(
1443 &layout,
1444 &indexes.storage(),
1445 indexes.layout(),
1446 &source.storage(),
1447 source.layout(),
1448 dim,
1449 )?;
1450 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1451 Op::ScatterAdd(t1, t2, t3, dim)
1452 });
1453 Ok(from_storage(storage, self.shape(), op, false))
1454 }
1455
1456 pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1457 if self.same_storage(source) {
1458 crate::bail!("cannot use slice_set when self and src share their storage")
1459 }
1460 let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1461 self.scatter_checks(indexes, source, dim)?;
1462 self.storage_mut().scatter_add(
1463 self.layout(),
1464 &indexes.storage(),
1465 indexes.layout(),
1466 &source.storage(),
1467 source.layout(),
1468 dim,
1469 )?;
1470 Ok(())
1471 }
1472
1473 pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1475 let dim = dim.to_index(self.shape(), "slice-scatter")?;
1476 if dim == 0 {
1477 self.slice_scatter0(src, start)
1478 } else {
1479 self.transpose(0, dim)?
1481 .slice_scatter0(&src.transpose(0, dim)?, start)?
1482 .transpose(0, dim)
1483 }
1484 }
1485
1486 pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1488 if self.dtype() != src.dtype() {
1489 Err(Error::DTypeMismatchBinaryOp {
1490 lhs: self.dtype(),
1491 rhs: src.dtype(),
1492 op: "slice-scatter",
1493 }
1494 .bt())?
1495 }
1496 if self.device().location() != src.device.location() {
1497 Err(Error::DeviceMismatchBinaryOp {
1498 lhs: self.device().location(),
1499 rhs: src.device().location(),
1500 op: "slice-scatter",
1501 }
1502 .bt())?
1503 }
1504 if self.rank() != src.rank() {
1505 Err(Error::UnexpectedNumberOfDims {
1506 expected: self.rank(),
1507 got: src.rank(),
1508 shape: src.shape().clone(),
1509 }
1510 .bt())?
1511 }
1512 let shape_ok =
1513 self.dims()
1514 .iter()
1515 .zip(src.dims().iter())
1516 .enumerate()
1517 .all(|(dim_idx, (&d1, &d2))| {
1518 if 0 == dim_idx {
1519 d2 + start <= d1
1520 } else {
1521 d1 == d2
1522 }
1523 });
1524 if !shape_ok {
1525 Err(Error::ShapeMismatchBinaryOp {
1526 op: "slice-scatter (self, src)",
1527 lhs: self.shape().clone(),
1528 rhs: src.shape().clone(),
1529 }
1530 .bt())?
1531 }
1532 let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1533 self.storage()
1534 .copy_strided_src(&mut storage, 0, self.layout())?;
1535 let offset = start * src.dims()[1..].iter().product::<usize>();
1536 src.storage()
1537 .copy_strided_src(&mut storage, offset, src.layout())?;
1538 let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1539 Ok(from_storage(storage, self.shape(), op, false))
1540 }
1541
1542 pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1544 let dim = dim.to_index(self.shape(), "index-add")?;
1545 let source_dims = source.dims();
1546 let self_dims = self.dims();
1547 let mismatch = if source_dims.len() != self_dims.len() {
1548 true
1549 } else {
1550 let mut mismatch = false;
1551 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1552 if i != dim && d1 != d2 {
1553 mismatch = true;
1554 break;
1555 }
1556 }
1557 mismatch
1558 };
1559 if mismatch {
1560 Err(Error::ShapeMismatchBinaryOp {
1561 op: "index-add (self, source)",
1562 lhs: self.shape().clone(),
1563 rhs: source.shape().clone(),
1564 }
1565 .bt())?
1566 }
1567 let indexes_len = indexes.dims1()?;
1571 if source_dims[dim] != indexes_len {
1572 Err(Error::ShapeMismatchBinaryOp {
1573 op: "index-add (ids, source))",
1574 lhs: indexes.shape().clone(),
1575 rhs: source.shape().clone(),
1576 }
1577 .bt())?
1578 }
1579 let storage = self.storage().index_add(
1580 self.layout(),
1581 &indexes.storage(),
1582 indexes.layout(),
1583 &source.storage(),
1584 source.layout(),
1585 dim,
1586 )?;
1587 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1588 Op::IndexAdd(t1, t2, t3, dim)
1589 });
1590 Ok(from_storage(storage, self.shape(), op, false))
1591 }
1592
1593 pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1605 let dim = dim.to_index(self.shape(), "gather")?;
1606
1607 let self_dims = self.dims();
1608 let indexes_dims = indexes.dims();
1609 let mismatch = if indexes_dims.len() != self_dims.len() {
1610 true
1611 } else {
1612 let mut mismatch = false;
1613 for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1614 if i != dim && d1 < d2 {
1615 mismatch = true;
1616 break;
1617 }
1618 }
1619 mismatch
1620 };
1621 if mismatch {
1622 Err(Error::ShapeMismatchBinaryOp {
1623 op: "gather",
1624 lhs: self.shape().clone(),
1625 rhs: indexes.shape().clone(),
1626 }
1627 .bt())?
1628 }
1629 let storage =
1630 self.storage()
1631 .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1632 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1633 Ok(from_storage(storage, indexes.shape(), op, false))
1634 }
1635
1636 pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1644 let dim = dim.to_index(self.shape(), "index-select")?;
1645 let indexes_len = match indexes.dims() {
1646 [l] => *l,
1647 _ => Err(Error::ShapeMismatchBinaryOp {
1648 lhs: self.shape().clone(),
1649 rhs: indexes.shape().clone(),
1650 op: "index-select",
1651 }
1652 .bt())?,
1653 };
1654 let storage = self.storage().index_select(
1655 &indexes.storage(),
1656 self.layout(),
1657 indexes.layout(),
1658 dim,
1659 )?;
1660 let mut dims = self.dims().to_vec();
1661 dims[dim] = indexes_len;
1662 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1663 Ok(from_storage(storage, dims, op, false))
1664 }
1665
1666 pub fn strided_index(&self) -> crate::StridedIndex {
1669 self.layout.strided_index()
1670 }
1671
1672 pub fn strided_blocks(&self) -> crate::StridedBlocks {
1677 self.layout.strided_blocks()
1678 }
1679
1680 pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1682 if self.rank() != 1 {
1683 Err(Error::UnexpectedNumberOfDims {
1684 expected: 1,
1685 got: self.rank(),
1686 shape: self.shape().clone(),
1687 }
1688 .bt())?
1689 }
1690 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1691 let data = S::cpu_storage_as_slice(cpu_storage)?;
1692 let data = match self.layout.contiguous_offsets() {
1693 Some((o1, o2)) => data[o1..o2].to_vec(),
1694 None => self.strided_index().map(|i| data[i]).collect(),
1695 };
1696 Ok::<Vec<_>, Error>(data)
1697 };
1698 match &*self.storage() {
1699 Storage::Cpu(storage) => from_cpu_storage(storage),
1700 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1701 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1702 }
1703 }
1704
1705 pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1707 let (dim1, dim2) = self.dims2()?;
1708 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1709 let data = S::cpu_storage_as_slice(cpu_storage)?;
1710 let mut rows = vec![];
1711 match self.layout.contiguous_offsets() {
1712 Some((o1, o2)) => {
1713 let data = &data[o1..o2];
1714 for idx_row in 0..dim1 {
1715 rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1716 }
1717 }
1718 None => {
1719 let mut src_index = self.strided_index();
1720 for _idx_row in 0..dim1 {
1721 let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1722 rows.push(row)
1723 }
1724 assert!(src_index.next().is_none());
1725 }
1726 }
1727 Ok(rows)
1728 };
1729 match &*self.storage() {
1730 Storage::Cpu(storage) => from_cpu_storage(storage),
1731 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1732 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1733 }
1734 }
1735
1736 pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1738 let (dim1, dim2, dim3) = self.dims3()?;
1739 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1740 let data = S::cpu_storage_as_slice(cpu_storage)?;
1741 let mut top_rows = vec![];
1742 match self.layout.contiguous_offsets() {
1743 Some((o1, o2)) => {
1744 let data = &data[o1..o2];
1745 let dim23 = dim2 * dim3;
1746 for idx1 in 0..dim1 {
1747 let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1748 let mut rows = vec![];
1749 for idx2 in 0..dim2 {
1750 rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
1751 }
1752 top_rows.push(rows);
1753 }
1754 }
1755 None => {
1756 let mut src_index = self.strided_index();
1757 for _idx in 0..dim1 {
1758 let mut rows = vec![];
1759 for _jdx in 0..dim2 {
1760 let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
1761 rows.push(row)
1762 }
1763 top_rows.push(rows);
1764 }
1765 assert!(src_index.next().is_none());
1766 }
1767 }
1768 Ok(top_rows)
1769 };
1770 match &*self.storage() {
1771 Storage::Cpu(storage) => from_cpu_storage(storage),
1772 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1773 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1774 }
1775 }
1776
1777 pub fn dtype(&self) -> DType {
1779 self.dtype
1780 }
1781
1782 pub fn device(&self) -> &Device {
1784 &self.device
1785 }
1786
1787 pub fn shape(&self) -> &Shape {
1789 self.layout().shape()
1790 }
1791
1792 pub fn dims(&self) -> &[usize] {
1794 self.shape().dims()
1795 }
1796
1797 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
1799 let dim = dim.to_index(self.shape(), "dim")?;
1800 Ok(self.dims()[dim])
1801 }
1802
1803 pub fn layout(&self) -> &Layout {
1806 &self.layout
1807 }
1808
1809 pub fn stride(&self) -> &[usize] {
1810 self.layout.stride()
1811 }
1812
1813 pub fn rank(&self) -> usize {
1815 self.shape().rank()
1816 }
1817
1818 pub fn elem_count(&self) -> usize {
1820 self.shape().elem_count()
1821 }
1822
1823 pub fn id(&self) -> TensorId {
1825 self.id
1826 }
1827
1828 pub fn is_variable(&self) -> bool {
1831 self.is_variable
1832 }
1833
1834 pub(crate) fn op(&self) -> &Option<Op> {
1835 &self.op
1836 }
1837
1838 pub fn max_all(&self) -> Result<Tensor> {
1849 if self.rank() == 0 {
1850 Ok(self.clone())
1851 } else {
1852 self.flatten_all()?.max(0)
1853 }
1854 }
1855
1856 pub fn min_all(&self) -> Result<Tensor> {
1867 if self.rank() == 0 {
1868 Ok(self.clone())
1869 } else {
1870 self.flatten_all()?.min(0)
1871 }
1872 }
1873
1874 pub fn sum_all(&self) -> Result<Tensor> {
1885 let dims: Vec<_> = (0..self.rank()).collect();
1886 self.sum(dims)
1887 }
1888
1889 pub fn mean_all(&self) -> Result<Tensor> {
1890 self.sum_all()? / self.elem_count() as f64
1891 }
1892
1893 fn flatten_<D1: Dim, D2: Dim>(
1894 &self,
1895 start_dim: Option<D1>,
1896 end_dim: Option<D2>,
1897 ) -> Result<Tensor> {
1898 if self.rank() == 0 {
1899 self.reshape(1)
1900 } else {
1901 let start_dim = match start_dim {
1902 None => 0,
1903 Some(dim) => dim.to_index(self.shape(), "flatten")?,
1904 };
1905 let end_dim = match end_dim {
1906 None => self.rank() - 1,
1907 Some(dim) => dim.to_index(self.shape(), "flatten")?,
1908 };
1909 if start_dim < end_dim {
1910 let dims = self.dims();
1911 let mut dst_dims = dims[..start_dim].to_vec();
1912 dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
1913 if end_dim + 1 < dims.len() {
1914 dst_dims.extend(&dims[end_dim + 1..]);
1915 }
1916 self.reshape(dst_dims)
1917 } else {
1918 Ok(self.clone())
1919 }
1920 }
1921 }
1922
1923 pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
1926 self.flatten_(Some(start_dim), Some(end_dim))
1927 }
1928
1929 pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
1931 self.flatten_(None::<usize>, Some(end_dim))
1932 }
1933
1934 pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
1937 self.flatten_(Some(start_dim), None::<usize>)
1938 }
1939
1940 pub fn flatten_all(&self) -> Result<Tensor> {
1950 self.flatten_(None::<usize>, None::<usize>)
1951 }
1952
1953 pub fn get(&self, i: usize) -> Result<Tensor> {
1965 let dims = self.dims();
1966 if dims.is_empty() {
1967 Ok(self.clone())
1968 } else {
1969 self.narrow(0, i, 1)?.reshape(&dims[1..])
1970 }
1971 }
1972
1973 pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
1987 let dim = dim.to_index(self.shape(), "get_on_dim")?;
1988 self.narrow(dim, index, 1)?.squeeze(dim)
1989 }
1990
1991 pub fn t(&self) -> Result<Tensor> {
2002 let rank = self.rank();
2003 if rank < 2 {
2004 Err(Error::UnexpectedNumberOfDims {
2005 expected: 2,
2006 got: rank,
2007 shape: self.shape().clone(),
2008 }
2009 .bt())?
2010 }
2011 self.transpose(rank - 2, rank - 1)
2012 }
2013
2014 pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2017 let dim1 = dim1.to_index(self.shape(), "transpose")?;
2018 let dim2 = dim2.to_index(self.shape(), "transpose")?;
2019 if dim1 == dim2 {
2020 return Ok(self.clone());
2021 }
2022 let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2023 let tensor_ = Tensor_ {
2024 id: TensorId::new(),
2025 storage: self.storage.clone(),
2026 layout: self.layout.transpose(dim1, dim2)?,
2027 op,
2028 is_variable: false,
2029 dtype: self.dtype,
2030 device: self.device.clone(),
2031 };
2032 Ok(Tensor(Arc::new(tensor_)))
2033 }
2034
2035 pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2047 let dims = dims.to_indexes(self.shape(), "permute")?;
2048 let is_permutation =
2050 dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2051 if !is_permutation {
2052 bail!(
2053 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2054 self.dims(),
2055 dims
2056 )
2057 }
2058 let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2059 let tensor_ = Tensor_ {
2060 id: TensorId::new(),
2061 storage: self.storage.clone(),
2062 layout: self.layout.permute(&dims)?,
2063 op,
2064 is_variable: false,
2065 dtype: self.dtype,
2066 device: self.device.clone(),
2067 };
2068 Ok(Tensor(Arc::new(tensor_)))
2069 }
2070
2071 pub fn is_contiguous(&self) -> bool {
2073 self.layout.is_contiguous()
2074 }
2075
2076 pub fn is_fortran_contiguous(&self) -> bool {
2078 self.layout.is_fortran_contiguous()
2079 }
2080
2081 pub fn copy(&self) -> Result<Tensor> {
2084 let op = BackpropOp::new1(self, Op::Copy);
2085 let tensor_ = Tensor_ {
2086 id: TensorId::new(),
2087 storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2088 layout: self.layout.clone(),
2089 op,
2090 is_variable: false,
2091 dtype: self.dtype,
2092 device: self.device.clone(),
2093 };
2094 Ok(Tensor(Arc::new(tensor_)))
2095 }
2096
2097 pub fn detach(&self) -> Tensor {
2102 if self.op.is_none() && !self.is_variable {
2103 self.clone()
2104 } else {
2105 let tensor_ = Tensor_ {
2106 id: TensorId::new(),
2107 storage: self.storage.clone(),
2108 layout: self.layout.clone(),
2109 op: BackpropOp::none(),
2110 is_variable: false,
2111 dtype: self.dtype,
2112 device: self.device.clone(),
2113 };
2114 Tensor(Arc::new(tensor_))
2115 }
2116 }
2117
2118 pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2120 if self.device().same_device(device) {
2121 Ok(self.clone())
2122 } else {
2123 let storage = match (&*self.storage(), device) {
2124 (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2125 Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2126 }
2127 (Storage::Cpu(storage), Device::Metal(metal)) => {
2128 Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2129 }
2130 (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2131 (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2132 (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2133 let cpu_storage = storage.to_cpu_storage()?;
2136 Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2137 }
2138 (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2139 _ => {
2140 bail!(
2141 "not implemented yet, self.device: {:?}, device: {:?}",
2142 self.device(),
2143 device
2144 )
2145 }
2146 };
2147 let op = BackpropOp::new1(self, Op::ToDevice);
2148 let tensor_ = Tensor_ {
2149 id: TensorId::new(),
2150 storage: Arc::new(RwLock::new(storage)),
2151 layout: self.layout.clone(),
2152 op,
2153 is_variable: false,
2154 dtype: self.dtype,
2155 device: device.clone(),
2156 };
2157 Ok(Tensor(Arc::new(tensor_)))
2158 }
2159 }
2160
2161 pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2164 let left_shape = left_shape.into();
2165 let mut dims = left_shape.into_dims();
2166 dims.extend(self.dims());
2167 self.broadcast_as(dims)
2168 }
2169
2170 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2178 let tensor_ = Tensor_ {
2179 id: TensorId::new(),
2180 storage: self.storage.clone(),
2181 layout: self.layout.broadcast_as(shape)?,
2182 op: BackpropOp::new1(self, Op::Broadcast),
2183 is_variable: false,
2184 dtype: self.dtype,
2185 device: self.device.clone(),
2186 };
2187 Ok(Tensor(Arc::new(tensor_)))
2188 }
2189
2190 pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2192 self.broadcast_as(shape)
2193 }
2194
2195 pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2206 if self.dtype() == dtype {
2207 Ok(self.clone())
2208 } else {
2209 let shape = self.shape();
2210 let storage = self.storage().to_dtype(self.layout(), dtype)?;
2211 let op = BackpropOp::new1(self, Op::ToDType);
2212 Ok(from_storage(storage, shape.clone(), op, false))
2213 }
2214 }
2215
2216 pub fn contiguous(&self) -> Result<Tensor> {
2219 if self.is_contiguous() {
2220 Ok(self.clone())
2221 } else {
2222 let shape = self.shape();
2223 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2224 self.storage()
2225 .copy_strided_src(&mut storage, 0, self.layout())?;
2226 let op = BackpropOp::new1(self, Op::Copy);
2227 Ok(from_storage(storage, shape.clone(), op, false))
2228 }
2229 }
2230
2231 pub fn force_contiguous(&self) -> Result<Tensor> {
2233 let shape = self.shape();
2234 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2235 self.storage()
2236 .copy_strided_src(&mut storage, 0, self.layout())?;
2237 let op = BackpropOp::new1(self, Op::Copy);
2238 Ok(from_storage(storage, shape.clone(), op, false))
2239 }
2240
2241 pub(crate) fn make_var(&self) -> Result<Tensor> {
2244 let shape = self.shape().clone();
2245 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2246 self.storage()
2247 .copy_strided_src(&mut storage, 0, self.layout())?;
2248 Ok(from_storage(storage, shape, BackpropOp::none(), true))
2249 }
2250
2251 pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2276 let shape = s.into_shape(self.elem_count())?;
2277 if shape.elem_count() != self.elem_count() {
2278 return Err(Error::ShapeMismatchBinaryOp {
2279 lhs: self.shape().clone(),
2280 rhs: shape,
2281 op: "reshape",
2282 }
2283 .bt());
2284 }
2285 let op = BackpropOp::new1(self, Op::Reshape);
2286 if self.is_contiguous() {
2287 let tensor_ = Tensor_ {
2288 id: TensorId::new(),
2289 storage: self.storage.clone(),
2290 layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2291 op,
2292 is_variable: false,
2293 dtype: self.dtype,
2294 device: self.device.clone(),
2295 };
2296 Ok(Tensor(Arc::new(tensor_)))
2297 } else {
2298 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2299 self.storage()
2300 .copy_strided_src(&mut storage, 0, self.layout())?;
2301 Ok(from_storage(storage, shape, op, false))
2302 }
2303 }
2304
2305 pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2319 let dims = self.dims();
2322 let dim = dim.to_index(self.shape(), "squeeze")?;
2323 if dims[dim] == 1 {
2324 let mut dims = dims.to_vec();
2325 let mut strides = self.stride().to_vec();
2326 dims.remove(dim);
2327 strides.remove(dim);
2328 let tensor_ = Tensor_ {
2329 id: TensorId::new(),
2330 storage: self.storage.clone(),
2331 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2332 op: BackpropOp::new1(self, Op::Reshape),
2333 is_variable: false,
2334 dtype: self.dtype,
2335 device: self.device.clone(),
2336 };
2337 Ok(Tensor(Arc::new(tensor_)))
2338 } else {
2339 Ok(self.clone())
2340 }
2341 }
2342
2343 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2357 let mut dims = self.dims().to_vec();
2358 let mut strides = self.stride().to_vec();
2359 let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2360 dims.insert(dim, 1);
2362 let stride = if dim < strides.len() { strides[dim] } else { 1 };
2365 strides.insert(dim, stride);
2366 let tensor_ = Tensor_ {
2367 id: TensorId::new(),
2368 storage: self.storage.clone(),
2369 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2370 op: BackpropOp::new1(self, Op::Reshape),
2371 is_variable: false,
2372 dtype: self.dtype,
2373 device: self.device.clone(),
2374 };
2375 Ok(Tensor(Arc::new(tensor_)))
2376 }
2377
2378 pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2395 if args.is_empty() {
2396 Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2397 }
2398 let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2399 let args = args
2400 .iter()
2401 .map(|t| t.as_ref().unsqueeze(dim))
2402 .collect::<Result<Vec<_>>>()?;
2403 Self::cat(&args, dim)
2404 }
2405
2406 pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2409 if left == 0 && right == 0 {
2410 Ok(self.clone())
2411 } else if left == 0 {
2412 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2413 let mut dims = self.dims().to_vec();
2414 dims[dim] = right;
2415 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2416 Tensor::cat(&[self, &right], dim)
2417 } else if right == 0 {
2418 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2419 let mut dims = self.dims().to_vec();
2420 dims[dim] = left;
2421 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2422 Tensor::cat(&[&left, self], dim)
2423 } else {
2424 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2425 let mut dims = self.dims().to_vec();
2426 dims[dim] = left;
2427 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2428 dims[dim] = right;
2429 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2430 Tensor::cat(&[&left, self, &right], dim)
2431 }
2432 }
2433
2434 pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2437 if left == 0 && right == 0 {
2438 Ok(self.clone())
2439 } else if self.elem_count() == 0 {
2440 bail!("cannot use pad_with_same on an empty tensor")
2441 } else if left == 0 {
2442 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2443 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2444 let mut v = vec![self];
2445 for _ in 0..right {
2446 v.push(&r)
2447 }
2448 Tensor::cat(&v, dim)
2449 } else if right == 0 {
2450 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2451 let l = self.narrow(dim, 0, 1)?;
2452 let mut v = vec![];
2453 for _ in 0..left {
2454 v.push(&l)
2455 }
2456 v.push(self);
2457 Tensor::cat(&v, dim)
2458 } else {
2459 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2460 let l = self.narrow(dim, 0, 1)?;
2461 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2462 let mut v = vec![];
2463 for _ in 0..left {
2464 v.push(&l)
2465 }
2466 v.push(self);
2467 for _ in 0..right {
2468 v.push(&r)
2469 }
2470 Tensor::cat(&v, dim)
2471 }
2472 }
2473
2474 pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2476 m.forward(self)
2477 }
2478
2479 pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2481 m.forward_t(self, train)
2482 }
2483
2484 pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2485 self.storage.read().unwrap()
2486 }
2487
2488 pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2489 self.storage.write().unwrap()
2490 }
2491
2492 pub(crate) fn storage_mut_and_layout(
2495 &self,
2496 ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2497 let storage = self.storage.write().unwrap();
2498 (storage, &self.layout)
2499 }
2500
2501 pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2503 let storage = self.storage.read().unwrap();
2504 (storage, &self.layout)
2505 }
2506
2507 pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2508 let lhs: &RwLock<Storage> = self.storage.as_ref();
2509 let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2510 std::ptr::eq(lhs, rhs)
2511 }
2512
2513 pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2516 let rank = self.rank() as i64;
2517 if rank <= axis {
2518 bail!("axis {axis} is too large, tensor rank {rank}")
2519 } else if 0 <= axis {
2520 Ok(axis as usize)
2521 } else {
2522 let naxis = rank + axis;
2523 if naxis < 0 {
2524 bail!("axis {axis} is too small, tensor rank {rank}")
2525 }
2526 Ok(naxis as usize)
2527 }
2528 }
2529
2530 pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2532 let t = Tensor::arange(0u32, n as u32, device)?;
2533 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2534 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2535 t1.le(&t2)?.to_dtype(dtype)
2536 }
2537
2538 pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2540 let t = Tensor::arange(0u32, n as u32, device)?;
2541 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2542 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2543 t1.ge(&t2)?.to_dtype(dtype)
2544 }
2545
2546 pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2548 let t = Tensor::arange(0u32, n as u32, device)?;
2549 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2550 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2551 t1.eq(&t2)?.to_dtype(dtype)
2552 }
2553
2554 pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2559 let dim = dim.to_index(self.shape(), "cumsum")?;
2560 let rank = self.rank();
2561 if rank == 0 {
2562 return Ok(self.clone());
2563 }
2564 let n_axis = self.dim(dim)?;
2565 let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2566 if rank == 1 {
2567 self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2568 } else {
2569 let last = rank - 1;
2570 let t = self.transpose(dim, last)?;
2571 let t = t.broadcast_matmul(&triu)?;
2572 t.transpose(dim, last)
2573 }
2574 }
2575
2576 pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2579 &self,
2580 ranges: &[D],
2581 src: &Tensor,
2582 ) -> Result<Self> {
2583 let src_dims = src.dims();
2584 let self_dims = self.dims();
2585 if self_dims.len() != src_dims.len() {
2586 bail!(
2587 "slice-assign requires input with the same rank {} <> {}",
2588 self_dims.len(),
2589 src_dims.len()
2590 )
2591 }
2592 if self_dims.len() != ranges.len() {
2593 bail!(
2594 "slice-assign requires input with the same rank as there are ranges {} <> {}",
2595 self_dims.len(),
2596 ranges.len()
2597 )
2598 }
2599 let mut src = src.clone();
2600 let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2601 for (i, range) in ranges.iter().enumerate() {
2602 let start_included = match range.start_bound() {
2603 std::ops::Bound::Unbounded => 0,
2604 std::ops::Bound::Included(v) => *v,
2605 std::ops::Bound::Excluded(v) => *v + 1,
2606 };
2607 let end_excluded = match range.end_bound() {
2608 std::ops::Bound::Unbounded => self_dims[i],
2609 std::ops::Bound::Included(v) => *v + 1,
2610 std::ops::Bound::Excluded(v) => *v,
2611 };
2612 if end_excluded <= start_included {
2613 bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2614 }
2615 if self_dims[i] < end_excluded {
2616 bail!(
2617 "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2618 self_dims[i]
2619 )
2620 }
2621 if end_excluded - start_included != src_dims[i] {
2622 bail!(
2623 "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2624 )
2625 }
2626 src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2627 mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2628 }
2629 mask.where_cond(&src, self)
2630 }
2631
2632 pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2634 let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2635 if sum_dims.is_empty() {
2636 return Ok(self.clone());
2637 }
2638 let max = sum_dims[1..]
2639 .iter()
2640 .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2641 max.max_keepdim(dim)
2642 })?;
2643 let exp = self.broadcast_sub(&max)?.exp()?;
2644 let sum = exp.sum(sum_dims.clone())?;
2645
2646 sum.log()? + max.squeeze_dims(&sum_dims)
2647 }
2648
2649 pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2651 rhs.mul(&self.log()?)?.exp()
2652 }
2653
2654 pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2656 rhs.broadcast_mul(&self.log()?)?.exp()
2657 }
2658
2659 pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2671 let mut result = self.clone();
2672 for &dim in dims.iter() {
2673 let size = result.dim(dim)?;
2674 let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2675 let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2676 result = result.index_select(&indices_tensor, dim)?;
2677 }
2678 Ok(result)
2679 }
2680}
2681
2682macro_rules! bin_trait {
2683 ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2684 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2685 type Output = Result<Tensor>;
2686
2687 fn $fn1(self, rhs: B) -> Self::Output {
2688 Tensor::$fn1(&self, rhs.borrow())
2689 }
2690 }
2691
2692 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2693 type Output = Result<Tensor>;
2694
2695 fn $fn1(self, rhs: B) -> Self::Output {
2696 Tensor::$fn1(&self, rhs.borrow())
2697 }
2698 }
2699
2700 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2701 type Output = Result<Tensor>;
2702
2703 fn $fn1(self, rhs: Tensor) -> Self::Output {
2704 Tensor::$fn1(self?.borrow(), &rhs)
2705 }
2706 }
2707
2708 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
2709 type Output = Result<Tensor>;
2710
2711 fn $fn1(self, rhs: &Tensor) -> Self::Output {
2712 Tensor::$fn1(self?.borrow(), rhs)
2713 }
2714 }
2715
2716 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
2717 type Output = Result<Tensor>;
2718
2719 fn $fn1(self, rhs: Result<B>) -> Self::Output {
2720 Tensor::$fn1(&self, rhs?.borrow())
2721 }
2722 }
2723
2724 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
2725 type Output = Result<Tensor>;
2726
2727 fn $fn1(self, rhs: Result<B>) -> Self::Output {
2728 Tensor::$fn1(&self, rhs?.borrow())
2729 }
2730 }
2731
2732 impl std::ops::$trait<f64> for Tensor {
2733 type Output = Result<Tensor>;
2734
2735 fn $fn1(self, rhs: f64) -> Self::Output {
2736 self.affine($mul(rhs), $add(rhs))
2737 }
2738 }
2739
2740 impl std::ops::$trait<f64> for &Tensor {
2741 type Output = Result<Tensor>;
2742
2743 fn $fn1(self, rhs: f64) -> Self::Output {
2744 self.affine($mul(rhs), $add(rhs))
2745 }
2746 }
2747 };
2748}
2749
2750bin_trait!(Add, add, |_| 1., |v| v);
2751bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
2752bin_trait!(Mul, mul, |v| v, |_| 0.);
2753bin_trait!(Div, div, |v| 1. / v, |_| 0.);
2754
2755impl std::ops::Add<Tensor> for f64 {
2756 type Output = Result<Tensor>;
2757
2758 fn add(self, rhs: Tensor) -> Self::Output {
2759 rhs + self
2760 }
2761}
2762
2763impl std::ops::Add<&Tensor> for f64 {
2764 type Output = Result<Tensor>;
2765
2766 fn add(self, rhs: &Tensor) -> Self::Output {
2767 rhs + self
2768 }
2769}
2770
2771impl std::ops::Mul<Tensor> for f64 {
2772 type Output = Result<Tensor>;
2773
2774 fn mul(self, rhs: Tensor) -> Self::Output {
2775 rhs * self
2776 }
2777}
2778
2779impl std::ops::Mul<&Tensor> for f64 {
2780 type Output = Result<Tensor>;
2781
2782 fn mul(self, rhs: &Tensor) -> Self::Output {
2783 rhs * self
2784 }
2785}
2786
2787impl std::ops::Sub<Tensor> for f64 {
2788 type Output = Result<Tensor>;
2789
2790 fn sub(self, rhs: Tensor) -> Self::Output {
2791 rhs.affine(-1., self)
2792 }
2793}
2794
2795impl std::ops::Sub<&Tensor> for f64 {
2796 type Output = Result<Tensor>;
2797
2798 fn sub(self, rhs: &Tensor) -> Self::Output {
2799 rhs.affine(-1., self)
2800 }
2801}
2802
2803impl std::ops::Div<Tensor> for f64 {
2804 type Output = Result<Tensor>;
2805
2806 #[allow(clippy::suspicious_arithmetic_impl)]
2807 fn div(self, rhs: Tensor) -> Self::Output {
2808 rhs.recip()? * self
2809 }
2810}
2811
2812impl std::ops::Div<&Tensor> for f64 {
2813 type Output = Result<Tensor>;
2814
2815 #[allow(clippy::suspicious_arithmetic_impl)]
2816 fn div(self, rhs: &Tensor) -> Self::Output {
2817 rhs.recip()? * self
2818 }
2819}