1#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15 fn new() -> Self {
16 use std::sync::atomic;
18 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20 }
21}
22
23pub struct Tensor_ {
24 id: TensorId,
25 storage: Arc<RwLock<Storage>>,
38 layout: Layout,
39 op: BackpropOp,
40 is_variable: bool,
41 dtype: DType,
42 device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46 fn as_ref(&self) -> &Tensor {
47 self
48 }
49}
50
51#[derive(Clone)]
55pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71 type Target = Tensor_;
72
73 fn deref(&self) -> &Self::Target {
74 self.0.as_ref()
75 }
76}
77
78macro_rules! unary_op {
79 ($fn_name:ident, $op_name:ident) => {
80 pub fn $fn_name(&self) -> Result<Self> {
81 let shape = self.shape();
82 if shape.elem_count() == 0 {
83 return Ok(self.clone());
84 }
85 let storage = self
86 .storage()
87 .unary_impl::<crate::op::$op_name>(self.layout())?;
88 let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89 Ok(from_storage(storage, shape.clone(), op, false))
90 }
91 };
92}
93
94macro_rules! binary_op {
95 ($fn_name:ident, $op_name:ident) => {
96 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97 let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98 if shape.elem_count() == 0 {
99 return Ok(self.clone());
100 }
101 let storage = self.storage().binary_impl::<crate::op::$op_name>(
102 &*rhs.storage(),
103 self.layout(),
104 rhs.layout(),
105 )?;
106 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107 Ok(from_storage(storage, shape.clone(), op, false))
108 }
109 };
110}
111
112macro_rules! binary_op_scalar {
113 ($fn_name:ident, $op_name:ident) => {
114 pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115 let rhs = match rhs.to_tensor_scalar()? {
116 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117 crate::scalar::TensorScalar::Scalar(rhs) => rhs
118 .to_dtype(self.dtype())?
119 .to_device(self.device())?
120 .broadcast_as(self.shape())?,
121 };
122 let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123 if self.elem_count() == 0 {
124 return Ok(self.clone());
125 }
126 let storage = self.storage().binary_impl::<crate::op::$op_name>(
127 &*rhs.storage(),
128 self.layout(),
129 rhs.layout(),
130 )?;
131 let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132 Ok(from_storage(storage, shape.clone(), op, false))
133 }
134 };
135}
136
137macro_rules! broadcast_binary_op {
138 ($fn_name:ident, $inner_fn_name:ident) => {
139 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140 let lhs = self;
141 let shape = lhs
142 .shape()
143 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144 let l_broadcast = shape != *lhs.shape();
145 let r_broadcast = shape != *rhs.shape();
146 match (l_broadcast, r_broadcast) {
147 (true, true) => lhs
148 .broadcast_as(&shape)?
149 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152 (false, false) => lhs.$inner_fn_name(rhs),
153 }
154 }
155 };
156}
157
158pub(crate) fn from_storage<S: Into<Shape>>(
160 storage: Storage,
161 shape: S,
162 op: BackpropOp,
163 is_variable: bool,
164) -> Tensor {
165 let dtype = storage.dtype();
166 let device = storage.device();
167 let tensor_ = Tensor_ {
168 id: TensorId::new(),
169 storage: Arc::new(RwLock::new(storage)),
170 layout: Layout::contiguous(shape),
171 op,
172 is_variable,
173 dtype,
174 device,
175 };
176 Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180 pub(crate) fn ones_impl<S: Into<Shape>>(
181 shape: S,
182 dtype: DType,
183 device: &Device,
184 is_variable: bool,
185 ) -> Result<Self> {
186 let none = BackpropOp::none();
187 let shape = shape.into();
188 let storage = device.ones(&shape, dtype)?;
189 Ok(from_storage(storage, shape, none, is_variable))
190 }
191
192 pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
202 Self::ones_impl(shape, dtype, device, false)
203 }
204
205 pub fn ones_like(&self) -> Result<Self> {
215 Tensor::ones(self.shape(), self.dtype(), self.device())
216 }
217
218 pub(crate) fn zeros_impl<S: Into<Shape>>(
221 shape: S,
222 dtype: DType,
223 device: &Device,
224 is_variable: bool,
225 ) -> Result<Self> {
226 let none = BackpropOp::none();
227 let shape = shape.into();
228 let storage = device.zeros(&shape, dtype)?;
229 Ok(from_storage(storage, shape, none, is_variable))
230 }
231
232 pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
242 Self::zeros_impl(shape, dtype, device, false)
243 }
244
245 pub fn zeros_like(&self) -> Result<Self> {
256 Tensor::zeros(self.shape(), self.dtype(), self.device())
257 }
258
259 pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
260 lo: T,
261 up: T,
262 s: S,
263 device: &Device,
264 is_variable: bool,
265 ) -> Result<Self> {
266 let s = s.into();
267 let storage = device.rand_uniform(lo, up, &s)?;
268 let none = BackpropOp::none();
269 Ok(from_storage(storage, s, none, is_variable))
270 }
271
272 pub(crate) fn rand_f64_impl<S: Into<Shape>>(
273 lo: f64,
274 up: f64,
275 s: S,
276 dtype: DType,
277 device: &Device,
278 is_variable: bool,
279 ) -> Result<Self> {
280 let s = s.into();
281 let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
282 let none = BackpropOp::none();
283 Ok(from_storage(storage, s, none, is_variable))
284 }
285
286 pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
288 lo: T,
289 up: T,
290 s: S,
291 device: &Device,
292 ) -> Result<Self> {
293 Self::rand_impl(lo, up, s, device, false)
294 }
295
296 pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
297 Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
298 }
299
300 pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
301 mean: T,
302 std: T,
303 s: S,
304 device: &Device,
305 is_variable: bool,
306 ) -> Result<Self> {
307 let s = s.into();
308 let storage = device.rand_normal(mean, std, &s)?;
309 let none = BackpropOp::none();
310 Ok(from_storage(storage, s, none, is_variable))
311 }
312
313 pub(crate) fn randn_f64_impl<S: Into<Shape>>(
314 mean: f64,
315 std: f64,
316 s: S,
317 dtype: DType,
318 device: &Device,
319 is_variable: bool,
320 ) -> Result<Self> {
321 let s = s.into();
322 let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
323 let none = BackpropOp::none();
324 Ok(from_storage(storage, s, none, is_variable))
325 }
326
327 pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
328 Tensor::randn_f64_impl(
329 mean,
330 stdev,
331 self.shape(),
332 self.dtype(),
333 self.device(),
334 false,
335 )
336 }
337
338 pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
341 mean: T,
342 std: T,
343 s: S,
344 device: &Device,
345 ) -> Result<Self> {
346 Self::randn_impl(mean, std, s, device, false)
347 }
348
349 pub(crate) fn new_impl<A: crate::device::NdArray>(
350 array: A,
351 shape: Shape,
352 device: &Device,
353 is_variable: bool,
354 ) -> Result<Self> {
355 let n: usize = shape.elem_count();
356 let buffer_size: usize = array.shape()?.elem_count();
357 if buffer_size != n {
358 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
359 }
360 let storage = device.storage(array)?;
361 let none = BackpropOp::none();
362 Ok(from_storage(storage, shape, none, is_variable))
363 }
364
365 pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
367 let shape = array.shape()?;
368 Self::new_impl(array, shape, device, false)
369 }
370
371 pub fn full<D: crate::WithDType, S: Into<Shape>>(
383 value: D,
384 shape: S,
385 device: &Device,
386 ) -> Result<Self> {
387 Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
388 }
389
390 pub fn from_iter<D: crate::WithDType>(
399 iter: impl IntoIterator<Item = D>,
400 device: &Device,
401 ) -> Result<Self> {
402 let data = iter.into_iter().collect::<Vec<_>>();
403 let len = data.len();
404 Self::from_vec_impl(data, len, device, false)
405 }
406
407 pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
417 Self::arange_step(start, end, D::one(), device)
418 }
419
420 pub fn arange_step<D: crate::WithDType>(
430 start: D,
431 end: D,
432 step: D,
433 device: &Device,
434 ) -> Result<Self> {
435 if D::is_zero(&step) {
436 bail!("step cannot be zero")
437 }
438 let mut data = vec![];
439 let mut current = start;
440 if step >= D::zero() {
441 while current < end {
442 data.push(current);
443 current += step;
444 }
445 } else {
446 while current > end {
447 data.push(current);
448 current += step;
449 }
450 }
451 let len = data.len();
452 Self::from_vec_impl(data, len, device, false)
453 }
454
455 pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
456 data: Vec<D>,
457 shape: S,
458 device: &Device,
459 is_variable: bool,
460 ) -> Result<Self> {
461 let shape = shape.into();
462 let buffer_size = data.len();
463 if buffer_size != shape.elem_count() {
464 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
465 }
466 let storage = device.storage_owned(data)?;
467 let none = BackpropOp::none();
468 Ok(from_storage(storage, shape, none, is_variable))
469 }
470
471 pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
485 data: Vec<D>,
486 shape: S,
487 device: &Device,
488 ) -> Result<Self> {
489 Self::from_vec_impl(data, shape, device, false)
490 }
491
492 pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
506 array: &[D],
507 shape: S,
508 device: &Device,
509 ) -> Result<Self> {
510 let shape = shape.into();
511 let n: usize = shape.elem_count();
512 let buffer_size: usize = array.len();
513 if buffer_size != n {
514 return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
515 }
516 let storage = device.storage_from_slice(array)?;
517 let none = BackpropOp::none();
518 Ok(from_storage(storage, shape, none, false))
519 }
520
521 pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
522 let lhs = self.shape();
523 let rhs = rhs.shape();
524 if lhs != rhs {
525 Err(Error::ShapeMismatchBinaryOp {
526 lhs: lhs.clone(),
527 rhs: rhs.clone(),
528 op,
529 }
530 .bt())
531 } else {
532 Ok(lhs)
533 }
534 }
535
536 pub fn track_op(&self) -> bool {
539 self.is_variable || self.op.is_some()
540 }
541
542 binary_op!(add, Add);
545 binary_op!(mul, Mul);
546 binary_op!(sub, Sub);
547 binary_op!(div, Div);
548 binary_op_scalar!(maximum, Maximum);
549 binary_op_scalar!(minimum, Minimum);
550 broadcast_binary_op!(broadcast_add, add);
551 broadcast_binary_op!(broadcast_mul, mul);
552 broadcast_binary_op!(broadcast_sub, sub);
553 broadcast_binary_op!(broadcast_div, div);
554 broadcast_binary_op!(broadcast_maximum, maximum);
555 broadcast_binary_op!(broadcast_minimum, minimum);
556 broadcast_binary_op!(broadcast_eq, eq);
557 broadcast_binary_op!(broadcast_ne, ne);
558 broadcast_binary_op!(broadcast_lt, lt);
559 broadcast_binary_op!(broadcast_le, le);
560 broadcast_binary_op!(broadcast_gt, gt);
561 broadcast_binary_op!(broadcast_ge, ge);
562
563 unary_op!(recip, Recip);
564 unary_op!(neg, Neg);
565 unary_op!(exp, Exp);
566 unary_op!(log, Log);
567 unary_op!(sin, Sin);
568 unary_op!(cos, Cos);
569 unary_op!(tanh, Tanh);
570 unary_op!(abs, Abs);
571 unary_op!(sqr, Sqr);
572 unary_op!(sqrt, Sqrt);
573 unary_op!(gelu, Gelu);
574 unary_op!(gelu_erf, GeluErf);
575 unary_op!(erf, Erf);
576 unary_op!(relu, Relu);
577 unary_op!(silu, Silu);
578 unary_op!(ceil, Ceil);
579 unary_op!(floor, Floor);
580 unary_op!(round, Round);
581 unary_op!(sign, Sign);
582
583 pub fn round_to(&self, decimals: i32) -> Result<Self> {
588 let mult = 10f64.powi(decimals);
589 (self * mult)?.round()? * (1f64 / mult)
590 }
591
592 pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
595 if self.rank() != 0 {
596 Err(Error::UnexpectedNumberOfDims {
597 expected: 0,
598 got: self.rank(),
599 shape: self.shape().clone(),
600 }
601 .bt())?
602 }
603 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
604 let data = S::cpu_storage_as_slice(cpu_storage)?;
605 Ok::<_, Error>(data[self.layout().start_offset()])
606 };
607 match &*self.storage() {
608 Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
609 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
610 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
611 }
612 }
613
614 pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
616 self.to_scalar::<S>()
617 }
618
619 pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
621 let repeats = shape.into();
623 let repeats = repeats.dims();
624 let mut inp = if self.rank() < repeats.len() {
625 let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
626 self.reshape(shape)?
627 } else {
628 self.clone()
629 };
630 for (idx, &repeat) in repeats.iter().enumerate() {
631 if repeat > 1 {
632 inp = Tensor::cat(&vec![&inp; repeat], idx)?
633 }
634 }
635 Ok(inp)
636 }
637
638 pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
675 if args.len() <= 1 {
676 Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
677 }
678 let args: Vec<_> = if xy_indexing {
679 args.iter().rev().collect()
680 } else {
681 args.iter().collect()
682 };
683
684 let mut shape = Vec::with_capacity(args.len());
685 for arg in args.iter() {
686 shape.push(arg.as_ref().dims1()?)
687 }
688
689 let mut grids = Vec::with_capacity(args.len());
690 for idx in 0..args.len() {
691 let mut ones = vec![1usize; args.len()];
692 ones[idx] = shape[idx];
693 let arg = args[idx].as_ref().reshape(ones)?;
694 let mut repeats = shape.clone();
695 repeats[idx] = 1;
696 let repeated_tensor = arg.repeat(repeats)?;
697 grids.push(repeated_tensor);
698 }
699 if xy_indexing {
700 grids.reverse();
701 }
702 Ok(grids)
703 }
704
705 pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
717 if self.elem_count() == 0 {
718 return Ok(self.clone());
719 }
720 let storage = self.storage().affine(self.layout(), mul, add)?;
721 let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
722 Ok(from_storage(storage, self.shape(), op, false))
723 }
724
725 pub fn elu(&self, alpha: f64) -> Result<Self> {
727 if self.elem_count() == 0 {
728 return Ok(self.clone());
729 }
730 let storage = self.storage().elu(self.layout(), alpha)?;
731 let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
732 Ok(from_storage(storage, self.shape(), op, false))
733 }
734
735 pub fn powf(&self, e: f64) -> Result<Self> {
737 if self.elem_count() == 0 {
738 return Ok(self.clone());
739 }
740 let storage = self.storage().powf(self.layout(), e)?;
741 let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
742 Ok(from_storage(storage, self.shape(), op, false))
743 }
744
745 pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
746 if dim >= self.dims().len() {
747 Err(Error::DimOutOfRange {
748 shape: self.shape().clone(),
749 dim: dim as i32,
750 op,
751 }
752 .bt())?
753 } else {
754 Ok(())
755 }
756 }
757
758 pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
761 let dim = dim.to_index(self.shape(), "chunk")?;
762 let size = self.dim(dim)?;
763 if size < chunks {
764 (0..size).map(|i| self.narrow(dim, i, 1)).collect()
765 } else {
766 let chunk_size = size / chunks;
767 let cnt_additional = size % chunks;
768 let mut tensors = vec![];
769 let mut sum_chunk_size = 0;
770 for i in 0..chunks {
771 let chunk_size = if i < cnt_additional {
772 chunk_size + 1
773 } else {
774 chunk_size
775 };
776 let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
777 tensors.push(tensor);
778 sum_chunk_size += chunk_size
779 }
780 Ok(tensors)
781 }
782 }
783
784 pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
811 let dims = self.dims();
812 let dim = dim.to_index(self.shape(), "narrow")?;
813 let err = |msg| {
814 Err::<(), _>(
815 Error::NarrowInvalidArgs {
816 shape: self.shape().clone(),
817 dim,
818 start,
819 len,
820 msg,
821 }
822 .bt(),
823 )
824 };
825 if start > dims[dim] {
826 err("start > dim_len")?
827 }
828 if start.saturating_add(len) > dims[dim] {
829 err("start + len > dim_len")?
830 }
831 if start == 0 && dims[dim] == len {
832 Ok(self.clone())
833 } else {
834 let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
835 let layout = self.layout().narrow(dim, start, len)?;
836 let tensor_ = Tensor_ {
837 id: TensorId::new(),
838 storage: self.storage.clone(),
839 layout,
840 op,
841 is_variable: false,
842 dtype: self.dtype,
843 device: self.device.clone(),
844 };
845 Ok(Tensor(Arc::new(tensor_)))
846 }
847 }
848
849 fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
850 match dims {
851 [] => Ok(self),
852 [i] => self.squeeze(*i),
853 dims => {
854 let dims = self
855 .dims()
856 .iter()
857 .enumerate()
858 .filter_map(|(dim_idx, &v)| {
859 if dims.contains(&dim_idx) {
860 None
861 } else {
862 Some(v)
863 }
864 })
865 .collect::<Vec<_>>();
866 self.reshape(dims)
867 }
868 }
869 }
870
871 fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
872 let dim = dim.to_index(self.shape(), op.name())?;
873 let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
874 let mut dims = self.dims().to_vec();
875 dims[dim] = 1;
876 let op = match op {
877 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
878 BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
879 }
880 ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
881 };
882 let res = from_storage(storage, dims, op, false);
883 if keepdim {
884 Ok(res)
885 } else {
886 res.squeeze_dims(&[dim])
887 }
888 }
889
890 fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
891 let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
892 let storage = self
893 .storage()
894 .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
895 let mut dims = self.dims().to_vec();
896 for &sum_dim in sum_dims.iter() {
897 dims[sum_dim] = 1
898 }
899 let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
900 let sum = from_storage(storage, dims, op, false);
901 if keepdim {
902 Ok(sum)
903 } else {
904 sum.squeeze_dims(&sum_dims)
905 }
906 }
907
908 pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
922 where
923 D: Dim + Clone,
924 {
925 let dim = dim.to_index(self.shape(), "roll")?;
926 let dim_size = self.dim(dim)?;
927 let shift = shift.rem_euclid(dim_size as i32) as usize;
928 if shift == 0 {
929 Ok(self.clone())
930 } else {
931 let a = self.narrow(dim, 0, dim_size - shift)?;
932 let b = self.narrow(dim, dim_size - shift, shift)?;
933 Tensor::cat(&[&b, &a], dim)
934 }
935 }
936
937 pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
955 self.sum_impl(sum_dims, true)
956 }
957
958 pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
962 self.sum_impl(sum_dims, false)
963 }
964
965 pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
983 let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
984 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
985 let scale = 1f64 / (reduced_dim as f64);
986 self.sum_impl(mean_dims, true)? * scale
987 }
988
989 pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
993 let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
994 let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
995 let scale = 1f64 / (reduced_dim as f64);
996 self.sum_impl(mean_dims, false)? * scale
997 }
998
999 pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1001 let dim = dim.to_index(self.shape(), "var")?;
1002 let mean = self.mean_keepdim(dim)?;
1003 let squares = self.broadcast_sub(&mean)?.sqr()?;
1004 squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1005 }
1006
1007 pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1009 let dim = dim.to_index(self.shape(), "var")?;
1010 self.var_keepdim(dim)?.squeeze(dim)
1011 }
1012
1013 pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1016 self.reduce_impl(dim, true, ReduceOp::Max)
1017 }
1018
1019 pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1021 self.reduce_impl(dim, false, ReduceOp::Max)
1022 }
1023
1024 pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1027 self.reduce_impl(dim, true, ReduceOp::Min)
1028 }
1029
1030 pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1032 self.reduce_impl(dim, false, ReduceOp::Min)
1033 }
1034
1035 pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1036 self.reduce_impl(dim, true, ReduceOp::ArgMax)
1037 }
1038
1039 pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1041 self.reduce_impl(dim, false, ReduceOp::ArgMax)
1042 }
1043
1044 pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1045 self.reduce_impl(dim, true, ReduceOp::ArgMin)
1046 }
1047
1048 pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1050 self.reduce_impl(dim, false, ReduceOp::ArgMin)
1051 }
1052
1053 pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1058 let rhs = match rhs.to_tensor_scalar()? {
1059 crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1060 crate::scalar::TensorScalar::Scalar(rhs) => rhs
1061 .to_dtype(self.dtype())?
1062 .to_device(self.device())?
1063 .broadcast_as(self.shape())?,
1064 };
1065 let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1066 let storage = self
1067 .storage()
1068 .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1069 let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1070 Ok(from_storage(storage, shape.dims(), op, false))
1071 }
1072
1073 pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1075 self.cmp(rhs, CmpOp::Eq)
1076 }
1077
1078 pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1080 self.cmp(rhs, CmpOp::Ne)
1081 }
1082
1083 pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1086 self.cmp(rhs, CmpOp::Lt)
1087 }
1088
1089 pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1092 self.cmp(rhs, CmpOp::Gt)
1093 }
1094
1095 pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1098 self.cmp(rhs, CmpOp::Ge)
1099 }
1100
1101 pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1104 self.cmp(rhs, CmpOp::Le)
1105 }
1106
1107 pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1109 self.maximum(min)?.minimum(max)
1110 }
1111
1112 pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1117 let (n, c, _l) = self.dims3()?;
1118 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1119 let storage = self
1120 .storage()
1121 .upsample_nearest1d(self.layout(), target_size)?;
1122 Ok(from_storage(storage, (n, c, target_size), op, false))
1123 }
1124
1125 pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1127 self.interpolate1d(target_size)
1128 }
1129
1130 pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1136 let (n, c, _h, _w) = self.dims4()?;
1137 let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1138 arg,
1139 target_h,
1140 target_w,
1141 });
1142 let storage = self
1143 .storage()
1144 .upsample_nearest2d(self.layout(), target_h, target_w)?;
1145 Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1146 }
1147
1148 pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1150 self.interpolate2d(target_h, target_w)
1151 }
1152
1153 pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1160 let sz = sz.to_usize2();
1161 self.avg_pool2d_with_stride(sz, sz)
1162 }
1163
1164 pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1167 &self,
1168 kernel_size: T,
1169 stride: T,
1170 ) -> Result<Self> {
1171 let kernel_size = kernel_size.to_usize2();
1172 let stride = stride.to_usize2();
1173 let (n, c, h, w) = self.dims4()?;
1174 if h < kernel_size.0 || w < kernel_size.1 {
1175 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1176 }
1177 let h_out = (h - kernel_size.0) / stride.0 + 1;
1179 let w_out = (w - kernel_size.1) / stride.1 + 1;
1180 let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1181 arg,
1182 kernel_size,
1183 stride,
1184 });
1185 let storage = self
1186 .storage()
1187 .avg_pool2d(self.layout(), kernel_size, stride)?;
1188 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1189 }
1190
1191 pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1198 let sz = sz.to_usize2();
1199 self.max_pool2d_with_stride(sz, sz)
1200 }
1201
1202 pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1205 &self,
1206 kernel_size: T,
1207 stride: T,
1208 ) -> Result<Self> {
1209 let kernel_size = kernel_size.to_usize2();
1210 let stride = stride.to_usize2();
1211 let (n, c, h, w) = self.dims4()?;
1212 if h < kernel_size.0 || w < kernel_size.1 {
1213 bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1214 }
1215 let h_out = (h - kernel_size.0) / stride.0 + 1;
1217 let w_out = (w - kernel_size.1) / stride.1 + 1;
1218 let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1219 arg,
1220 kernel_size,
1221 stride,
1222 });
1223 let storage = self
1224 .storage()
1225 .max_pool2d(self.layout(), kernel_size, stride)?;
1226 Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1227 }
1228
1229 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1238 let a_dims = self.shape().dims();
1239 let b_dims = rhs.shape().dims();
1240
1241 let dim = a_dims.len();
1242
1243 if dim < 2 || b_dims.len() != dim {
1244 Err(Error::ShapeMismatchBinaryOp {
1245 lhs: self.shape().clone(),
1246 rhs: rhs.shape().clone(),
1247 op: "matmul",
1248 }
1249 .bt())?
1250 }
1251
1252 let m = a_dims[dim - 2];
1253 let k = a_dims[dim - 1];
1254 let k2 = b_dims[dim - 2];
1255 let n = b_dims[dim - 1];
1256
1257 let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1258 if c_shape.elem_count() == 0 || k == 0 {
1259 return Tensor::zeros(c_shape, self.dtype(), self.device());
1260 }
1261 let batching: usize = a_dims[..dim - 2].iter().product();
1262 let batching_b: usize = b_dims[..dim - 2].iter().product();
1263 if k != k2 || batching != batching_b {
1264 Err(Error::ShapeMismatchBinaryOp {
1265 lhs: self.shape().clone(),
1266 rhs: rhs.shape().clone(),
1267 op: "matmul",
1268 }
1269 .bt())?
1270 }
1271
1272 let storage = self.storage().matmul(
1273 &rhs.storage(),
1274 (batching, m, n, k),
1275 self.layout(),
1276 rhs.layout(),
1277 )?;
1278 let op = BackpropOp::new2(self, rhs, Op::Matmul);
1279 Ok(from_storage(storage, c_shape, op, false))
1280 }
1281
1282 pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1288 let lhs = self;
1289 let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1290 let l_broadcast = l_shape != *lhs.shape();
1291 let r_broadcast = r_shape != *rhs.shape();
1292 match (l_broadcast, r_broadcast) {
1294 (true, true) => lhs
1295 .broadcast_as(&l_shape)?
1296 .contiguous()?
1297 .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1298 (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1299 (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1300 (false, false) => lhs.matmul(rhs),
1301 }
1302 }
1303
1304 pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1308 let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1309 let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1310 let storage = self.storage().where_cond(
1311 self.layout(),
1312 &on_true.storage(),
1313 on_true.layout(),
1314 &on_false.storage(),
1315 on_false.layout(),
1316 )?;
1317 let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1318 Ok(from_storage(storage, shape, op, false))
1319 }
1320
1321 pub fn embedding(&self, ids: &Self) -> Result<Self> {
1341 if self.rank() != 2 || ids.rank() != 1 {
1342 Err(Error::ShapeMismatchBinaryOp {
1343 lhs: self.shape().clone(),
1344 rhs: ids.shape().clone(),
1345 op: "embedding",
1346 }
1347 .bt())?
1348 }
1349 self.index_select(ids, 0)
1350 }
1351
1352 pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1353 let dim = dim.to_index(self.shape(), "scatter-add")?;
1354 let source_dims = source.dims();
1355 let self_dims = self.dims();
1356 let mismatch = if source_dims.len() != self_dims.len() {
1357 true
1358 } else {
1359 let mut mismatch = false;
1360 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1361 if i != dim && d1 != d2 {
1362 mismatch = true;
1363 break;
1364 }
1365 }
1366 mismatch
1367 };
1368 if mismatch {
1369 Err(Error::ShapeMismatchBinaryOp {
1370 op: "scatter-add (self, src)",
1371 lhs: self.shape().clone(),
1372 rhs: source.shape().clone(),
1373 }
1374 .bt())?
1375 }
1376 if indexes.dims() != source.dims() {
1377 Err(Error::ShapeMismatchBinaryOp {
1378 op: "scatter-add (indexes, src)",
1379 lhs: indexes.shape().clone(),
1380 rhs: source.shape().clone(),
1381 }
1382 .bt())?
1383 }
1384 let storage = self.storage().scatter_add(
1385 self.layout(),
1386 &indexes.storage(),
1387 indexes.layout(),
1388 &source.storage(),
1389 source.layout(),
1390 dim,
1391 )?;
1392 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1393 Op::ScatterAdd(t1, t2, t3, dim)
1394 });
1395 Ok(from_storage(storage, self.shape(), op, false))
1396 }
1397
1398 pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1400 let dim = dim.to_index(self.shape(), "slice-scatter")?;
1401 if dim == 0 {
1402 self.slice_scatter0(src, start)
1403 } else {
1404 self.transpose(0, dim)?
1406 .slice_scatter0(&src.transpose(0, dim)?, start)?
1407 .transpose(0, dim)
1408 }
1409 }
1410
1411 pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1413 if self.dtype() != src.dtype() {
1414 Err(Error::DTypeMismatchBinaryOp {
1415 lhs: self.dtype(),
1416 rhs: src.dtype(),
1417 op: "slice-scatter",
1418 }
1419 .bt())?
1420 }
1421 if self.device().location() != src.device.location() {
1422 Err(Error::DeviceMismatchBinaryOp {
1423 lhs: self.device().location(),
1424 rhs: src.device().location(),
1425 op: "slice-scatter",
1426 }
1427 .bt())?
1428 }
1429 if self.rank() != src.rank() {
1430 Err(Error::UnexpectedNumberOfDims {
1431 expected: self.rank(),
1432 got: src.rank(),
1433 shape: src.shape().clone(),
1434 }
1435 .bt())?
1436 }
1437 let shape_ok =
1438 self.dims()
1439 .iter()
1440 .zip(src.dims().iter())
1441 .enumerate()
1442 .all(|(dim_idx, (&d1, &d2))| {
1443 if 0 == dim_idx {
1444 d2 + start <= d1
1445 } else {
1446 d1 == d2
1447 }
1448 });
1449 if !shape_ok {
1450 Err(Error::ShapeMismatchBinaryOp {
1451 op: "slice-scatter (self, src)",
1452 lhs: self.shape().clone(),
1453 rhs: src.shape().clone(),
1454 }
1455 .bt())?
1456 }
1457 let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1458 self.storage()
1459 .copy_strided_src(&mut storage, 0, self.layout())?;
1460 let offset = start * src.dims()[1..].iter().product::<usize>();
1461 src.storage()
1462 .copy_strided_src(&mut storage, offset, src.layout())?;
1463 let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1464 Ok(from_storage(storage, self.shape(), op, false))
1465 }
1466
1467 pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1469 let dim = dim.to_index(self.shape(), "index-add")?;
1470 let source_dims = source.dims();
1471 let self_dims = self.dims();
1472 let mismatch = if source_dims.len() != self_dims.len() {
1473 true
1474 } else {
1475 let mut mismatch = false;
1476 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1477 if i != dim && d1 != d2 {
1478 mismatch = true;
1479 break;
1480 }
1481 }
1482 mismatch
1483 };
1484 if mismatch {
1485 Err(Error::ShapeMismatchBinaryOp {
1486 op: "index-add (self, source)",
1487 lhs: self.shape().clone(),
1488 rhs: source.shape().clone(),
1489 }
1490 .bt())?
1491 }
1492 let indexes_len = indexes.dims1()?;
1496 if source_dims[dim] != indexes_len {
1497 Err(Error::ShapeMismatchBinaryOp {
1498 op: "index-add (ids, source))",
1499 lhs: indexes.shape().clone(),
1500 rhs: source.shape().clone(),
1501 }
1502 .bt())?
1503 }
1504 let storage = self.storage().index_add(
1505 self.layout(),
1506 &indexes.storage(),
1507 indexes.layout(),
1508 &source.storage(),
1509 source.layout(),
1510 dim,
1511 )?;
1512 let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1513 Op::IndexAdd(t1, t2, t3, dim)
1514 });
1515 Ok(from_storage(storage, self.shape(), op, false))
1516 }
1517
1518 pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1530 let dim = dim.to_index(self.shape(), "gather")?;
1531
1532 let self_dims = self.dims();
1533 let indexes_dims = indexes.dims();
1534 let mismatch = if indexes_dims.len() != self_dims.len() {
1535 true
1536 } else {
1537 let mut mismatch = false;
1538 for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1539 if i != dim && d1 < d2 {
1540 mismatch = true;
1541 break;
1542 }
1543 }
1544 mismatch
1545 };
1546 if mismatch {
1547 Err(Error::ShapeMismatchBinaryOp {
1548 op: "gather",
1549 lhs: self.shape().clone(),
1550 rhs: indexes.shape().clone(),
1551 }
1552 .bt())?
1553 }
1554 let storage =
1555 self.storage()
1556 .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1557 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1558 Ok(from_storage(storage, indexes.shape(), op, false))
1559 }
1560
1561 pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1569 let dim = dim.to_index(self.shape(), "index-select")?;
1570 let indexes_len = match indexes.dims() {
1571 [l] => *l,
1572 _ => Err(Error::ShapeMismatchBinaryOp {
1573 lhs: self.shape().clone(),
1574 rhs: indexes.shape().clone(),
1575 op: "index-select",
1576 }
1577 .bt())?,
1578 };
1579 let storage = self.storage().index_select(
1580 &indexes.storage(),
1581 self.layout(),
1582 indexes.layout(),
1583 dim,
1584 )?;
1585 let mut dims = self.dims().to_vec();
1586 dims[dim] = indexes_len;
1587 let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1588 Ok(from_storage(storage, dims, op, false))
1589 }
1590
1591 pub fn strided_index(&self) -> crate::StridedIndex {
1594 self.layout.strided_index()
1595 }
1596
1597 pub fn strided_blocks(&self) -> crate::StridedBlocks {
1602 self.layout.strided_blocks()
1603 }
1604
1605 pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1607 if self.rank() != 1 {
1608 Err(Error::UnexpectedNumberOfDims {
1609 expected: 1,
1610 got: self.rank(),
1611 shape: self.shape().clone(),
1612 }
1613 .bt())?
1614 }
1615 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1616 let data = S::cpu_storage_as_slice(cpu_storage)?;
1617 let data = match self.layout.contiguous_offsets() {
1618 Some((o1, o2)) => data[o1..o2].to_vec(),
1619 None => self.strided_index().map(|i| data[i]).collect(),
1620 };
1621 Ok::<Vec<_>, Error>(data)
1622 };
1623 match &*self.storage() {
1624 Storage::Cpu(storage) => from_cpu_storage(storage),
1625 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1626 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1627 }
1628 }
1629
1630 pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1632 let (dim1, dim2) = self.dims2()?;
1633 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1634 let data = S::cpu_storage_as_slice(cpu_storage)?;
1635 let mut rows = vec![];
1636 match self.layout.contiguous_offsets() {
1637 Some((o1, o2)) => {
1638 let data = &data[o1..o2];
1639 for idx_row in 0..dim1 {
1640 rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1641 }
1642 }
1643 None => {
1644 let mut src_index = self.strided_index();
1645 for _idx_row in 0..dim1 {
1646 let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1647 rows.push(row)
1648 }
1649 assert!(src_index.next().is_none());
1650 }
1651 }
1652 Ok(rows)
1653 };
1654 match &*self.storage() {
1655 Storage::Cpu(storage) => from_cpu_storage(storage),
1656 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1657 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1658 }
1659 }
1660
1661 pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1663 let (dim1, dim2, dim3) = self.dims3()?;
1664 let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1665 let data = S::cpu_storage_as_slice(cpu_storage)?;
1666 let mut top_rows = vec![];
1667 match self.layout.contiguous_offsets() {
1668 Some((o1, o2)) => {
1669 let data = &data[o1..o2];
1670 let dim23 = dim2 * dim3;
1671 for idx1 in 0..dim1 {
1672 let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1673 let mut rows = vec![];
1674 for idx2 in 0..dim2 {
1675 rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
1676 }
1677 top_rows.push(rows);
1678 }
1679 }
1680 None => {
1681 let mut src_index = self.strided_index();
1682 for _idx in 0..dim1 {
1683 let mut rows = vec![];
1684 for _jdx in 0..dim2 {
1685 let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
1686 rows.push(row)
1687 }
1688 top_rows.push(rows);
1689 }
1690 assert!(src_index.next().is_none());
1691 }
1692 }
1693 Ok(top_rows)
1694 };
1695 match &*self.storage() {
1696 Storage::Cpu(storage) => from_cpu_storage(storage),
1697 Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1698 Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1699 }
1700 }
1701
1702 pub fn dtype(&self) -> DType {
1704 self.dtype
1705 }
1706
1707 pub fn device(&self) -> &Device {
1709 &self.device
1710 }
1711
1712 pub fn shape(&self) -> &Shape {
1714 self.layout().shape()
1715 }
1716
1717 pub fn dims(&self) -> &[usize] {
1719 self.shape().dims()
1720 }
1721
1722 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
1724 let dim = dim.to_index(self.shape(), "dim")?;
1725 Ok(self.dims()[dim])
1726 }
1727
1728 pub fn layout(&self) -> &Layout {
1731 &self.layout
1732 }
1733
1734 pub fn stride(&self) -> &[usize] {
1735 self.layout.stride()
1736 }
1737
1738 pub fn rank(&self) -> usize {
1740 self.shape().rank()
1741 }
1742
1743 pub fn elem_count(&self) -> usize {
1745 self.shape().elem_count()
1746 }
1747
1748 pub fn id(&self) -> TensorId {
1750 self.id
1751 }
1752
1753 pub fn is_variable(&self) -> bool {
1756 self.is_variable
1757 }
1758
1759 pub(crate) fn op(&self) -> &Option<Op> {
1760 &self.op
1761 }
1762
1763 pub fn max_all(&self) -> Result<Tensor> {
1774 if self.rank() == 0 {
1775 Ok(self.clone())
1776 } else {
1777 self.flatten_all()?.max(0)
1778 }
1779 }
1780
1781 pub fn min_all(&self) -> Result<Tensor> {
1792 if self.rank() == 0 {
1793 Ok(self.clone())
1794 } else {
1795 self.flatten_all()?.min(0)
1796 }
1797 }
1798
1799 pub fn sum_all(&self) -> Result<Tensor> {
1810 let dims: Vec<_> = (0..self.rank()).collect();
1811 self.sum(dims)
1812 }
1813
1814 pub fn mean_all(&self) -> Result<Tensor> {
1815 self.sum_all()? / self.elem_count() as f64
1816 }
1817
1818 fn flatten_<D1: Dim, D2: Dim>(
1819 &self,
1820 start_dim: Option<D1>,
1821 end_dim: Option<D2>,
1822 ) -> Result<Tensor> {
1823 if self.rank() == 0 {
1824 self.reshape(1)
1825 } else {
1826 let start_dim = match start_dim {
1827 None => 0,
1828 Some(dim) => dim.to_index(self.shape(), "flatten")?,
1829 };
1830 let end_dim = match end_dim {
1831 None => self.rank() - 1,
1832 Some(dim) => dim.to_index(self.shape(), "flatten")?,
1833 };
1834 if start_dim < end_dim {
1835 let dims = self.dims();
1836 let mut dst_dims = dims[..start_dim].to_vec();
1837 dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
1838 if end_dim + 1 < dims.len() {
1839 dst_dims.extend(&dims[end_dim + 1..]);
1840 }
1841 self.reshape(dst_dims)
1842 } else {
1843 Ok(self.clone())
1844 }
1845 }
1846 }
1847
1848 pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
1851 self.flatten_(Some(start_dim), Some(end_dim))
1852 }
1853
1854 pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
1856 self.flatten_(None::<usize>, Some(end_dim))
1857 }
1858
1859 pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
1862 self.flatten_(Some(start_dim), None::<usize>)
1863 }
1864
1865 pub fn flatten_all(&self) -> Result<Tensor> {
1875 self.flatten_(None::<usize>, None::<usize>)
1876 }
1877
1878 pub fn get(&self, i: usize) -> Result<Tensor> {
1890 let dims = self.dims();
1891 if dims.is_empty() {
1892 Ok(self.clone())
1893 } else {
1894 self.narrow(0, i, 1)?.reshape(&dims[1..])
1895 }
1896 }
1897
1898 pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
1912 let dim = dim.to_index(self.shape(), "get_on_dim")?;
1913 self.narrow(dim, index, 1)?.squeeze(dim)
1914 }
1915
1916 pub fn t(&self) -> Result<Tensor> {
1927 let rank = self.rank();
1928 if rank < 2 {
1929 Err(Error::UnexpectedNumberOfDims {
1930 expected: 2,
1931 got: rank,
1932 shape: self.shape().clone(),
1933 }
1934 .bt())?
1935 }
1936 self.transpose(rank - 2, rank - 1)
1937 }
1938
1939 pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
1942 let dim1 = dim1.to_index(self.shape(), "transpose")?;
1943 let dim2 = dim2.to_index(self.shape(), "transpose")?;
1944 if dim1 == dim2 {
1945 return Ok(self.clone());
1946 }
1947 let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
1948 let tensor_ = Tensor_ {
1949 id: TensorId::new(),
1950 storage: self.storage.clone(),
1951 layout: self.layout.transpose(dim1, dim2)?,
1952 op,
1953 is_variable: false,
1954 dtype: self.dtype,
1955 device: self.device.clone(),
1956 };
1957 Ok(Tensor(Arc::new(tensor_)))
1958 }
1959
1960 pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
1972 let dims = dims.to_indexes(self.shape(), "permute")?;
1973 let is_permutation =
1975 dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
1976 if !is_permutation {
1977 bail!(
1978 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
1979 self.dims(),
1980 dims
1981 )
1982 }
1983 let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
1984 let tensor_ = Tensor_ {
1985 id: TensorId::new(),
1986 storage: self.storage.clone(),
1987 layout: self.layout.permute(&dims)?,
1988 op,
1989 is_variable: false,
1990 dtype: self.dtype,
1991 device: self.device.clone(),
1992 };
1993 Ok(Tensor(Arc::new(tensor_)))
1994 }
1995
1996 pub fn is_contiguous(&self) -> bool {
1998 self.layout.is_contiguous()
1999 }
2000
2001 pub fn is_fortran_contiguous(&self) -> bool {
2003 self.layout.is_fortran_contiguous()
2004 }
2005
2006 pub fn copy(&self) -> Result<Tensor> {
2009 let op = BackpropOp::new1(self, Op::Copy);
2010 let tensor_ = Tensor_ {
2011 id: TensorId::new(),
2012 storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2013 layout: self.layout.clone(),
2014 op,
2015 is_variable: false,
2016 dtype: self.dtype,
2017 device: self.device.clone(),
2018 };
2019 Ok(Tensor(Arc::new(tensor_)))
2020 }
2021
2022 pub fn detach(&self) -> Tensor {
2027 if self.op.is_none() && !self.is_variable {
2028 self.clone()
2029 } else {
2030 let tensor_ = Tensor_ {
2031 id: TensorId::new(),
2032 storage: self.storage.clone(),
2033 layout: self.layout.clone(),
2034 op: BackpropOp::none(),
2035 is_variable: false,
2036 dtype: self.dtype,
2037 device: self.device.clone(),
2038 };
2039 Tensor(Arc::new(tensor_))
2040 }
2041 }
2042
2043 pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2045 if self.device().same_device(device) {
2046 Ok(self.clone())
2047 } else {
2048 let storage = match (&*self.storage(), device) {
2049 (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2050 Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2051 }
2052 (Storage::Cpu(storage), Device::Metal(metal)) => {
2053 Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2054 }
2055 (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2056 (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2057 (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2058 let cpu_storage = storage.to_cpu_storage()?;
2061 Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2062 }
2063 (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2064 _ => {
2065 bail!(
2066 "not implemented yet, self.device: {:?}, device: {:?}",
2067 self.device(),
2068 device
2069 )
2070 }
2071 };
2072 let op = BackpropOp::new1(self, Op::ToDevice);
2073 let tensor_ = Tensor_ {
2074 id: TensorId::new(),
2075 storage: Arc::new(RwLock::new(storage)),
2076 layout: self.layout.clone(),
2077 op,
2078 is_variable: false,
2079 dtype: self.dtype,
2080 device: device.clone(),
2081 };
2082 Ok(Tensor(Arc::new(tensor_)))
2083 }
2084 }
2085
2086 pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2089 let left_shape = left_shape.into();
2090 let mut dims = left_shape.into_dims();
2091 dims.extend(self.dims());
2092 self.broadcast_as(dims)
2093 }
2094
2095 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2103 let tensor_ = Tensor_ {
2104 id: TensorId::new(),
2105 storage: self.storage.clone(),
2106 layout: self.layout.broadcast_as(shape)?,
2107 op: BackpropOp::new1(self, Op::Broadcast),
2108 is_variable: false,
2109 dtype: self.dtype,
2110 device: self.device.clone(),
2111 };
2112 Ok(Tensor(Arc::new(tensor_)))
2113 }
2114
2115 pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2117 self.broadcast_as(shape)
2118 }
2119
2120 pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2131 if self.dtype() == dtype {
2132 Ok(self.clone())
2133 } else {
2134 let shape = self.shape();
2135 let storage = self.storage().to_dtype(self.layout(), dtype)?;
2136 let op = BackpropOp::new1(self, Op::ToDType);
2137 Ok(from_storage(storage, shape.clone(), op, false))
2138 }
2139 }
2140
2141 pub fn contiguous(&self) -> Result<Tensor> {
2144 if self.is_contiguous() {
2145 Ok(self.clone())
2146 } else {
2147 let shape = self.shape();
2148 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2149 self.storage()
2150 .copy_strided_src(&mut storage, 0, self.layout())?;
2151 let op = BackpropOp::new1(self, Op::Copy);
2152 Ok(from_storage(storage, shape.clone(), op, false))
2153 }
2154 }
2155
2156 pub fn force_contiguous(&self) -> Result<Tensor> {
2158 let shape = self.shape();
2159 let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2160 self.storage()
2161 .copy_strided_src(&mut storage, 0, self.layout())?;
2162 let op = BackpropOp::new1(self, Op::Copy);
2163 Ok(from_storage(storage, shape.clone(), op, false))
2164 }
2165
2166 pub(crate) fn make_var(&self) -> Result<Tensor> {
2169 let shape = self.shape().clone();
2170 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2171 self.storage()
2172 .copy_strided_src(&mut storage, 0, self.layout())?;
2173 Ok(from_storage(storage, shape, BackpropOp::none(), true))
2174 }
2175
2176 pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2201 let shape = s.into_shape(self.elem_count())?;
2202 if shape.elem_count() != self.elem_count() {
2203 return Err(Error::ShapeMismatchBinaryOp {
2204 lhs: self.shape().clone(),
2205 rhs: shape,
2206 op: "reshape",
2207 }
2208 .bt());
2209 }
2210 let op = BackpropOp::new1(self, Op::Reshape);
2211 if self.is_contiguous() {
2212 let tensor_ = Tensor_ {
2213 id: TensorId::new(),
2214 storage: self.storage.clone(),
2215 layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2216 op,
2217 is_variable: false,
2218 dtype: self.dtype,
2219 device: self.device.clone(),
2220 };
2221 Ok(Tensor(Arc::new(tensor_)))
2222 } else {
2223 let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2224 self.storage()
2225 .copy_strided_src(&mut storage, 0, self.layout())?;
2226 Ok(from_storage(storage, shape, op, false))
2227 }
2228 }
2229
2230 pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2244 let dims = self.dims();
2247 let dim = dim.to_index(self.shape(), "squeeze")?;
2248 if dims[dim] == 1 {
2249 let mut dims = dims.to_vec();
2250 let mut strides = self.stride().to_vec();
2251 dims.remove(dim);
2252 strides.remove(dim);
2253 let tensor_ = Tensor_ {
2254 id: TensorId::new(),
2255 storage: self.storage.clone(),
2256 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2257 op: BackpropOp::new1(self, Op::Reshape),
2258 is_variable: false,
2259 dtype: self.dtype,
2260 device: self.device.clone(),
2261 };
2262 Ok(Tensor(Arc::new(tensor_)))
2263 } else {
2264 Ok(self.clone())
2265 }
2266 }
2267
2268 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2282 let mut dims = self.dims().to_vec();
2283 let mut strides = self.stride().to_vec();
2284 let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2285 dims.insert(dim, 1);
2287 let stride = if dim < strides.len() { strides[dim] } else { 1 };
2290 strides.insert(dim, stride);
2291 let tensor_ = Tensor_ {
2292 id: TensorId::new(),
2293 storage: self.storage.clone(),
2294 layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2295 op: BackpropOp::new1(self, Op::Reshape),
2296 is_variable: false,
2297 dtype: self.dtype,
2298 device: self.device.clone(),
2299 };
2300 Ok(Tensor(Arc::new(tensor_)))
2301 }
2302
2303 pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2320 if args.is_empty() {
2321 Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2322 }
2323 let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2324 let args = args
2325 .iter()
2326 .map(|t| t.as_ref().unsqueeze(dim))
2327 .collect::<Result<Vec<_>>>()?;
2328 Self::cat(&args, dim)
2329 }
2330
2331 pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2334 if left == 0 && right == 0 {
2335 Ok(self.clone())
2336 } else if left == 0 {
2337 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2338 let mut dims = self.dims().to_vec();
2339 dims[dim] = right;
2340 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2341 Tensor::cat(&[self, &right], dim)
2342 } else if right == 0 {
2343 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2344 let mut dims = self.dims().to_vec();
2345 dims[dim] = left;
2346 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2347 Tensor::cat(&[&left, self], dim)
2348 } else {
2349 let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2350 let mut dims = self.dims().to_vec();
2351 dims[dim] = left;
2352 let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2353 dims[dim] = right;
2354 let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2355 Tensor::cat(&[&left, self, &right], dim)
2356 }
2357 }
2358
2359 pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2362 if left == 0 && right == 0 {
2363 Ok(self.clone())
2364 } else if self.elem_count() == 0 {
2365 bail!("cannot use pad_with_same on an empty tensor")
2366 } else if left == 0 {
2367 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2368 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2369 let mut v = vec![self];
2370 for _ in 0..right {
2371 v.push(&r)
2372 }
2373 Tensor::cat(&v, dim)
2374 } else if right == 0 {
2375 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2376 let l = self.narrow(dim, 0, 1)?;
2377 let mut v = vec![];
2378 for _ in 0..left {
2379 v.push(&l)
2380 }
2381 v.push(self);
2382 Tensor::cat(&v, dim)
2383 } else {
2384 let dim = dim.to_index(self.shape(), "pad_with_same")?;
2385 let l = self.narrow(dim, 0, 1)?;
2386 let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2387 let mut v = vec![];
2388 for _ in 0..left {
2389 v.push(&l)
2390 }
2391 v.push(self);
2392 for _ in 0..right {
2393 v.push(&r)
2394 }
2395 Tensor::cat(&v, dim)
2396 }
2397 }
2398
2399 pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2401 m.forward(self)
2402 }
2403
2404 pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2406 m.forward_t(self, train)
2407 }
2408
2409 pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2410 self.storage.read().unwrap()
2411 }
2412
2413 pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2414 self.storage.write().unwrap()
2415 }
2416
2417 pub(crate) fn storage_mut_and_layout(
2420 &self,
2421 ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2422 let storage = self.storage.write().unwrap();
2423 (storage, &self.layout)
2424 }
2425
2426 pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2428 let storage = self.storage.read().unwrap();
2429 (storage, &self.layout)
2430 }
2431
2432 pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2433 let lhs: &RwLock<Storage> = self.storage.as_ref();
2434 let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2435 std::ptr::eq(lhs, rhs)
2436 }
2437
2438 pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2441 let rank = self.rank() as i64;
2442 if rank <= axis {
2443 bail!("axis {axis} is too large, tensor rank {rank}")
2444 } else if 0 <= axis {
2445 Ok(axis as usize)
2446 } else {
2447 let naxis = rank + axis;
2448 if naxis < 0 {
2449 bail!("axis {axis} is too small, tensor rank {rank}")
2450 }
2451 Ok(naxis as usize)
2452 }
2453 }
2454
2455 pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2457 let t = Tensor::arange(0u32, n as u32, device)?;
2458 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2459 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2460 t1.le(&t2)?.to_dtype(dtype)
2461 }
2462
2463 pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2465 let t = Tensor::arange(0u32, n as u32, device)?;
2466 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2467 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2468 t1.ge(&t2)?.to_dtype(dtype)
2469 }
2470
2471 pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2473 let t = Tensor::arange(0u32, n as u32, device)?;
2474 let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2475 let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2476 t1.eq(&t2)?.to_dtype(dtype)
2477 }
2478
2479 pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2484 let dim = dim.to_index(self.shape(), "cumsum")?;
2485 let rank = self.rank();
2486 if rank == 0 {
2487 return Ok(self.clone());
2488 }
2489 let n_axis = self.dim(dim)?;
2490 let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2491 if rank == 1 {
2492 self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2493 } else {
2494 let last = rank - 1;
2495 let t = self.transpose(dim, last)?;
2496 let t = t.broadcast_matmul(&triu)?;
2497 t.transpose(dim, last)
2498 }
2499 }
2500
2501 pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2504 &self,
2505 ranges: &[D],
2506 src: &Tensor,
2507 ) -> Result<Self> {
2508 let src_dims = src.dims();
2509 let self_dims = self.dims();
2510 if self_dims.len() != src_dims.len() {
2511 bail!(
2512 "slice-assign requires input with the same rank {} <> {}",
2513 self_dims.len(),
2514 src_dims.len()
2515 )
2516 }
2517 if self_dims.len() != ranges.len() {
2518 bail!(
2519 "slice-assign requires input with the same rank as there are ranges {} <> {}",
2520 self_dims.len(),
2521 ranges.len()
2522 )
2523 }
2524 let mut src = src.clone();
2525 let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2526 for (i, range) in ranges.iter().enumerate() {
2527 let start_included = match range.start_bound() {
2528 std::ops::Bound::Unbounded => 0,
2529 std::ops::Bound::Included(v) => *v,
2530 std::ops::Bound::Excluded(v) => *v + 1,
2531 };
2532 let end_excluded = match range.end_bound() {
2533 std::ops::Bound::Unbounded => self_dims[i],
2534 std::ops::Bound::Included(v) => *v + 1,
2535 std::ops::Bound::Excluded(v) => *v,
2536 };
2537 if end_excluded <= start_included {
2538 bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2539 }
2540 if self_dims[i] < end_excluded {
2541 bail!(
2542 "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2543 self_dims[i]
2544 )
2545 }
2546 if end_excluded - start_included != src_dims[i] {
2547 bail!(
2548 "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2549 )
2550 }
2551 src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2552 mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2553 }
2554 mask.where_cond(&src, self)
2555 }
2556
2557 pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2559 let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2560 if sum_dims.is_empty() {
2561 return Ok(self.clone());
2562 }
2563 let max = sum_dims[1..]
2564 .iter()
2565 .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2566 max.max_keepdim(dim)
2567 })?;
2568 let exp = self.broadcast_sub(&max)?.exp()?;
2569 let sum = exp.sum(sum_dims.clone())?;
2570
2571 sum.log()? + max.squeeze_dims(&sum_dims)
2572 }
2573
2574 pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2576 rhs.mul(&self.log()?)?.exp()
2577 }
2578
2579 pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2581 rhs.broadcast_mul(&self.log()?)?.exp()
2582 }
2583}
2584
2585macro_rules! bin_trait {
2586 ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2587 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2588 type Output = Result<Tensor>;
2589
2590 fn $fn1(self, rhs: B) -> Self::Output {
2591 Tensor::$fn1(&self, rhs.borrow())
2592 }
2593 }
2594
2595 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2596 type Output = Result<Tensor>;
2597
2598 fn $fn1(self, rhs: B) -> Self::Output {
2599 Tensor::$fn1(&self, rhs.borrow())
2600 }
2601 }
2602
2603 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2604 type Output = Result<Tensor>;
2605
2606 fn $fn1(self, rhs: Tensor) -> Self::Output {
2607 Tensor::$fn1(self?.borrow(), &rhs)
2608 }
2609 }
2610
2611 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
2612 type Output = Result<Tensor>;
2613
2614 fn $fn1(self, rhs: &Tensor) -> Self::Output {
2615 Tensor::$fn1(self?.borrow(), rhs)
2616 }
2617 }
2618
2619 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
2620 type Output = Result<Tensor>;
2621
2622 fn $fn1(self, rhs: Result<B>) -> Self::Output {
2623 Tensor::$fn1(&self, rhs?.borrow())
2624 }
2625 }
2626
2627 impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
2628 type Output = Result<Tensor>;
2629
2630 fn $fn1(self, rhs: Result<B>) -> Self::Output {
2631 Tensor::$fn1(&self, rhs?.borrow())
2632 }
2633 }
2634
2635 impl std::ops::$trait<f64> for Tensor {
2636 type Output = Result<Tensor>;
2637
2638 fn $fn1(self, rhs: f64) -> Self::Output {
2639 self.affine($mul(rhs), $add(rhs))
2640 }
2641 }
2642
2643 impl std::ops::$trait<f64> for &Tensor {
2644 type Output = Result<Tensor>;
2645
2646 fn $fn1(self, rhs: f64) -> Self::Output {
2647 self.affine($mul(rhs), $add(rhs))
2648 }
2649 }
2650 };
2651}
2652
2653bin_trait!(Add, add, |_| 1., |v| v);
2654bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
2655bin_trait!(Mul, mul, |v| v, |_| 0.);
2656bin_trait!(Div, div, |v| 1. / v, |_| 0.);
2657
2658impl std::ops::Add<Tensor> for f64 {
2659 type Output = Result<Tensor>;
2660
2661 fn add(self, rhs: Tensor) -> Self::Output {
2662 rhs + self
2663 }
2664}
2665
2666impl std::ops::Add<&Tensor> for f64 {
2667 type Output = Result<Tensor>;
2668
2669 fn add(self, rhs: &Tensor) -> Self::Output {
2670 rhs + self
2671 }
2672}
2673
2674impl std::ops::Mul<Tensor> for f64 {
2675 type Output = Result<Tensor>;
2676
2677 fn mul(self, rhs: Tensor) -> Self::Output {
2678 rhs * self
2679 }
2680}
2681
2682impl std::ops::Mul<&Tensor> for f64 {
2683 type Output = Result<Tensor>;
2684
2685 fn mul(self, rhs: &Tensor) -> Self::Output {
2686 rhs * self
2687 }
2688}
2689
2690impl std::ops::Sub<Tensor> for f64 {
2691 type Output = Result<Tensor>;
2692
2693 fn sub(self, rhs: Tensor) -> Self::Output {
2694 rhs.affine(-1., self)
2695 }
2696}
2697
2698impl std::ops::Sub<&Tensor> for f64 {
2699 type Output = Result<Tensor>;
2700
2701 fn sub(self, rhs: &Tensor) -> Self::Output {
2702 rhs.affine(-1., self)
2703 }
2704}
2705
2706impl std::ops::Div<Tensor> for f64 {
2707 type Output = Result<Tensor>;
2708
2709 #[allow(clippy::suspicious_arithmetic_impl)]
2710 fn div(self, rhs: Tensor) -> Self::Output {
2711 rhs.recip()? * self
2712 }
2713}
2714
2715impl std::ops::Div<&Tensor> for f64 {
2716 type Output = Result<Tensor>;
2717
2718 #[allow(clippy::suspicious_arithmetic_impl)]
2719 fn div(self, rhs: &Tensor) -> Self::Output {
2720 rhs.recip()? * self
2721 }
2722}