candle_core/
tensor.rs

1//! Tensors are N-dimensional matrixes of elements using a single data type.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10/// Unique identifier for tensors.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15    fn new() -> Self {
16        // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
17        use std::sync::atomic;
18        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20    }
21}
22
23pub struct Tensor_ {
24    id: TensorId,
25    // As we provide inner mutability on the tensor content, the alternatives are:
26    // - Using a mutex, this would have the highest cost when retrieving the storage but would
27    //   prevent errors when concurrent access takes place. Mutex would also be subject to
28    //   deadlocks for example using the current code if the same tensor is used twice by a single
29    //   binary op.
30    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
31    //   verified dynamically, but the resulting tensors would not be send or sync.
32    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
33    //   accesses.
34    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
35    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
36    // that's tricky to encode in the current setup.
37    storage: Arc<RwLock<Storage>>,
38    layout: Layout,
39    op: BackpropOp,
40    is_variable: bool,
41    dtype: DType,
42    device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46    fn as_ref(&self) -> &Tensor {
47        self
48    }
49}
50
51// Tensors are refcounted so that cloning is cheap when building the op graph.
52// Storages are also refcounted independently so that its possible to avoid
53// copying the storage for operations that only modify the shape or stride.
54#[derive(Clone)]
55/// The core struct for manipulating tensors.
56///
57/// ```rust
58/// use candle_core::{Tensor, DType, Device};
59///
60/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
61/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
62///
63/// let c = a.matmul(&b)?;
64/// # Ok::<(), candle_core::Error>(())
65/// ```
66///
67/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
68pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71    type Target = Tensor_;
72
73    fn deref(&self) -> &Self::Target {
74        self.0.as_ref()
75    }
76}
77
78macro_rules! unary_op {
79    ($fn_name:ident, $op_name:ident) => {
80        pub fn $fn_name(&self) -> Result<Self> {
81            let shape = self.shape();
82            if shape.elem_count() == 0 {
83                return Ok(self.clone());
84            }
85            let storage = self
86                .storage()
87                .unary_impl::<crate::op::$op_name>(self.layout())?;
88            let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89            Ok(from_storage(storage, shape.clone(), op, false))
90        }
91    };
92}
93
94macro_rules! binary_op {
95    ($fn_name:ident, $op_name:ident) => {
96        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97            let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98            if shape.elem_count() == 0 {
99                return Ok(self.clone());
100            }
101            let storage = self.storage().binary_impl::<crate::op::$op_name>(
102                &*rhs.storage(),
103                self.layout(),
104                rhs.layout(),
105            )?;
106            let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107            Ok(from_storage(storage, shape.clone(), op, false))
108        }
109    };
110}
111
112macro_rules! binary_op_scalar {
113    ($fn_name:ident, $op_name:ident) => {
114        pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115            let rhs = match rhs.to_tensor_scalar()? {
116                crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117                crate::scalar::TensorScalar::Scalar(rhs) => rhs
118                    .to_dtype(self.dtype())?
119                    .to_device(self.device())?
120                    .broadcast_as(self.shape())?,
121            };
122            let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123            if self.elem_count() == 0 {
124                return Ok(self.clone());
125            }
126            let storage = self.storage().binary_impl::<crate::op::$op_name>(
127                &*rhs.storage(),
128                self.layout(),
129                rhs.layout(),
130            )?;
131            let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132            Ok(from_storage(storage, shape.clone(), op, false))
133        }
134    };
135}
136
137macro_rules! broadcast_binary_op {
138    ($fn_name:ident, $inner_fn_name:ident) => {
139        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140            let lhs = self;
141            let shape = lhs
142                .shape()
143                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144            let l_broadcast = shape != *lhs.shape();
145            let r_broadcast = shape != *rhs.shape();
146            match (l_broadcast, r_broadcast) {
147                (true, true) => lhs
148                    .broadcast_as(&shape)?
149                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152                (false, false) => lhs.$inner_fn_name(rhs),
153            }
154        }
155    };
156}
157
158/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
159pub(crate) fn from_storage<S: Into<Shape>>(
160    storage: Storage,
161    shape: S,
162    op: BackpropOp,
163    is_variable: bool,
164) -> Tensor {
165    let dtype = storage.dtype();
166    let device = storage.device();
167    let tensor_ = Tensor_ {
168        id: TensorId::new(),
169        storage: Arc::new(RwLock::new(storage)),
170        layout: Layout::contiguous(shape),
171        op,
172        is_variable,
173        dtype,
174        device,
175    };
176    Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180    pub(crate) fn ones_impl<S: Into<Shape>>(
181        shape: S,
182        dtype: DType,
183        device: &Device,
184        is_variable: bool,
185    ) -> Result<Self> {
186        let none = BackpropOp::none();
187        let shape = shape.into();
188        let storage = device.ones(&shape, dtype)?;
189        Ok(from_storage(storage, shape, none, is_variable))
190    }
191
192    /// Creates a new tensor filled with ones.
193    ///
194    /// ```rust
195    /// use candle_core::{Tensor, DType, Device};
196    /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
197    /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
198    /// // a == b
199    /// # Ok::<(), candle_core::Error>(())
200    /// ```
201    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
202        Self::ones_impl(shape, dtype, device, false)
203    }
204
205    /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
206    ///
207    /// ```rust
208    /// use candle_core::{Tensor, DType, Device};
209    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
210    /// let b = a.ones_like()?;
211    /// // b == a + 1
212    /// # Ok::<(), candle_core::Error>(())
213    /// ```
214    pub fn ones_like(&self) -> Result<Self> {
215        Tensor::ones(self.shape(), self.dtype(), self.device())
216    }
217
218    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
219    // the variable module.
220    pub(crate) fn zeros_impl<S: Into<Shape>>(
221        shape: S,
222        dtype: DType,
223        device: &Device,
224        is_variable: bool,
225    ) -> Result<Self> {
226        let none = BackpropOp::none();
227        let shape = shape.into();
228        let storage = device.zeros(&shape, dtype)?;
229        Ok(from_storage(storage, shape, none, is_variable))
230    }
231
232    /// Creates a new tensor filled with zeros.
233    ///
234    /// ```rust
235    /// use candle_core::{Tensor, DType, Device};
236    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
237    /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
238    /// // a == b
239    /// # Ok::<(), candle_core::Error>(())
240    /// ```
241    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
242        Self::zeros_impl(shape, dtype, device, false)
243    }
244
245    /// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
246    /// tensor.
247    ///
248    /// ```rust
249    /// use candle_core::{Tensor, DType, Device};
250    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
251    /// let b = a.zeros_like()?;
252    /// // b is on CPU f32.
253    /// # Ok::<(), candle_core::Error>(())
254    /// ```
255    pub fn zeros_like(&self) -> Result<Self> {
256        Tensor::zeros(self.shape(), self.dtype(), self.device())
257    }
258
259    pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
260        lo: T,
261        up: T,
262        s: S,
263        device: &Device,
264        is_variable: bool,
265    ) -> Result<Self> {
266        let s = s.into();
267        let storage = device.rand_uniform(lo, up, &s)?;
268        let none = BackpropOp::none();
269        Ok(from_storage(storage, s, none, is_variable))
270    }
271
272    pub(crate) fn rand_f64_impl<S: Into<Shape>>(
273        lo: f64,
274        up: f64,
275        s: S,
276        dtype: DType,
277        device: &Device,
278        is_variable: bool,
279    ) -> Result<Self> {
280        let s = s.into();
281        let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
282        let none = BackpropOp::none();
283        Ok(from_storage(storage, s, none, is_variable))
284    }
285
286    /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
287    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
288        lo: T,
289        up: T,
290        s: S,
291        device: &Device,
292    ) -> Result<Self> {
293        Self::rand_impl(lo, up, s, device, false)
294    }
295
296    pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
297        Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
298    }
299
300    pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
301        mean: T,
302        std: T,
303        s: S,
304        device: &Device,
305        is_variable: bool,
306    ) -> Result<Self> {
307        let s = s.into();
308        let storage = device.rand_normal(mean, std, &s)?;
309        let none = BackpropOp::none();
310        Ok(from_storage(storage, s, none, is_variable))
311    }
312
313    pub(crate) fn randn_f64_impl<S: Into<Shape>>(
314        mean: f64,
315        std: f64,
316        s: S,
317        dtype: DType,
318        device: &Device,
319        is_variable: bool,
320    ) -> Result<Self> {
321        let s = s.into();
322        let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
323        let none = BackpropOp::none();
324        Ok(from_storage(storage, s, none, is_variable))
325    }
326
327    pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
328        Tensor::randn_f64_impl(
329            mean,
330            stdev,
331            self.shape(),
332            self.dtype(),
333            self.device(),
334            false,
335        )
336    }
337
338    /// Creates a new tensor initialized with values sampled from a normal distribution with the
339    /// specified `mean` and standard deviation `std`.
340    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
341        mean: T,
342        std: T,
343        s: S,
344        device: &Device,
345    ) -> Result<Self> {
346        Self::randn_impl(mean, std, s, device, false)
347    }
348
349    pub(crate) fn new_impl<A: crate::device::NdArray>(
350        array: A,
351        shape: Shape,
352        device: &Device,
353        is_variable: bool,
354    ) -> Result<Self> {
355        let n: usize = shape.elem_count();
356        let buffer_size: usize = array.shape()?.elem_count();
357        if buffer_size != n {
358            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
359        }
360        let storage = device.storage(array)?;
361        let none = BackpropOp::none();
362        Ok(from_storage(storage, shape, none, is_variable))
363    }
364
365    /// Creates a new tensor on the specified device using the content and shape of the input.
366    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
367        let shape = array.shape()?;
368        Self::new_impl(array, shape, device, false)
369    }
370
371    /// Returns a new tensor with all the elements having the same specified value. Note that
372    /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
373    ///```rust
374    /// use candle_core::{Tensor, Device};
375    /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
376    ///
377    /// assert_eq!(a.to_vec2::<f64>()?, &[
378    ///     [3.5, 3.5, 3.5, 3.5],
379    ///     [3.5, 3.5, 3.5, 3.5],
380    /// ]);
381    /// # Ok::<(), candle_core::Error>(())
382    pub fn full<D: crate::WithDType, S: Into<Shape>>(
383        value: D,
384        shape: S,
385        device: &Device,
386    ) -> Result<Self> {
387        Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
388    }
389
390    /// Creates a new 1D tensor from an iterator.
391    ///```rust
392    /// use candle_core::{Tensor, Device};
393    /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
394    ///
395    /// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
396    /// # Ok::<(), candle_core::Error>(())
397    /// ```
398    pub fn from_iter<D: crate::WithDType>(
399        iter: impl IntoIterator<Item = D>,
400        device: &Device,
401    ) -> Result<Self> {
402        let data = iter.into_iter().collect::<Vec<_>>();
403        let len = data.len();
404        Self::from_vec_impl(data, len, device, false)
405    }
406
407    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
408    /// difference `1` from `start`.
409    ///```rust
410    /// use candle_core::{Tensor, Device};
411    /// let a = Tensor::arange(2., 5., &Device::Cpu)?;
412    ///
413    /// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
414    /// # Ok::<(), candle_core::Error>(())
415    /// ```
416    pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
417        Self::arange_step(start, end, D::one(), device)
418    }
419
420    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
421    /// difference `step` from `start`.
422    ///```rust
423    /// use candle_core::{Tensor, Device};
424    /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
425    ///
426    /// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
427    /// # Ok::<(), candle_core::Error>(())
428    /// ```
429    pub fn arange_step<D: crate::WithDType>(
430        start: D,
431        end: D,
432        step: D,
433        device: &Device,
434    ) -> Result<Self> {
435        if D::is_zero(&step) {
436            bail!("step cannot be zero")
437        }
438        let mut data = vec![];
439        let mut current = start;
440        if step >= D::zero() {
441            while current < end {
442                data.push(current);
443                current += step;
444            }
445        } else {
446            while current > end {
447                data.push(current);
448                current += step;
449            }
450        }
451        let len = data.len();
452        Self::from_vec_impl(data, len, device, false)
453    }
454
455    pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
456        data: Vec<D>,
457        shape: S,
458        device: &Device,
459        is_variable: bool,
460    ) -> Result<Self> {
461        let shape = shape.into();
462        let buffer_size = data.len();
463        if buffer_size != shape.elem_count() {
464            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
465        }
466        let storage = device.storage_owned(data)?;
467        let none = BackpropOp::none();
468        Ok(from_storage(storage, shape, none, is_variable))
469    }
470
471    /// Creates a new tensor initialized with values from the input vector. The number of elements
472    /// in this vector must be the same as the number of elements defined by the shape.
473    /// If the device is cpu, no data copy is made.
474    ///```rust
475    /// use candle_core::{Tensor, Device};
476    /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
477    ///
478    /// assert_eq!(a.to_vec2::<f64>()?, &[
479    ///     [1., 2., 3.],
480    ///     [4., 5., 6.]
481    /// ]);
482    /// # Ok::<(), candle_core::Error>(())
483    /// ```
484    pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
485        data: Vec<D>,
486        shape: S,
487        device: &Device,
488    ) -> Result<Self> {
489        Self::from_vec_impl(data, shape, device, false)
490    }
491
492    /// Creates a new tensor initialized with values from the input slice. The number of elements
493    /// in this vector must be the same as the number of elements defined by the shape.
494    ///```rust
495    /// use candle_core::{Tensor, Device};
496    /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
497    /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
498    ///
499    /// assert_eq!(a.to_vec2::<f64>()?, &[
500    ///     [2., 3., 4.],
501    ///     [5., 6., 7.]
502    /// ]);
503    /// # Ok::<(), candle_core::Error>(())
504    /// ```
505    pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
506        array: &[D],
507        shape: S,
508        device: &Device,
509    ) -> Result<Self> {
510        let shape = shape.into();
511        let n: usize = shape.elem_count();
512        let buffer_size: usize = array.len();
513        if buffer_size != n {
514            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
515        }
516        let storage = device.storage_from_slice(array)?;
517        let none = BackpropOp::none();
518        Ok(from_storage(storage, shape, none, false))
519    }
520
521    pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
522        let lhs = self.shape();
523        let rhs = rhs.shape();
524        if lhs != rhs {
525            Err(Error::ShapeMismatchBinaryOp {
526                lhs: lhs.clone(),
527                rhs: rhs.clone(),
528                op,
529            }
530            .bt())
531        } else {
532            Ok(lhs)
533        }
534    }
535
536    /// Returns true if the computation graph should track this op, that is if it is
537    /// a variable or if it has some variable as dependencies.
538    pub fn track_op(&self) -> bool {
539        self.is_variable || self.op.is_some()
540    }
541
542    // TODO: Also make an inplace version or a pre-allocated? This could be tricky
543    // if this can create cycles in the compute graph.
544    binary_op!(add, Add);
545    binary_op!(mul, Mul);
546    binary_op!(sub, Sub);
547    binary_op!(div, Div);
548    binary_op_scalar!(maximum, Maximum);
549    binary_op_scalar!(minimum, Minimum);
550    broadcast_binary_op!(broadcast_add, add);
551    broadcast_binary_op!(broadcast_mul, mul);
552    broadcast_binary_op!(broadcast_sub, sub);
553    broadcast_binary_op!(broadcast_div, div);
554    broadcast_binary_op!(broadcast_maximum, maximum);
555    broadcast_binary_op!(broadcast_minimum, minimum);
556    broadcast_binary_op!(broadcast_eq, eq);
557    broadcast_binary_op!(broadcast_ne, ne);
558    broadcast_binary_op!(broadcast_lt, lt);
559    broadcast_binary_op!(broadcast_le, le);
560    broadcast_binary_op!(broadcast_gt, gt);
561    broadcast_binary_op!(broadcast_ge, ge);
562
563    unary_op!(recip, Recip);
564    unary_op!(neg, Neg);
565    unary_op!(exp, Exp);
566    unary_op!(log, Log);
567    unary_op!(sin, Sin);
568    unary_op!(cos, Cos);
569    unary_op!(tanh, Tanh);
570    unary_op!(abs, Abs);
571    unary_op!(sqr, Sqr);
572    unary_op!(sqrt, Sqrt);
573    unary_op!(gelu, Gelu);
574    unary_op!(gelu_erf, GeluErf);
575    unary_op!(erf, Erf);
576    unary_op!(relu, Relu);
577    unary_op!(silu, Silu);
578    unary_op!(ceil, Ceil);
579    unary_op!(floor, Floor);
580    unary_op!(round, Round);
581    unary_op!(sign, Sign);
582
583    /// Round element of the input tensor to the nearest integer.
584    ///
585    /// If the number of decimals is negative, it specifies the number of positions to the left of
586    /// the decimal point.
587    pub fn round_to(&self, decimals: i32) -> Result<Self> {
588        let mult = 10f64.powi(decimals);
589        (self * mult)?.round()? * (1f64 / mult)
590    }
591
592    /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
593    /// dimensions, an error is returned instead.
594    pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
595        if self.rank() != 0 {
596            Err(Error::UnexpectedNumberOfDims {
597                expected: 0,
598                got: self.rank(),
599                shape: self.shape().clone(),
600            }
601            .bt())?
602        }
603        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
604            let data = S::cpu_storage_as_slice(cpu_storage)?;
605            Ok::<_, Error>(data[self.layout().start_offset()])
606        };
607        match &*self.storage() {
608            Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
609            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
610            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
611        }
612    }
613
614    /// An alias for `to_scalar`.
615    pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
616        self.to_scalar::<S>()
617    }
618
619    /// Repeat this tensor along the specified dimensions.
620    pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
621        // Similar to PyTorch, we extend the number of dimensions of self if needed.
622        let repeats = shape.into();
623        let repeats = repeats.dims();
624        let mut inp = if self.rank() < repeats.len() {
625            let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
626            self.reshape(shape)?
627        } else {
628            self.clone()
629        };
630        for (idx, &repeat) in repeats.iter().enumerate() {
631            if repeat > 1 {
632                inp = Tensor::cat(&vec![&inp; repeat], idx)?
633            }
634        }
635        Ok(inp)
636    }
637
638    /// Creates grids of coordinates specified by the 1D inputs.
639    ///
640    /// # Arguments
641    ///
642    /// * `args` - A slice of 1D tensors.
643    /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
644    ///   first dimension corresponds to the cardinality of the second input and the second
645    ///   dimension corresponds to the cardinality of the first input. If ij is selected, the
646    ///   dimensions are in the same order as the cardinality of the inputs.
647    ///
648    /// # Examples
649    ///
650    /// ```rust
651    /// use candle_core::{Tensor, Device, Shape};
652    /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
653    /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
654    ///
655    /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
656    ///
657    /// assert_eq!(grids_xy.len(), 2);
658    /// assert_eq!(grids_xy[0].dims(), &[3, 3]);
659    ///
660    /// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
661    /// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
662    ///
663    /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
664    ///
665    /// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
666    /// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
667    /// # Ok::<(), candle_core::Error>(())
668    /// ```
669    ///
670    /// # Errors
671    ///
672    /// * Will return `Err` if `args` contains less than 2 tensors.
673    ///
674    pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
675        if args.len() <= 1 {
676            Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
677        }
678        let args: Vec<_> = if xy_indexing {
679            args.iter().rev().collect()
680        } else {
681            args.iter().collect()
682        };
683
684        let mut shape = Vec::with_capacity(args.len());
685        for arg in args.iter() {
686            shape.push(arg.as_ref().dims1()?)
687        }
688
689        let mut grids = Vec::with_capacity(args.len());
690        for idx in 0..args.len() {
691            let mut ones = vec![1usize; args.len()];
692            ones[idx] = shape[idx];
693            let arg = args[idx].as_ref().reshape(ones)?;
694            let mut repeats = shape.clone();
695            repeats[idx] = 1;
696            let repeated_tensor = arg.repeat(repeats)?;
697            grids.push(repeated_tensor);
698        }
699        if xy_indexing {
700            grids.reverse();
701        }
702        Ok(grids)
703    }
704
705    /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
706    /// The input values `mul` and `add` are casted to the appropriate type so some rounding might
707    /// be performed.
708    ///
709    /// ```rust
710    /// use candle_core::{Tensor, Device};
711    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
712    /// let a = a.affine(4., -2.)?;
713    /// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
714    /// # Ok::<(), candle_core::Error>(())
715    /// ```
716    pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
717        if self.elem_count() == 0 {
718            return Ok(self.clone());
719        }
720        let storage = self.storage().affine(self.layout(), mul, add)?;
721        let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
722        Ok(from_storage(storage, self.shape(), op, false))
723    }
724
725    /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
726    pub fn elu(&self, alpha: f64) -> Result<Self> {
727        if self.elem_count() == 0 {
728            return Ok(self.clone());
729        }
730        let storage = self.storage().elu(self.layout(), alpha)?;
731        let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
732        Ok(from_storage(storage, self.shape(), op, false))
733    }
734
735    /// Raise the tensor to some float exponent `e`.
736    pub fn powf(&self, e: f64) -> Result<Self> {
737        if self.elem_count() == 0 {
738            return Ok(self.clone());
739        }
740        let storage = self.storage().powf(self.layout(), e)?;
741        let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
742        Ok(from_storage(storage, self.shape(), op, false))
743    }
744
745    pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
746        if dim >= self.dims().len() {
747            Err(Error::DimOutOfRange {
748                shape: self.shape().clone(),
749                dim: dim as i32,
750                op,
751            }
752            .bt())?
753        } else {
754            Ok(())
755        }
756    }
757
758    /// Split a tensor into the specified number of chunks, this may return less chunks than
759    /// specified.
760    pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
761        let dim = dim.to_index(self.shape(), "chunk")?;
762        let size = self.dim(dim)?;
763        if size < chunks {
764            (0..size).map(|i| self.narrow(dim, i, 1)).collect()
765        } else {
766            let chunk_size = size / chunks;
767            let cnt_additional = size % chunks;
768            let mut tensors = vec![];
769            let mut sum_chunk_size = 0;
770            for i in 0..chunks {
771                let chunk_size = if i < cnt_additional {
772                    chunk_size + 1
773                } else {
774                    chunk_size
775                };
776                let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
777                tensors.push(tensor);
778                sum_chunk_size += chunk_size
779            }
780            Ok(tensors)
781        }
782    }
783
784    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
785    /// ranges from `start` to `start + len`.
786    /// ```
787    /// use candle_core::{Tensor, Device};
788    /// let a = Tensor::new(&[
789    ///     [0f32, 1., 2.],
790    ///     [3.  , 4., 5.],
791    ///     [6.  , 7., 8.]
792    /// ], &Device::Cpu)?;
793    ///
794    /// let b = a.narrow(0, 1, 2)?;
795    /// assert_eq!(b.shape().dims(), &[2, 3]);
796    /// assert_eq!(b.to_vec2::<f32>()?, &[
797    ///     [3., 4., 5.],
798    ///     [6., 7., 8.]
799    /// ]);
800    ///
801    /// let c = a.narrow(1, 1, 1)?;
802    /// assert_eq!(c.shape().dims(), &[3, 1]);
803    /// assert_eq!(c.to_vec2::<f32>()?, &[
804    ///     [1.],
805    ///     [4.],
806    ///     [7.]
807    /// ]);
808    /// # Ok::<(), candle_core::Error>(())
809    /// ```
810    pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
811        let dims = self.dims();
812        let dim = dim.to_index(self.shape(), "narrow")?;
813        let err = |msg| {
814            Err::<(), _>(
815                Error::NarrowInvalidArgs {
816                    shape: self.shape().clone(),
817                    dim,
818                    start,
819                    len,
820                    msg,
821                }
822                .bt(),
823            )
824        };
825        if start > dims[dim] {
826            err("start > dim_len")?
827        }
828        if start.saturating_add(len) > dims[dim] {
829            err("start + len > dim_len")?
830        }
831        if start == 0 && dims[dim] == len {
832            Ok(self.clone())
833        } else {
834            let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
835            let layout = self.layout().narrow(dim, start, len)?;
836            let tensor_ = Tensor_ {
837                id: TensorId::new(),
838                storage: self.storage.clone(),
839                layout,
840                op,
841                is_variable: false,
842                dtype: self.dtype,
843                device: self.device.clone(),
844            };
845            Ok(Tensor(Arc::new(tensor_)))
846        }
847    }
848
849    fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
850        match dims {
851            [] => Ok(self),
852            [i] => self.squeeze(*i),
853            dims => {
854                let dims = self
855                    .dims()
856                    .iter()
857                    .enumerate()
858                    .filter_map(|(dim_idx, &v)| {
859                        if dims.contains(&dim_idx) {
860                            None
861                        } else {
862                            Some(v)
863                        }
864                    })
865                    .collect::<Vec<_>>();
866                self.reshape(dims)
867            }
868        }
869    }
870
871    fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
872        let dim = dim.to_index(self.shape(), op.name())?;
873        let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
874        let mut dims = self.dims().to_vec();
875        dims[dim] = 1;
876        let op = match op {
877            ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
878                BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
879            }
880            ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
881        };
882        let res = from_storage(storage, dims, op, false);
883        if keepdim {
884            Ok(res)
885        } else {
886            res.squeeze_dims(&[dim])
887        }
888    }
889
890    fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
891        let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
892        let storage = self
893            .storage()
894            .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
895        let mut dims = self.dims().to_vec();
896        for &sum_dim in sum_dims.iter() {
897            dims[sum_dim] = 1
898        }
899        let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
900        let sum = from_storage(storage, dims, op, false);
901        if keepdim {
902            Ok(sum)
903        } else {
904            sum.squeeze_dims(&sum_dims)
905        }
906    }
907
908    /// Roll the tensor input along the given dimension.
909    /// Elements that are shifted beyond the last position are re-introduced at the first position.
910    ///
911    /// ```rust
912    /// # use candle_core::{Tensor, Device};
913    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
914    /// let tensor = tensor.roll(1, 0)?;
915    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
916    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
917    /// let tensor = tensor.roll(-1, 0)?;
918    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
919    /// # Ok::<(), candle_core::Error>(())
920    /// ```
921    pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
922    where
923        D: Dim + Clone,
924    {
925        let dim = dim.to_index(self.shape(), "roll")?;
926        let dim_size = self.dim(dim)?;
927        let shift = shift.rem_euclid(dim_size as i32) as usize;
928        if shift == 0 {
929            Ok(self.clone())
930        } else {
931            let a = self.narrow(dim, 0, dim_size - shift)?;
932            let b = self.narrow(dim, dim_size - shift, shift)?;
933            Tensor::cat(&[&b, &a], dim)
934        }
935    }
936
937    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
938    /// input dimensions.
939    ///
940    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
941    /// that the number of elements for each dimension index in `sum_dims` is 1.
942    ///
943    /// ```rust
944    /// use candle_core::{Tensor, Device};
945    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
946    /// let s = a.sum_keepdim(0)?;
947    /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
948    /// let s = a.sum_keepdim(1)?;
949    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
950    /// let s = a.sum_keepdim((0, 1))?;
951    /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
952    /// # Ok::<(), candle_core::Error>(())
953    /// ```
954    pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
955        self.sum_impl(sum_dims, true)
956    }
957
958    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
959    /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
960    /// kept.
961    pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
962        self.sum_impl(sum_dims, false)
963    }
964
965    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
966    /// input dimensions.
967    ///
968    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
969    /// that the number of elements for each dimension index in `mean_dims` is 1.
970    ///
971    /// ```rust
972    /// use candle_core::{Tensor, Device};
973    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
974    /// let s = a.mean_keepdim(0)?;
975    /// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
976    /// let s = a.mean_keepdim(1)?;
977    /// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
978    /// let s = a.mean_keepdim((0, 1))?;
979    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
980    /// # Ok::<(), candle_core::Error>(())
981    /// ```
982    pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
983        let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
984        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
985        let scale = 1f64 / (reduced_dim as f64);
986        self.sum_impl(mean_dims, true)? * scale
987    }
988
989    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
990    /// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
991    /// kept.
992    pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
993        let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
994        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
995        let scale = 1f64 / (reduced_dim as f64);
996        self.sum_impl(mean_dims, false)? * scale
997    }
998
999    /// Returns the unbiased variance over the selected dimension.
1000    pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1001        let dim = dim.to_index(self.shape(), "var")?;
1002        let mean = self.mean_keepdim(dim)?;
1003        let squares = self.broadcast_sub(&mean)?.sqr()?;
1004        squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1005    }
1006
1007    /// Returns the unbiased variance over the selected dimension.
1008    pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1009        let dim = dim.to_index(self.shape(), "var")?;
1010        self.var_keepdim(dim)?.squeeze(dim)
1011    }
1012
1013    /// Gathers the maximum value across the selected dimension. The resulting shape has the same
1014    /// number of dimensions as the original tensor and the select dimension has a single element.
1015    pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1016        self.reduce_impl(dim, true, ReduceOp::Max)
1017    }
1018
1019    /// Similar to `max_keepdim` but the target dimension is squeezed.
1020    pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1021        self.reduce_impl(dim, false, ReduceOp::Max)
1022    }
1023
1024    /// Gathers the minimum value across the selected dimension. The resulting shape has the same
1025    /// number of dimensions as the original tensor and the select dimension has a single element.
1026    pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1027        self.reduce_impl(dim, true, ReduceOp::Min)
1028    }
1029
1030    /// Similar to `min_keepdim` but the target dimension is squeezed.
1031    pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1032        self.reduce_impl(dim, false, ReduceOp::Min)
1033    }
1034
1035    pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1036        self.reduce_impl(dim, true, ReduceOp::ArgMax)
1037    }
1038
1039    /// Similar to `argmax_keepdim` but the target dimension is squeezed.
1040    pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1041        self.reduce_impl(dim, false, ReduceOp::ArgMax)
1042    }
1043
1044    pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1045        self.reduce_impl(dim, true, ReduceOp::ArgMin)
1046    }
1047
1048    /// Similar to `argmin_keepdim` but the target dimension is squeezed.
1049    pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1050        self.reduce_impl(dim, false, ReduceOp::ArgMin)
1051    }
1052
1053    /// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
1054    /// comparison operation is specified by the `op` argument.
1055    ///
1056    /// The returned tensor has the same shape as the original tensors and uses `u8` elements.
1057    pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1058        let rhs = match rhs.to_tensor_scalar()? {
1059            crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1060            crate::scalar::TensorScalar::Scalar(rhs) => rhs
1061                .to_dtype(self.dtype())?
1062                .to_device(self.device())?
1063                .broadcast_as(self.shape())?,
1064        };
1065        let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1066        let storage = self
1067            .storage()
1068            .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1069        let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1070        Ok(from_storage(storage, shape.dims(), op, false))
1071    }
1072
1073    /// Element-wise equality.
1074    pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1075        self.cmp(rhs, CmpOp::Eq)
1076    }
1077
1078    /// Element-wise non-equality.
1079    pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1080        self.cmp(rhs, CmpOp::Ne)
1081    }
1082
1083    /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
1084    /// rhs` and 0 otherwise.
1085    pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1086        self.cmp(rhs, CmpOp::Lt)
1087    }
1088
1089    /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
1090    /// rhs` and 0 otherwise.
1091    pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1092        self.cmp(rhs, CmpOp::Gt)
1093    }
1094
1095    /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
1096    /// rhs` and 0 otherwise.
1097    pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1098        self.cmp(rhs, CmpOp::Ge)
1099    }
1100
1101    /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
1102    /// rhs` and 0 otherwise.
1103    pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1104        self.cmp(rhs, CmpOp::Le)
1105    }
1106
1107    /// Clamp the tensor values to be between `min` and `max`.
1108    pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1109        self.maximum(min)?.minimum(max)
1110    }
1111
1112    /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
1113    ///
1114    /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
1115    /// tensor also has three dimensions, `(batch, channels, target_size)`.
1116    pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1117        let (n, c, _l) = self.dims3()?;
1118        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1119        let storage = self
1120            .storage()
1121            .upsample_nearest1d(self.layout(), target_size)?;
1122        Ok(from_storage(storage, (n, c, target_size), op, false))
1123    }
1124
1125    /// Alias for `interpolate1d`.
1126    pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1127        self.interpolate1d(target_size)
1128    }
1129
1130    /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
1131    /// nearest element.
1132    ///
1133    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1134    /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
1135    pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1136        let (n, c, _h, _w) = self.dims4()?;
1137        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1138            arg,
1139            target_h,
1140            target_w,
1141        });
1142        let storage = self
1143            .storage()
1144            .upsample_nearest2d(self.layout(), target_h, target_w)?;
1145        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1146    }
1147
1148    /// Alias for `interpolate2d`.
1149    pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1150        self.interpolate2d(target_h, target_w)
1151    }
1152
1153    /// 2D average pooling over an input tensor with multiple channels.
1154    ///
1155    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1156    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1157    /// the two last dimensions using a kernel of size `sz`. The returned element is the average
1158    /// value over the kernel window.
1159    pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1160        let sz = sz.to_usize2();
1161        self.avg_pool2d_with_stride(sz, sz)
1162    }
1163
1164    /// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
1165    /// kernel size.
1166    pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1167        &self,
1168        kernel_size: T,
1169        stride: T,
1170    ) -> Result<Self> {
1171        let kernel_size = kernel_size.to_usize2();
1172        let stride = stride.to_usize2();
1173        let (n, c, h, w) = self.dims4()?;
1174        if h < kernel_size.0 || w < kernel_size.1 {
1175            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1176        }
1177        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
1178        let h_out = (h - kernel_size.0) / stride.0 + 1;
1179        let w_out = (w - kernel_size.1) / stride.1 + 1;
1180        let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1181            arg,
1182            kernel_size,
1183            stride,
1184        });
1185        let storage = self
1186            .storage()
1187            .avg_pool2d(self.layout(), kernel_size, stride)?;
1188        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1189    }
1190
1191    /// 2D max pooling over an input tensor with multiple channels.
1192    ///
1193    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1194    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1195    /// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
1196    /// value over the kernel window.
1197    pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1198        let sz = sz.to_usize2();
1199        self.max_pool2d_with_stride(sz, sz)
1200    }
1201
1202    /// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
1203    /// kernel size.
1204    pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1205        &self,
1206        kernel_size: T,
1207        stride: T,
1208    ) -> Result<Self> {
1209        let kernel_size = kernel_size.to_usize2();
1210        let stride = stride.to_usize2();
1211        let (n, c, h, w) = self.dims4()?;
1212        if h < kernel_size.0 || w < kernel_size.1 {
1213            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1214        }
1215        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
1216        let h_out = (h - kernel_size.0) / stride.0 + 1;
1217        let w_out = (w - kernel_size.1) / stride.1 + 1;
1218        let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1219            arg,
1220            kernel_size,
1221            stride,
1222        });
1223        let storage = self
1224            .storage()
1225            .max_pool2d(self.layout(), kernel_size, stride)?;
1226        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1227    }
1228
1229    /// Returns the matrix-multiplication of the input tensor with the other provided tensor.
1230    ///
1231    /// # Arguments
1232    ///
1233    /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
1234    /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
1235    ///
1236    /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
1237    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1238        let a_dims = self.shape().dims();
1239        let b_dims = rhs.shape().dims();
1240
1241        let dim = a_dims.len();
1242
1243        if dim < 2 || b_dims.len() != dim {
1244            Err(Error::ShapeMismatchBinaryOp {
1245                lhs: self.shape().clone(),
1246                rhs: rhs.shape().clone(),
1247                op: "matmul",
1248            }
1249            .bt())?
1250        }
1251
1252        let m = a_dims[dim - 2];
1253        let k = a_dims[dim - 1];
1254        let k2 = b_dims[dim - 2];
1255        let n = b_dims[dim - 1];
1256
1257        let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1258        if c_shape.elem_count() == 0 || k == 0 {
1259            return Tensor::zeros(c_shape, self.dtype(), self.device());
1260        }
1261        let batching: usize = a_dims[..dim - 2].iter().product();
1262        let batching_b: usize = b_dims[..dim - 2].iter().product();
1263        if k != k2 || batching != batching_b {
1264            Err(Error::ShapeMismatchBinaryOp {
1265                lhs: self.shape().clone(),
1266                rhs: rhs.shape().clone(),
1267                op: "matmul",
1268            }
1269            .bt())?
1270        }
1271
1272        let storage = self.storage().matmul(
1273            &rhs.storage(),
1274            (batching, m, n, k),
1275            self.layout(),
1276            rhs.layout(),
1277        )?;
1278        let op = BackpropOp::new2(self, rhs, Op::Matmul);
1279        Ok(from_storage(storage, c_shape, op, false))
1280    }
1281
1282    /// Matrix-multiplication with broadcasting support.
1283    ///
1284    /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
1285    /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
1286    /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
1287    pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1288        let lhs = self;
1289        let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1290        let l_broadcast = l_shape != *lhs.shape();
1291        let r_broadcast = r_shape != *rhs.shape();
1292        // TODO: Avoid concretising the broadcasted matrixes via contiguous.
1293        match (l_broadcast, r_broadcast) {
1294            (true, true) => lhs
1295                .broadcast_as(&l_shape)?
1296                .contiguous()?
1297                .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1298            (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1299            (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1300            (false, false) => lhs.matmul(rhs),
1301        }
1302    }
1303
1304    /// Returns a tensor with the same shape as the input tensor, the values are taken from
1305    /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
1306    /// input tensor is equal to zero.
1307    pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1308        let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1309        let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1310        let storage = self.storage().where_cond(
1311            self.layout(),
1312            &on_true.storage(),
1313            on_true.layout(),
1314            &on_false.storage(),
1315            on_false.layout(),
1316        )?;
1317        let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1318        Ok(from_storage(storage, shape, op, false))
1319    }
1320
1321    /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
1322    /// values hold in the `ids` tensor.
1323    ///
1324    /// # Arguments
1325    ///
1326    /// * `self` - A tensor with dimensions `v, h`.
1327    /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
1328    ///
1329    /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
1330    /// vocabulary size, and `h` the hidden size.
1331    ///
1332    /// ```rust
1333    /// use candle_core::{Tensor, Device};
1334    /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1335    /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
1336    /// let emb = values.embedding(&ids)?;
1337    /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
1338    /// # Ok::<(), candle_core::Error>(())
1339    /// ```
1340    pub fn embedding(&self, ids: &Self) -> Result<Self> {
1341        if self.rank() != 2 || ids.rank() != 1 {
1342            Err(Error::ShapeMismatchBinaryOp {
1343                lhs: self.shape().clone(),
1344                rhs: ids.shape().clone(),
1345                op: "embedding",
1346            }
1347            .bt())?
1348        }
1349        self.index_select(ids, 0)
1350    }
1351
1352    pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1353        let dim = dim.to_index(self.shape(), "scatter-add")?;
1354        let source_dims = source.dims();
1355        let self_dims = self.dims();
1356        let mismatch = if source_dims.len() != self_dims.len() {
1357            true
1358        } else {
1359            let mut mismatch = false;
1360            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1361                if i != dim && d1 != d2 {
1362                    mismatch = true;
1363                    break;
1364                }
1365            }
1366            mismatch
1367        };
1368        if mismatch {
1369            Err(Error::ShapeMismatchBinaryOp {
1370                op: "scatter-add (self, src)",
1371                lhs: self.shape().clone(),
1372                rhs: source.shape().clone(),
1373            }
1374            .bt())?
1375        }
1376        if indexes.dims() != source.dims() {
1377            Err(Error::ShapeMismatchBinaryOp {
1378                op: "scatter-add (indexes, src)",
1379                lhs: indexes.shape().clone(),
1380                rhs: source.shape().clone(),
1381            }
1382            .bt())?
1383        }
1384        let storage = self.storage().scatter_add(
1385            self.layout(),
1386            &indexes.storage(),
1387            indexes.layout(),
1388            &source.storage(),
1389            source.layout(),
1390            dim,
1391        )?;
1392        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1393            Op::ScatterAdd(t1, t2, t3, dim)
1394        });
1395        Ok(from_storage(storage, self.shape(), op, false))
1396    }
1397
1398    /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
1399    pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1400        let dim = dim.to_index(self.shape(), "slice-scatter")?;
1401        if dim == 0 {
1402            self.slice_scatter0(src, start)
1403        } else {
1404            // TODO: Maybe we want to add a more efficient implementation at some point.
1405            self.transpose(0, dim)?
1406                .slice_scatter0(&src.transpose(0, dim)?, start)?
1407                .transpose(0, dim)
1408        }
1409    }
1410
1411    /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
1412    pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1413        if self.dtype() != src.dtype() {
1414            Err(Error::DTypeMismatchBinaryOp {
1415                lhs: self.dtype(),
1416                rhs: src.dtype(),
1417                op: "slice-scatter",
1418            }
1419            .bt())?
1420        }
1421        if self.device().location() != src.device.location() {
1422            Err(Error::DeviceMismatchBinaryOp {
1423                lhs: self.device().location(),
1424                rhs: src.device().location(),
1425                op: "slice-scatter",
1426            }
1427            .bt())?
1428        }
1429        if self.rank() != src.rank() {
1430            Err(Error::UnexpectedNumberOfDims {
1431                expected: self.rank(),
1432                got: src.rank(),
1433                shape: src.shape().clone(),
1434            }
1435            .bt())?
1436        }
1437        let shape_ok =
1438            self.dims()
1439                .iter()
1440                .zip(src.dims().iter())
1441                .enumerate()
1442                .all(|(dim_idx, (&d1, &d2))| {
1443                    if 0 == dim_idx {
1444                        d2 + start <= d1
1445                    } else {
1446                        d1 == d2
1447                    }
1448                });
1449        if !shape_ok {
1450            Err(Error::ShapeMismatchBinaryOp {
1451                op: "slice-scatter (self, src)",
1452                lhs: self.shape().clone(),
1453                rhs: src.shape().clone(),
1454            }
1455            .bt())?
1456        }
1457        let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1458        self.storage()
1459            .copy_strided_src(&mut storage, 0, self.layout())?;
1460        let offset = start * src.dims()[1..].iter().product::<usize>();
1461        src.storage()
1462            .copy_strided_src(&mut storage, offset, src.layout())?;
1463        let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1464        Ok(from_storage(storage, self.shape(), op, false))
1465    }
1466
1467    /// Accumulate element from `source` at indexes `indexes` and add them to `self`.
1468    pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1469        let dim = dim.to_index(self.shape(), "index-add")?;
1470        let source_dims = source.dims();
1471        let self_dims = self.dims();
1472        let mismatch = if source_dims.len() != self_dims.len() {
1473            true
1474        } else {
1475            let mut mismatch = false;
1476            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1477                if i != dim && d1 != d2 {
1478                    mismatch = true;
1479                    break;
1480                }
1481            }
1482            mismatch
1483        };
1484        if mismatch {
1485            Err(Error::ShapeMismatchBinaryOp {
1486                op: "index-add (self, source)",
1487                lhs: self.shape().clone(),
1488                rhs: source.shape().clone(),
1489            }
1490            .bt())?
1491        }
1492        // The number of element in indexes must match the dimension on which the add is
1493        // performed on the source tensor (and the index values from `indexes` are taken from
1494        // the target tensor self)
1495        let indexes_len = indexes.dims1()?;
1496        if source_dims[dim] != indexes_len {
1497            Err(Error::ShapeMismatchBinaryOp {
1498                op: "index-add (ids, source))",
1499                lhs: indexes.shape().clone(),
1500                rhs: source.shape().clone(),
1501            }
1502            .bt())?
1503        }
1504        let storage = self.storage().index_add(
1505            self.layout(),
1506            &indexes.storage(),
1507            indexes.layout(),
1508            &source.storage(),
1509            source.layout(),
1510            dim,
1511        )?;
1512        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1513            Op::IndexAdd(t1, t2, t3, dim)
1514        });
1515        Ok(from_storage(storage, self.shape(), op, false))
1516    }
1517
1518    /// Gather values across the target dimension.
1519    ///
1520    /// # Arguments
1521    ///
1522    /// * `self` - The input tensor.
1523    /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
1524    ///   and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
1525    /// * `dim` - the target dimension.
1526    ///
1527    /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
1528    /// dimension `dim` by the values in `indexes`.
1529    pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1530        let dim = dim.to_index(self.shape(), "gather")?;
1531
1532        let self_dims = self.dims();
1533        let indexes_dims = indexes.dims();
1534        let mismatch = if indexes_dims.len() != self_dims.len() {
1535            true
1536        } else {
1537            let mut mismatch = false;
1538            for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1539                if i != dim && d1 < d2 {
1540                    mismatch = true;
1541                    break;
1542                }
1543            }
1544            mismatch
1545        };
1546        if mismatch {
1547            Err(Error::ShapeMismatchBinaryOp {
1548                op: "gather",
1549                lhs: self.shape().clone(),
1550                rhs: indexes.shape().clone(),
1551            }
1552            .bt())?
1553        }
1554        let storage =
1555            self.storage()
1556                .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1557        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1558        Ok(from_storage(storage, indexes.shape(), op, false))
1559    }
1560
1561    /// Select values for the input tensor at the target indexes across the specified dimension.
1562    ///
1563    /// The `indexes` is argument is an int tensor with a single dimension.
1564    /// The output has the same number of dimension as the `self` input. The target dimension of
1565    /// the output has length the length of `indexes` and the values are taken from `self` using
1566    /// the index from `indexes`. Other dimensions have the same number of elements as the input
1567    /// tensor.
1568    pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1569        let dim = dim.to_index(self.shape(), "index-select")?;
1570        let indexes_len = match indexes.dims() {
1571            [l] => *l,
1572            _ => Err(Error::ShapeMismatchBinaryOp {
1573                lhs: self.shape().clone(),
1574                rhs: indexes.shape().clone(),
1575                op: "index-select",
1576            }
1577            .bt())?,
1578        };
1579        let storage = self.storage().index_select(
1580            &indexes.storage(),
1581            self.layout(),
1582            indexes.layout(),
1583            dim,
1584        )?;
1585        let mut dims = self.dims().to_vec();
1586        dims[dim] = indexes_len;
1587        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1588        Ok(from_storage(storage, dims, op, false))
1589    }
1590
1591    /// Returns an iterator over position of the elements in the storage when ranging over the
1592    /// index tuples in lexicographic order.
1593    pub fn strided_index(&self) -> crate::StridedIndex {
1594        self.layout.strided_index()
1595    }
1596
1597    /// Similar to `strided_index` but returns the position of the start of each contiguous block
1598    /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
1599    /// will only return the start offset and the size would be the number of elements in the
1600    /// tensor.
1601    pub fn strided_blocks(&self) -> crate::StridedBlocks {
1602        self.layout.strided_blocks()
1603    }
1604
1605    /// Returns the data contained in a 1D tensor as a vector of scalar values.
1606    pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1607        if self.rank() != 1 {
1608            Err(Error::UnexpectedNumberOfDims {
1609                expected: 1,
1610                got: self.rank(),
1611                shape: self.shape().clone(),
1612            }
1613            .bt())?
1614        }
1615        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1616            let data = S::cpu_storage_as_slice(cpu_storage)?;
1617            let data = match self.layout.contiguous_offsets() {
1618                Some((o1, o2)) => data[o1..o2].to_vec(),
1619                None => self.strided_index().map(|i| data[i]).collect(),
1620            };
1621            Ok::<Vec<_>, Error>(data)
1622        };
1623        match &*self.storage() {
1624            Storage::Cpu(storage) => from_cpu_storage(storage),
1625            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1626            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1627        }
1628    }
1629
1630    /// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
1631    pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1632        let (dim1, dim2) = self.dims2()?;
1633        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1634            let data = S::cpu_storage_as_slice(cpu_storage)?;
1635            let mut rows = vec![];
1636            match self.layout.contiguous_offsets() {
1637                Some((o1, o2)) => {
1638                    let data = &data[o1..o2];
1639                    for idx_row in 0..dim1 {
1640                        rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1641                    }
1642                }
1643                None => {
1644                    let mut src_index = self.strided_index();
1645                    for _idx_row in 0..dim1 {
1646                        let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1647                        rows.push(row)
1648                    }
1649                    assert!(src_index.next().is_none());
1650                }
1651            }
1652            Ok(rows)
1653        };
1654        match &*self.storage() {
1655            Storage::Cpu(storage) => from_cpu_storage(storage),
1656            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1657            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1658        }
1659    }
1660
1661    /// Returns the data contained in a 3D tensor.
1662    pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1663        let (dim1, dim2, dim3) = self.dims3()?;
1664        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1665            let data = S::cpu_storage_as_slice(cpu_storage)?;
1666            let mut top_rows = vec![];
1667            match self.layout.contiguous_offsets() {
1668                Some((o1, o2)) => {
1669                    let data = &data[o1..o2];
1670                    let dim23 = dim2 * dim3;
1671                    for idx1 in 0..dim1 {
1672                        let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1673                        let mut rows = vec![];
1674                        for idx2 in 0..dim2 {
1675                            rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
1676                        }
1677                        top_rows.push(rows);
1678                    }
1679                }
1680                None => {
1681                    let mut src_index = self.strided_index();
1682                    for _idx in 0..dim1 {
1683                        let mut rows = vec![];
1684                        for _jdx in 0..dim2 {
1685                            let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
1686                            rows.push(row)
1687                        }
1688                        top_rows.push(rows);
1689                    }
1690                    assert!(src_index.next().is_none());
1691                }
1692            }
1693            Ok(top_rows)
1694        };
1695        match &*self.storage() {
1696            Storage::Cpu(storage) => from_cpu_storage(storage),
1697            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1698            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1699        }
1700    }
1701
1702    /// The dtype for the elements stored in the input tensor.
1703    pub fn dtype(&self) -> DType {
1704        self.dtype
1705    }
1706
1707    /// The device on which the input tensor is located.
1708    pub fn device(&self) -> &Device {
1709        &self.device
1710    }
1711
1712    /// The tensor shape, i.e. dimension sizes on each axis.
1713    pub fn shape(&self) -> &Shape {
1714        self.layout().shape()
1715    }
1716
1717    /// The dimension size for this tensor on each axis.
1718    pub fn dims(&self) -> &[usize] {
1719        self.shape().dims()
1720    }
1721
1722    /// The dimension size for a specified dimension index.
1723    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
1724        let dim = dim.to_index(self.shape(), "dim")?;
1725        Ok(self.dims()[dim])
1726    }
1727
1728    /// The layout of the input tensor, this stores both the shape of the tensor as well as the
1729    /// strides and the start offset to apply to the underlying storage.
1730    pub fn layout(&self) -> &Layout {
1731        &self.layout
1732    }
1733
1734    pub fn stride(&self) -> &[usize] {
1735        self.layout.stride()
1736    }
1737
1738    /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
1739    pub fn rank(&self) -> usize {
1740        self.shape().rank()
1741    }
1742
1743    /// The number of elements stored in this tensor.
1744    pub fn elem_count(&self) -> usize {
1745        self.shape().elem_count()
1746    }
1747
1748    /// The unique identifier for this tensor.
1749    pub fn id(&self) -> TensorId {
1750        self.id
1751    }
1752
1753    /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is
1754    /// tracked and on which backpropagation can be performed.
1755    pub fn is_variable(&self) -> bool {
1756        self.is_variable
1757    }
1758
1759    pub(crate) fn op(&self) -> &Option<Op> {
1760        &self.op
1761    }
1762
1763    /// Computes the max of all the elements in this tensor and returns a tensor holding this
1764    /// scalar with zero dimensions.
1765    ///
1766    /// ```rust
1767    /// use candle_core::{Tensor, Device};
1768    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1769    /// let tensor = tensor.max_all()?;
1770    /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
1771    /// # Ok::<(), candle_core::Error>(())
1772    /// ```
1773    pub fn max_all(&self) -> Result<Tensor> {
1774        if self.rank() == 0 {
1775            Ok(self.clone())
1776        } else {
1777            self.flatten_all()?.max(0)
1778        }
1779    }
1780
1781    /// Computes the min of all the elements in this tensor and returns a tensor holding this
1782    /// scalar with zero dimensions.
1783    ///
1784    /// ```rust
1785    /// use candle_core::{Tensor, Device};
1786    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1787    /// let tensor = tensor.min_all()?;
1788    /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
1789    /// # Ok::<(), candle_core::Error>(())
1790    /// ```
1791    pub fn min_all(&self) -> Result<Tensor> {
1792        if self.rank() == 0 {
1793            Ok(self.clone())
1794        } else {
1795            self.flatten_all()?.min(0)
1796        }
1797    }
1798
1799    /// Computes the sum of all the elements in this tensor and returns a tensor holding this
1800    /// scalar with zero dimensions.
1801    ///
1802    /// ```rust
1803    /// use candle_core::{Tensor, Device};
1804    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1805    /// let tensor = tensor.sum_all()?;
1806    /// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
1807    /// # Ok::<(), candle_core::Error>(())
1808    /// ```
1809    pub fn sum_all(&self) -> Result<Tensor> {
1810        let dims: Vec<_> = (0..self.rank()).collect();
1811        self.sum(dims)
1812    }
1813
1814    pub fn mean_all(&self) -> Result<Tensor> {
1815        self.sum_all()? / self.elem_count() as f64
1816    }
1817
1818    fn flatten_<D1: Dim, D2: Dim>(
1819        &self,
1820        start_dim: Option<D1>,
1821        end_dim: Option<D2>,
1822    ) -> Result<Tensor> {
1823        if self.rank() == 0 {
1824            self.reshape(1)
1825        } else {
1826            let start_dim = match start_dim {
1827                None => 0,
1828                Some(dim) => dim.to_index(self.shape(), "flatten")?,
1829            };
1830            let end_dim = match end_dim {
1831                None => self.rank() - 1,
1832                Some(dim) => dim.to_index(self.shape(), "flatten")?,
1833            };
1834            if start_dim < end_dim {
1835                let dims = self.dims();
1836                let mut dst_dims = dims[..start_dim].to_vec();
1837                dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
1838                if end_dim + 1 < dims.len() {
1839                    dst_dims.extend(&dims[end_dim + 1..]);
1840                }
1841                self.reshape(dst_dims)
1842            } else {
1843                Ok(self.clone())
1844            }
1845        }
1846    }
1847
1848    /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
1849    /// inclusive).
1850    pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
1851        self.flatten_(Some(start_dim), Some(end_dim))
1852    }
1853
1854    /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
1855    pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
1856        self.flatten_(None::<usize>, Some(end_dim))
1857    }
1858
1859    /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
1860    /// dimension.
1861    pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
1862        self.flatten_(Some(start_dim), None::<usize>)
1863    }
1864
1865    /// Flattens the input tensor by reshaping it into a one dimension tensor.
1866    ///
1867    /// ```rust
1868    /// use candle_core::{Tensor, Device};
1869    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1870    /// let tensor = tensor.flatten_all()?;
1871    /// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
1872    /// # Ok::<(), candle_core::Error>(())
1873    /// ```
1874    pub fn flatten_all(&self) -> Result<Tensor> {
1875        self.flatten_(None::<usize>, None::<usize>)
1876    }
1877
1878    /// Returns the sub-tensor fixing the index at `i` on the first dimension.
1879    ///
1880    /// ```rust
1881    /// use candle_core::{Tensor, Device};
1882    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1883    /// let t = tensor.get(0)?;
1884    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
1885    /// let t = tensor.get(1)?;
1886    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
1887    /// # Ok::<(), candle_core::Error>(())
1888    /// ```
1889    pub fn get(&self, i: usize) -> Result<Tensor> {
1890        let dims = self.dims();
1891        if dims.is_empty() {
1892            Ok(self.clone())
1893        } else {
1894            self.narrow(0, i, 1)?.reshape(&dims[1..])
1895        }
1896    }
1897
1898    /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
1899    ///
1900    /// ```rust
1901    /// use candle_core::{Tensor, Device};
1902    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1903    /// let t = tensor.get_on_dim(1, 0)?;
1904    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
1905    /// let t = tensor.get_on_dim(1, 1)?;
1906    /// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
1907    /// let t = tensor.get_on_dim(0, 1)?;
1908    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
1909    /// # Ok::<(), candle_core::Error>(())
1910    /// ```
1911    pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
1912        let dim = dim.to_index(self.shape(), "get_on_dim")?;
1913        self.narrow(dim, index, 1)?.squeeze(dim)
1914    }
1915
1916    /// Returns a tensor that is a transposed version of the input, the two last dimensions of the
1917    /// input are swapped.
1918    ///
1919    /// ```rust
1920    /// use candle_core::{Tensor, Device};
1921    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1922    /// let tensor = tensor.t()?;
1923    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
1924    /// # Ok::<(), candle_core::Error>(())
1925    /// ```
1926    pub fn t(&self) -> Result<Tensor> {
1927        let rank = self.rank();
1928        if rank < 2 {
1929            Err(Error::UnexpectedNumberOfDims {
1930                expected: 2,
1931                got: rank,
1932                shape: self.shape().clone(),
1933            }
1934            .bt())?
1935        }
1936        self.transpose(rank - 2, rank - 1)
1937    }
1938
1939    /// Returns a tensor that is a transposed version of the input, the given dimensions are
1940    /// swapped.
1941    pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
1942        let dim1 = dim1.to_index(self.shape(), "transpose")?;
1943        let dim2 = dim2.to_index(self.shape(), "transpose")?;
1944        if dim1 == dim2 {
1945            return Ok(self.clone());
1946        }
1947        let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
1948        let tensor_ = Tensor_ {
1949            id: TensorId::new(),
1950            storage: self.storage.clone(),
1951            layout: self.layout.transpose(dim1, dim2)?,
1952            op,
1953            is_variable: false,
1954            dtype: self.dtype,
1955            device: self.device.clone(),
1956        };
1957        Ok(Tensor(Arc::new(tensor_)))
1958    }
1959
1960    /// Returns a tensor with the same data as the input where the dimensions have been permuted.
1961    /// dims must be a permutation, i.e. include each dimension index exactly once.
1962    ///
1963    /// ```rust
1964    /// use candle_core::{Tensor, Device};
1965    /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
1966    /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
1967    /// let tensor = tensor.permute((2, 3, 1, 0))?;
1968    /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
1969    /// # Ok::<(), candle_core::Error>(())
1970    /// ```
1971    pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
1972        let dims = dims.to_indexes(self.shape(), "permute")?;
1973        // O(n^2) permutation check but these arrays are small.
1974        let is_permutation =
1975            dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
1976        if !is_permutation {
1977            bail!(
1978                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
1979                self.dims(),
1980                dims
1981            )
1982        }
1983        let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
1984        let tensor_ = Tensor_ {
1985            id: TensorId::new(),
1986            storage: self.storage.clone(),
1987            layout: self.layout.permute(&dims)?,
1988            op,
1989            is_variable: false,
1990            dtype: self.dtype,
1991            device: self.device.clone(),
1992        };
1993        Ok(Tensor(Arc::new(tensor_)))
1994    }
1995
1996    /// Returns true if the data is stored in a C contiguous (aka row major) way.
1997    pub fn is_contiguous(&self) -> bool {
1998        self.layout.is_contiguous()
1999    }
2000
2001    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
2002    pub fn is_fortran_contiguous(&self) -> bool {
2003        self.layout.is_fortran_contiguous()
2004    }
2005
2006    /// Compared to clone, this copies the actual storage but may fail because of running out of
2007    /// memory.
2008    pub fn copy(&self) -> Result<Tensor> {
2009        let op = BackpropOp::new1(self, Op::Copy);
2010        let tensor_ = Tensor_ {
2011            id: TensorId::new(),
2012            storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2013            layout: self.layout.clone(),
2014            op,
2015            is_variable: false,
2016            dtype: self.dtype,
2017            device: self.device.clone(),
2018        };
2019        Ok(Tensor(Arc::new(tensor_)))
2020    }
2021
2022    /// Returns a new tensor detached from the current graph, gradient are not propagated through
2023    /// this new node. The storage of this tensor is shared with the initial tensor.
2024    ///
2025    /// If the tensor is already detached from the computation graph, the same tensor is returned.
2026    pub fn detach(&self) -> Tensor {
2027        if self.op.is_none() && !self.is_variable {
2028            self.clone()
2029        } else {
2030            let tensor_ = Tensor_ {
2031                id: TensorId::new(),
2032                storage: self.storage.clone(),
2033                layout: self.layout.clone(),
2034                op: BackpropOp::none(),
2035                is_variable: false,
2036                dtype: self.dtype,
2037                device: self.device.clone(),
2038            };
2039            Tensor(Arc::new(tensor_))
2040        }
2041    }
2042
2043    /// If the target device is the same as the tensor device, only a shallow copy is performed.
2044    pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2045        if self.device().same_device(device) {
2046            Ok(self.clone())
2047        } else {
2048            let storage = match (&*self.storage(), device) {
2049                (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2050                    Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2051                }
2052                (Storage::Cpu(storage), Device::Metal(metal)) => {
2053                    Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2054                }
2055                (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2056                (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2057                (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2058                    // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
2059                    // are the same.
2060                    let cpu_storage = storage.to_cpu_storage()?;
2061                    Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2062                }
2063                (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2064                _ => {
2065                    bail!(
2066                        "not implemented yet, self.device: {:?}, device: {:?}",
2067                        self.device(),
2068                        device
2069                    )
2070                }
2071            };
2072            let op = BackpropOp::new1(self, Op::ToDevice);
2073            let tensor_ = Tensor_ {
2074                id: TensorId::new(),
2075                storage: Arc::new(RwLock::new(storage)),
2076                layout: self.layout.clone(),
2077                op,
2078                is_variable: false,
2079                dtype: self.dtype,
2080                device: device.clone(),
2081            };
2082            Ok(Tensor(Arc::new(tensor_)))
2083        }
2084    }
2085
2086    /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
2087    /// on the left.
2088    pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2089        let left_shape = left_shape.into();
2090        let mut dims = left_shape.into_dims();
2091        dims.extend(self.dims());
2092        self.broadcast_as(dims)
2093    }
2094
2095    /// Broadcast the input tensor to the target shape. This returns an error if the input shape is
2096    /// not compatible with the target shape.
2097    ///
2098    /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
2099    /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
2100    /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
2101    /// `i_a` is equal to 1, any value can be used.
2102    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2103        let tensor_ = Tensor_ {
2104            id: TensorId::new(),
2105            storage: self.storage.clone(),
2106            layout: self.layout.broadcast_as(shape)?,
2107            op: BackpropOp::new1(self, Op::Broadcast),
2108            is_variable: false,
2109            dtype: self.dtype,
2110            device: self.device.clone(),
2111        };
2112        Ok(Tensor(Arc::new(tensor_)))
2113    }
2114
2115    /// An alias for broadcast_as.
2116    pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2117        self.broadcast_as(shape)
2118    }
2119
2120    /// Casts the input tensor to the target `dtype`.
2121    ///
2122    /// ```rust
2123    /// use candle_core::{Tensor, Device};
2124    /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
2125    /// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
2126    /// let tensor = tensor.to_dtype(candle_core::DType::F32)?;
2127    /// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
2128    /// # Ok::<(), candle_core::Error>(())
2129    /// ```
2130    pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2131        if self.dtype() == dtype {
2132            Ok(self.clone())
2133        } else {
2134            let shape = self.shape();
2135            let storage = self.storage().to_dtype(self.layout(), dtype)?;
2136            let op = BackpropOp::new1(self, Op::ToDType);
2137            Ok(from_storage(storage, shape.clone(), op, false))
2138        }
2139    }
2140
2141    /// Returns a tensor that is in row major order. This is the same as the original tensor if it
2142    /// was already contiguous, otherwise a copy is triggered.
2143    pub fn contiguous(&self) -> Result<Tensor> {
2144        if self.is_contiguous() {
2145            Ok(self.clone())
2146        } else {
2147            let shape = self.shape();
2148            let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2149            self.storage()
2150                .copy_strided_src(&mut storage, 0, self.layout())?;
2151            let op = BackpropOp::new1(self, Op::Copy);
2152            Ok(from_storage(storage, shape.clone(), op, false))
2153        }
2154    }
2155
2156    /// Returns a tensor that is in row major order. This always makes a copy.
2157    pub fn force_contiguous(&self) -> Result<Tensor> {
2158        let shape = self.shape();
2159        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2160        self.storage()
2161            .copy_strided_src(&mut storage, 0, self.layout())?;
2162        let op = BackpropOp::new1(self, Op::Copy);
2163        Ok(from_storage(storage, shape.clone(), op, false))
2164    }
2165
2166    /// Create a variable based on the values currently stored in a tensor. The storage is always
2167    /// copied.
2168    pub(crate) fn make_var(&self) -> Result<Tensor> {
2169        let shape = self.shape().clone();
2170        let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2171        self.storage()
2172            .copy_strided_src(&mut storage, 0, self.layout())?;
2173        Ok(from_storage(storage, shape, BackpropOp::none(), true))
2174    }
2175
2176    /// Reshape returns a tensor with the target shape provided that the number of elements of the
2177    /// original tensor is the same.
2178    /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
2179    /// a new storage and copies the data over, the returned tensor is always contiguous.
2180    ///
2181    /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
2182    /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
2183    /// as to match the number of elements in the tensor.
2184    ///
2185    /// ```rust
2186    /// # use candle_core::{Tensor, DType, Device, D};
2187    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2188    ///
2189    /// let c = a.reshape((1, 6))?;
2190    /// assert_eq!(c.shape().dims(), &[1, 6]);
2191    ///
2192    /// let c = a.reshape((3, 2))?;
2193    /// assert_eq!(c.shape().dims(), &[3, 2]);
2194    ///
2195    /// let c = a.reshape((2, (), 1))?;
2196    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2197    ///
2198    /// # Ok::<(), candle_core::Error>(())
2199    /// ```
2200    pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2201        let shape = s.into_shape(self.elem_count())?;
2202        if shape.elem_count() != self.elem_count() {
2203            return Err(Error::ShapeMismatchBinaryOp {
2204                lhs: self.shape().clone(),
2205                rhs: shape,
2206                op: "reshape",
2207            }
2208            .bt());
2209        }
2210        let op = BackpropOp::new1(self, Op::Reshape);
2211        if self.is_contiguous() {
2212            let tensor_ = Tensor_ {
2213                id: TensorId::new(),
2214                storage: self.storage.clone(),
2215                layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2216                op,
2217                is_variable: false,
2218                dtype: self.dtype,
2219                device: self.device.clone(),
2220            };
2221            Ok(Tensor(Arc::new(tensor_)))
2222        } else {
2223            let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2224            self.storage()
2225                .copy_strided_src(&mut storage, 0, self.layout())?;
2226            Ok(from_storage(storage, shape, op, false))
2227        }
2228    }
2229
2230    /// Creates a new tensor with the specified dimension removed if its size was one.
2231    ///
2232    /// ```rust
2233    /// # use candle_core::{Tensor, DType, Device, D};
2234    /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
2235    ///
2236    /// let c = a.squeeze(2)?;
2237    /// assert_eq!(c.shape().dims(), &[2, 3]);
2238    ///
2239    /// let c = a.squeeze(D::Minus1)?;
2240    /// assert_eq!(c.shape().dims(), &[2, 3]);
2241    /// # Ok::<(), candle_core::Error>(())
2242    /// ```
2243    pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2244        // The PyTorch semantics are to return the same tensor if the target dimension
2245        // does not have a size of 1.
2246        let dims = self.dims();
2247        let dim = dim.to_index(self.shape(), "squeeze")?;
2248        if dims[dim] == 1 {
2249            let mut dims = dims.to_vec();
2250            let mut strides = self.stride().to_vec();
2251            dims.remove(dim);
2252            strides.remove(dim);
2253            let tensor_ = Tensor_ {
2254                id: TensorId::new(),
2255                storage: self.storage.clone(),
2256                layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2257                op: BackpropOp::new1(self, Op::Reshape),
2258                is_variable: false,
2259                dtype: self.dtype,
2260                device: self.device.clone(),
2261            };
2262            Ok(Tensor(Arc::new(tensor_)))
2263        } else {
2264            Ok(self.clone())
2265        }
2266    }
2267
2268    /// Creates a new tensor with a dimension of size one inserted at the specified position.
2269    ///
2270    /// ```rust
2271    /// # use candle_core::{Tensor, DType, Device, D};
2272    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2273    ///
2274    /// let c = a.unsqueeze(0)?;
2275    /// assert_eq!(c.shape().dims(), &[1, 2, 3]);
2276    ///
2277    /// let c = a.unsqueeze(D::Minus1)?;
2278    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2279    /// # Ok::<(), candle_core::Error>(())
2280    /// ```
2281    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2282        let mut dims = self.dims().to_vec();
2283        let mut strides = self.stride().to_vec();
2284        let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2285        // Cannot panic because to_index_plus_one already checks dimensions
2286        dims.insert(dim, 1);
2287        // Any stride would work here, but we pick one so as to maximize the probability to remain
2288        // C contiguous.
2289        let stride = if dim < strides.len() { strides[dim] } else { 1 };
2290        strides.insert(dim, stride);
2291        let tensor_ = Tensor_ {
2292            id: TensorId::new(),
2293            storage: self.storage.clone(),
2294            layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2295            op: BackpropOp::new1(self, Op::Reshape),
2296            is_variable: false,
2297            dtype: self.dtype,
2298            device: self.device.clone(),
2299        };
2300        Ok(Tensor(Arc::new(tensor_)))
2301    }
2302
2303    /// Stacks two or more tensors along a particular dimension.
2304    ///
2305    /// All tensors must have the same rank, and the output has one additional rank
2306    ///
2307    /// ```rust
2308    /// # use candle_core::{Tensor, DType, Device};
2309    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2310    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2311    ///
2312    /// let c = Tensor::stack(&[&a, &b], 0)?;
2313    /// assert_eq!(c.shape().dims(), &[2, 2, 3]);
2314    ///
2315    /// let c = Tensor::stack(&[&a, &b], 2)?;
2316    /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
2317    /// # Ok::<(), candle_core::Error>(())
2318    /// ```
2319    pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2320        if args.is_empty() {
2321            Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2322        }
2323        let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2324        let args = args
2325            .iter()
2326            .map(|t| t.as_ref().unsqueeze(dim))
2327            .collect::<Result<Vec<_>>>()?;
2328        Self::cat(&args, dim)
2329    }
2330
2331    /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
2332    /// input tensor values and `right` elements after.
2333    pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2334        if left == 0 && right == 0 {
2335            Ok(self.clone())
2336        } else if left == 0 {
2337            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2338            let mut dims = self.dims().to_vec();
2339            dims[dim] = right;
2340            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2341            Tensor::cat(&[self, &right], dim)
2342        } else if right == 0 {
2343            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2344            let mut dims = self.dims().to_vec();
2345            dims[dim] = left;
2346            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2347            Tensor::cat(&[&left, self], dim)
2348        } else {
2349            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2350            let mut dims = self.dims().to_vec();
2351            dims[dim] = left;
2352            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2353            dims[dim] = right;
2354            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2355            Tensor::cat(&[&left, self, &right], dim)
2356        }
2357    }
2358
2359    /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
2360    /// input tensor values and `right` elements after.
2361    pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2362        if left == 0 && right == 0 {
2363            Ok(self.clone())
2364        } else if self.elem_count() == 0 {
2365            bail!("cannot use pad_with_same on an empty tensor")
2366        } else if left == 0 {
2367            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2368            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2369            let mut v = vec![self];
2370            for _ in 0..right {
2371                v.push(&r)
2372            }
2373            Tensor::cat(&v, dim)
2374        } else if right == 0 {
2375            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2376            let l = self.narrow(dim, 0, 1)?;
2377            let mut v = vec![];
2378            for _ in 0..left {
2379                v.push(&l)
2380            }
2381            v.push(self);
2382            Tensor::cat(&v, dim)
2383        } else {
2384            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2385            let l = self.narrow(dim, 0, 1)?;
2386            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2387            let mut v = vec![];
2388            for _ in 0..left {
2389                v.push(&l)
2390            }
2391            v.push(self);
2392            for _ in 0..right {
2393                v.push(&r)
2394            }
2395            Tensor::cat(&v, dim)
2396        }
2397    }
2398
2399    /// Run the `forward` method of `m` on `self`.
2400    pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2401        m.forward(self)
2402    }
2403
2404    /// Run the `forward` method of `m` on `self`.
2405    pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2406        m.forward_t(self, train)
2407    }
2408
2409    pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2410        self.storage.read().unwrap()
2411    }
2412
2413    pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2414        self.storage.write().unwrap()
2415    }
2416
2417    // If we extend the visibility of this function to be usable outside of this crate, we should
2418    // make it unsafe.
2419    pub(crate) fn storage_mut_and_layout(
2420        &self,
2421    ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2422        let storage = self.storage.write().unwrap();
2423        (storage, &self.layout)
2424    }
2425
2426    /// The storage used by this tensor, together with the layout to use to access it safely.
2427    pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2428        let storage = self.storage.read().unwrap();
2429        (storage, &self.layout)
2430    }
2431
2432    pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2433        let lhs: &RwLock<Storage> = self.storage.as_ref();
2434        let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2435        std::ptr::eq(lhs, rhs)
2436    }
2437
2438    /// Normalize a 'relative' axis value: positive values are kept, negative
2439    /// values means counting the dimensions from the back.
2440    pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2441        let rank = self.rank() as i64;
2442        if rank <= axis {
2443            bail!("axis {axis} is too large, tensor rank {rank}")
2444        } else if 0 <= axis {
2445            Ok(axis as usize)
2446        } else {
2447            let naxis = rank + axis;
2448            if naxis < 0 {
2449                bail!("axis {axis} is too small, tensor rank {rank}")
2450            }
2451            Ok(naxis as usize)
2452        }
2453    }
2454
2455    /// Returns a lower triangular matrix of ones of size n by n.
2456    pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2457        let t = Tensor::arange(0u32, n as u32, device)?;
2458        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2459        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2460        t1.le(&t2)?.to_dtype(dtype)
2461    }
2462
2463    /// Returns an upper triangular matrix of ones of size n by n.
2464    pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2465        let t = Tensor::arange(0u32, n as u32, device)?;
2466        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2467        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2468        t1.ge(&t2)?.to_dtype(dtype)
2469    }
2470
2471    /// Returns a matrix with a diagonal of ones of size n by n.
2472    pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2473        let t = Tensor::arange(0u32, n as u32, device)?;
2474        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2475        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2476        t1.eq(&t2)?.to_dtype(dtype)
2477    }
2478
2479    /// Returns the cumulative sum of elements of the input tensor summed over the specified
2480    /// dimension.
2481    ///
2482    /// This operation is most efficient when dim is the last dimension of the tensor.
2483    pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2484        let dim = dim.to_index(self.shape(), "cumsum")?;
2485        let rank = self.rank();
2486        if rank == 0 {
2487            return Ok(self.clone());
2488        }
2489        let n_axis = self.dim(dim)?;
2490        let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2491        if rank == 1 {
2492            self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2493        } else {
2494            let last = rank - 1;
2495            let t = self.transpose(dim, last)?;
2496            let t = t.broadcast_matmul(&triu)?;
2497            t.transpose(dim, last)
2498        }
2499    }
2500
2501    /// Returns a copy of `self` where the values within `ranges` have been replaced with the
2502    /// content of `src`.
2503    pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2504        &self,
2505        ranges: &[D],
2506        src: &Tensor,
2507    ) -> Result<Self> {
2508        let src_dims = src.dims();
2509        let self_dims = self.dims();
2510        if self_dims.len() != src_dims.len() {
2511            bail!(
2512                "slice-assign requires input with the same rank {} <> {}",
2513                self_dims.len(),
2514                src_dims.len()
2515            )
2516        }
2517        if self_dims.len() != ranges.len() {
2518            bail!(
2519                "slice-assign requires input with the same rank as there are ranges {} <> {}",
2520                self_dims.len(),
2521                ranges.len()
2522            )
2523        }
2524        let mut src = src.clone();
2525        let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2526        for (i, range) in ranges.iter().enumerate() {
2527            let start_included = match range.start_bound() {
2528                std::ops::Bound::Unbounded => 0,
2529                std::ops::Bound::Included(v) => *v,
2530                std::ops::Bound::Excluded(v) => *v + 1,
2531            };
2532            let end_excluded = match range.end_bound() {
2533                std::ops::Bound::Unbounded => self_dims[i],
2534                std::ops::Bound::Included(v) => *v + 1,
2535                std::ops::Bound::Excluded(v) => *v,
2536            };
2537            if end_excluded <= start_included {
2538                bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2539            }
2540            if self_dims[i] < end_excluded {
2541                bail!(
2542                    "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2543                    self_dims[i]
2544                )
2545            }
2546            if end_excluded - start_included != src_dims[i] {
2547                bail!(
2548                    "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2549                )
2550            }
2551            src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2552            mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2553        }
2554        mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
2555    }
2556
2557    /// Returns log(sum(exp(tensor), dim)).
2558    pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2559        let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2560        if sum_dims.is_empty() {
2561            return Ok(self.clone());
2562        }
2563        let max = sum_dims[1..]
2564            .iter()
2565            .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2566                max.max_keepdim(dim)
2567            })?;
2568        let exp = self.broadcast_sub(&max)?.exp()?;
2569        let sum = exp.sum(sum_dims.clone())?;
2570
2571        sum.log()? + max.squeeze_dims(&sum_dims)
2572    }
2573
2574    /// Pointwise pow operation.
2575    pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2576        rhs.mul(&self.log()?)?.exp()
2577    }
2578
2579    /// Broadcasting version of `pow`.
2580    pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2581        rhs.broadcast_mul(&self.log()?)?.exp()
2582    }
2583}
2584
2585macro_rules! bin_trait {
2586    ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2587        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2588            type Output = Result<Tensor>;
2589
2590            fn $fn1(self, rhs: B) -> Self::Output {
2591                Tensor::$fn1(&self, rhs.borrow())
2592            }
2593        }
2594
2595        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2596            type Output = Result<Tensor>;
2597
2598            fn $fn1(self, rhs: B) -> Self::Output {
2599                Tensor::$fn1(&self, rhs.borrow())
2600            }
2601        }
2602
2603        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2604            type Output = Result<Tensor>;
2605
2606            fn $fn1(self, rhs: Tensor) -> Self::Output {
2607                Tensor::$fn1(self?.borrow(), &rhs)
2608            }
2609        }
2610
2611        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
2612            type Output = Result<Tensor>;
2613
2614            fn $fn1(self, rhs: &Tensor) -> Self::Output {
2615                Tensor::$fn1(self?.borrow(), rhs)
2616            }
2617        }
2618
2619        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
2620            type Output = Result<Tensor>;
2621
2622            fn $fn1(self, rhs: Result<B>) -> Self::Output {
2623                Tensor::$fn1(&self, rhs?.borrow())
2624            }
2625        }
2626
2627        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
2628            type Output = Result<Tensor>;
2629
2630            fn $fn1(self, rhs: Result<B>) -> Self::Output {
2631                Tensor::$fn1(&self, rhs?.borrow())
2632            }
2633        }
2634
2635        impl std::ops::$trait<f64> for Tensor {
2636            type Output = Result<Tensor>;
2637
2638            fn $fn1(self, rhs: f64) -> Self::Output {
2639                self.affine($mul(rhs), $add(rhs))
2640            }
2641        }
2642
2643        impl std::ops::$trait<f64> for &Tensor {
2644            type Output = Result<Tensor>;
2645
2646            fn $fn1(self, rhs: f64) -> Self::Output {
2647                self.affine($mul(rhs), $add(rhs))
2648            }
2649        }
2650    };
2651}
2652
2653bin_trait!(Add, add, |_| 1., |v| v);
2654bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
2655bin_trait!(Mul, mul, |v| v, |_| 0.);
2656bin_trait!(Div, div, |v| 1. / v, |_| 0.);
2657
2658impl std::ops::Add<Tensor> for f64 {
2659    type Output = Result<Tensor>;
2660
2661    fn add(self, rhs: Tensor) -> Self::Output {
2662        rhs + self
2663    }
2664}
2665
2666impl std::ops::Add<&Tensor> for f64 {
2667    type Output = Result<Tensor>;
2668
2669    fn add(self, rhs: &Tensor) -> Self::Output {
2670        rhs + self
2671    }
2672}
2673
2674impl std::ops::Mul<Tensor> for f64 {
2675    type Output = Result<Tensor>;
2676
2677    fn mul(self, rhs: Tensor) -> Self::Output {
2678        rhs * self
2679    }
2680}
2681
2682impl std::ops::Mul<&Tensor> for f64 {
2683    type Output = Result<Tensor>;
2684
2685    fn mul(self, rhs: &Tensor) -> Self::Output {
2686        rhs * self
2687    }
2688}
2689
2690impl std::ops::Sub<Tensor> for f64 {
2691    type Output = Result<Tensor>;
2692
2693    fn sub(self, rhs: Tensor) -> Self::Output {
2694        rhs.affine(-1., self)
2695    }
2696}
2697
2698impl std::ops::Sub<&Tensor> for f64 {
2699    type Output = Result<Tensor>;
2700
2701    fn sub(self, rhs: &Tensor) -> Self::Output {
2702        rhs.affine(-1., self)
2703    }
2704}
2705
2706impl std::ops::Div<Tensor> for f64 {
2707    type Output = Result<Tensor>;
2708
2709    #[allow(clippy::suspicious_arithmetic_impl)]
2710    fn div(self, rhs: Tensor) -> Self::Output {
2711        rhs.recip()? * self
2712    }
2713}
2714
2715impl std::ops::Div<&Tensor> for f64 {
2716    type Output = Result<Tensor>;
2717
2718    #[allow(clippy::suspicious_arithmetic_impl)]
2719    fn div(self, rhs: &Tensor) -> Self::Output {
2720        rhs.recip()? * self
2721    }
2722}