candle_core/
tensor.rs

1//! Tensors are N-dimensional matrixes of elements using a single data type.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10/// Unique identifier for tensors.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15    fn new() -> Self {
16        // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
17        use std::sync::atomic;
18        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20    }
21}
22
23pub struct Tensor_ {
24    id: TensorId,
25    // As we provide inner mutability on the tensor content, the alternatives are:
26    // - Using a mutex, this would have the highest cost when retrieving the storage but would
27    //   prevent errors when concurrent access takes place. Mutex would also be subject to
28    //   deadlocks for example using the current code if the same tensor is used twice by a single
29    //   binary op.
30    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
31    //   verified dynamically, but the resulting tensors would not be send or sync.
32    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
33    //   accesses.
34    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
35    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
36    // that's tricky to encode in the current setup.
37    storage: Arc<RwLock<Storage>>,
38    layout: Layout,
39    op: BackpropOp,
40    is_variable: bool,
41    dtype: DType,
42    device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46    fn as_ref(&self) -> &Tensor {
47        self
48    }
49}
50
51// Tensors are refcounted so that cloning is cheap when building the op graph.
52// Storages are also refcounted independently so that its possible to avoid
53// copying the storage for operations that only modify the shape or stride.
54#[derive(Clone)]
55/// The core struct for manipulating tensors.
56///
57/// ```rust
58/// use candle_core::{Tensor, DType, Device};
59///
60/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
61/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
62///
63/// let c = a.matmul(&b)?;
64/// # Ok::<(), candle_core::Error>(())
65/// ```
66///
67/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
68pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71    type Target = Tensor_;
72
73    fn deref(&self) -> &Self::Target {
74        self.0.as_ref()
75    }
76}
77
78macro_rules! unary_op {
79    ($fn_name:ident, $op_name:ident) => {
80        pub fn $fn_name(&self) -> Result<Self> {
81            let shape = self.shape();
82            if shape.elem_count() == 0 {
83                return Ok(self.clone());
84            }
85            let storage = self
86                .storage()
87                .unary_impl::<crate::op::$op_name>(self.layout())?;
88            let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89            Ok(from_storage(storage, shape.clone(), op, false))
90        }
91    };
92}
93
94macro_rules! binary_op {
95    ($fn_name:ident, $op_name:ident) => {
96        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97            let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98            if shape.elem_count() == 0 {
99                return Ok(self.clone());
100            }
101            let storage = self.storage().binary_impl::<crate::op::$op_name>(
102                &*rhs.storage(),
103                self.layout(),
104                rhs.layout(),
105            )?;
106            let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107            Ok(from_storage(storage, shape.clone(), op, false))
108        }
109    };
110}
111
112macro_rules! binary_op_scalar {
113    ($fn_name:ident, $op_name:ident) => {
114        pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115            let rhs = match rhs.to_tensor_scalar()? {
116                crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117                crate::scalar::TensorScalar::Scalar(rhs) => rhs
118                    .to_dtype(self.dtype())?
119                    .to_device(self.device())?
120                    .broadcast_as(self.shape())?,
121            };
122            let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123            if self.elem_count() == 0 {
124                return Ok(self.clone());
125            }
126            let storage = self.storage().binary_impl::<crate::op::$op_name>(
127                &*rhs.storage(),
128                self.layout(),
129                rhs.layout(),
130            )?;
131            let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132            Ok(from_storage(storage, shape.clone(), op, false))
133        }
134    };
135}
136
137macro_rules! broadcast_binary_op {
138    ($fn_name:ident, $inner_fn_name:ident) => {
139        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140            let lhs = self;
141            let shape = lhs
142                .shape()
143                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144            let l_broadcast = shape != *lhs.shape();
145            let r_broadcast = shape != *rhs.shape();
146            match (l_broadcast, r_broadcast) {
147                (true, true) => lhs
148                    .broadcast_as(&shape)?
149                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152                (false, false) => lhs.$inner_fn_name(rhs),
153            }
154        }
155    };
156}
157
158/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
159pub(crate) fn from_storage<S: Into<Shape>>(
160    storage: Storage,
161    shape: S,
162    op: BackpropOp,
163    is_variable: bool,
164) -> Tensor {
165    let dtype = storage.dtype();
166    let device = storage.device();
167    let tensor_ = Tensor_ {
168        id: TensorId::new(),
169        storage: Arc::new(RwLock::new(storage)),
170        layout: Layout::contiguous(shape),
171        op,
172        is_variable,
173        dtype,
174        device,
175    };
176    Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180    pub(crate) fn ones_impl<S: Into<Shape>>(
181        shape: S,
182        dtype: DType,
183        device: &Device,
184        is_variable: bool,
185    ) -> Result<Self> {
186        let none = BackpropOp::none();
187        let shape = shape.into();
188        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189        let layout = Layout::contiguous(shape.clone());
190        storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191        Ok(from_storage(storage, shape, none, is_variable))
192    }
193
194    /// Creates a new tensor filled with ones.
195    ///
196    /// ```rust
197    /// use candle_core::{Tensor, DType, Device};
198    /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
199    /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
200    /// // a == b
201    /// # Ok::<(), candle_core::Error>(())
202    /// ```
203    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204        Self::ones_impl(shape, dtype, device, false)
205    }
206
207    pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208        self.storage_mut().const_set(value, self.layout())
209    }
210
211    pub fn zero_set(&self) -> Result<()> {
212        self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213    }
214
215    pub fn one_set(&self) -> Result<()> {
216        self.const_set(crate::scalar::Scalar::one(self.dtype()))
217    }
218
219    /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
220    ///
221    /// ```rust
222    /// use candle_core::{Tensor, DType, Device};
223    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
224    /// let b = a.ones_like()?;
225    /// // b == a + 1
226    /// # Ok::<(), candle_core::Error>(())
227    /// ```
228    pub fn ones_like(&self) -> Result<Self> {
229        Tensor::ones(self.shape(), self.dtype(), self.device())
230    }
231
232    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
233    // the variable module.
234    pub(crate) fn zeros_impl<S: Into<Shape>>(
235        shape: S,
236        dtype: DType,
237        device: &Device,
238        is_variable: bool,
239    ) -> Result<Self> {
240        let none = BackpropOp::none();
241        let shape = shape.into();
242        let storage = device.zeros(&shape, dtype)?;
243        Ok(from_storage(storage, shape, none, is_variable))
244    }
245
246    /// Creates a new tensor filled with zeros.
247    ///
248    /// ```rust
249    /// use candle_core::{Tensor, DType, Device};
250    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
251    /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
252    /// // a == b
253    /// # Ok::<(), candle_core::Error>(())
254    /// ```
255    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256        Self::zeros_impl(shape, dtype, device, false)
257    }
258
259    /// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
260    /// tensor.
261    ///
262    /// ```rust
263    /// use candle_core::{Tensor, DType, Device};
264    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
265    /// let b = a.zeros_like()?;
266    /// // b is on CPU f32.
267    /// # Ok::<(), candle_core::Error>(())
268    /// ```
269    pub fn zeros_like(&self) -> Result<Self> {
270        Tensor::zeros(self.shape(), self.dtype(), self.device())
271    }
272
273    pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
274        lo: T,
275        up: T,
276        s: S,
277        device: &Device,
278        is_variable: bool,
279    ) -> Result<Self> {
280        let s = s.into();
281        let storage = device.rand_uniform(lo, up, &s)?;
282        let none = BackpropOp::none();
283        Ok(from_storage(storage, s, none, is_variable))
284    }
285
286    pub(crate) fn rand_f64_impl<S: Into<Shape>>(
287        lo: f64,
288        up: f64,
289        s: S,
290        dtype: DType,
291        device: &Device,
292        is_variable: bool,
293    ) -> Result<Self> {
294        let s = s.into();
295        let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
296        let none = BackpropOp::none();
297        Ok(from_storage(storage, s, none, is_variable))
298    }
299
300    /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
301    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
302        lo: T,
303        up: T,
304        s: S,
305        device: &Device,
306    ) -> Result<Self> {
307        Self::rand_impl(lo, up, s, device, false)
308    }
309
310    pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
311        Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
312    }
313
314    pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
315        mean: T,
316        std: T,
317        s: S,
318        device: &Device,
319        is_variable: bool,
320    ) -> Result<Self> {
321        let s = s.into();
322        let storage = device.rand_normal(mean, std, &s)?;
323        let none = BackpropOp::none();
324        Ok(from_storage(storage, s, none, is_variable))
325    }
326
327    pub(crate) fn randn_f64_impl<S: Into<Shape>>(
328        mean: f64,
329        std: f64,
330        s: S,
331        dtype: DType,
332        device: &Device,
333        is_variable: bool,
334    ) -> Result<Self> {
335        let s = s.into();
336        let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
337        let none = BackpropOp::none();
338        Ok(from_storage(storage, s, none, is_variable))
339    }
340
341    pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
342        Tensor::randn_f64_impl(
343            mean,
344            stdev,
345            self.shape(),
346            self.dtype(),
347            self.device(),
348            false,
349        )
350    }
351
352    /// Creates a new tensor initialized with values sampled from a normal distribution with the
353    /// specified `mean` and standard deviation `std`.
354    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
355        mean: T,
356        std: T,
357        s: S,
358        device: &Device,
359    ) -> Result<Self> {
360        Self::randn_impl(mean, std, s, device, false)
361    }
362
363    pub(crate) fn new_impl<A: crate::device::NdArray>(
364        array: A,
365        shape: Shape,
366        device: &Device,
367        is_variable: bool,
368    ) -> Result<Self> {
369        let n: usize = shape.elem_count();
370        let buffer_size: usize = array.shape()?.elem_count();
371        if buffer_size != n {
372            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
373        }
374        let storage = device.storage(array)?;
375        let none = BackpropOp::none();
376        Ok(from_storage(storage, shape, none, is_variable))
377    }
378
379    /// Creates a new tensor on the specified device using the content and shape of the input.
380    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
381        let shape = array.shape()?;
382        Self::new_impl(array, shape, device, false)
383    }
384
385    /// Returns a new tensor with all the elements having the same specified value. Note that
386    /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
387    ///```rust
388    /// use candle_core::{Tensor, Device};
389    /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
390    ///
391    /// assert_eq!(a.to_vec2::<f64>()?, &[
392    ///     [3.5, 3.5, 3.5, 3.5],
393    ///     [3.5, 3.5, 3.5, 3.5],
394    /// ]);
395    /// # Ok::<(), candle_core::Error>(())
396    pub fn full<D: crate::WithDType, S: Into<Shape>>(
397        value: D,
398        shape: S,
399        device: &Device,
400    ) -> Result<Self> {
401        Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
402    }
403
404    /// Creates a new 1D tensor from an iterator.
405    ///```rust
406    /// use candle_core::{Tensor, Device};
407    /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
408    ///
409    /// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
410    /// # Ok::<(), candle_core::Error>(())
411    /// ```
412    pub fn from_iter<D: crate::WithDType>(
413        iter: impl IntoIterator<Item = D>,
414        device: &Device,
415    ) -> Result<Self> {
416        let data = iter.into_iter().collect::<Vec<_>>();
417        let len = data.len();
418        Self::from_vec_impl(data, len, device, false)
419    }
420
421    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
422    /// difference `1` from `start`.
423    ///```rust
424    /// use candle_core::{Tensor, Device};
425    /// let a = Tensor::arange(2., 5., &Device::Cpu)?;
426    ///
427    /// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
428    /// # Ok::<(), candle_core::Error>(())
429    /// ```
430    pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
431        Self::arange_step(start, end, D::one(), device)
432    }
433
434    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
435    /// difference `step` from `start`.
436    ///```rust
437    /// use candle_core::{Tensor, Device};
438    /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
439    ///
440    /// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
441    /// # Ok::<(), candle_core::Error>(())
442    /// ```
443    pub fn arange_step<D: crate::WithDType>(
444        start: D,
445        end: D,
446        step: D,
447        device: &Device,
448    ) -> Result<Self> {
449        if D::is_zero(&step) {
450            bail!("step cannot be zero")
451        }
452        let mut data = vec![];
453        let mut current = start;
454        if step >= D::zero() {
455            while current < end {
456                data.push(current);
457                current += step;
458            }
459        } else {
460            while current > end {
461                data.push(current);
462                current += step;
463            }
464        }
465        let len = data.len();
466        Self::from_vec_impl(data, len, device, false)
467    }
468
469    pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
470        data: Vec<D>,
471        shape: S,
472        device: &Device,
473        is_variable: bool,
474    ) -> Result<Self> {
475        let shape = shape.into_shape(data.len())?;
476        let storage = device.storage_owned(data)?;
477        let none = BackpropOp::none();
478        Ok(from_storage(storage, shape, none, is_variable))
479    }
480
481    /// Creates a new tensor initialized with values from the input vector. The number of elements
482    /// in this vector must be the same as the number of elements defined by the shape.
483    /// If the device is cpu, no data copy is made.
484    ///```rust
485    /// use candle_core::{Tensor, Device};
486    /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
487    ///
488    /// assert_eq!(a.to_vec2::<f64>()?, &[
489    ///     [1., 2., 3.],
490    ///     [4., 5., 6.]
491    /// ]);
492    /// # Ok::<(), candle_core::Error>(())
493    /// ```
494    pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
495        data: Vec<D>,
496        shape: S,
497        device: &Device,
498    ) -> Result<Self> {
499        Self::from_vec_impl(data, shape, device, false)
500    }
501
502    /// Creates a new tensor initialized with values from the input slice. The number of elements
503    /// in this vector must be the same as the number of elements defined by the shape.
504    ///```rust
505    /// use candle_core::{Tensor, Device};
506    /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
507    /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
508    ///
509    /// assert_eq!(a.to_vec2::<f64>()?, &[
510    ///     [2., 3., 4.],
511    ///     [5., 6., 7.]
512    /// ]);
513    /// # Ok::<(), candle_core::Error>(())
514    /// ```
515    pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
516        array: &[D],
517        shape: S,
518        device: &Device,
519    ) -> Result<Self> {
520        let shape = shape.into_shape(array.len())?;
521        let storage = device.storage_from_slice(array)?;
522        let none = BackpropOp::none();
523        Ok(from_storage(storage, shape, none, false))
524    }
525
526    pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
527        let lhs = self.shape();
528        let rhs = rhs.shape();
529        if lhs != rhs {
530            Err(Error::ShapeMismatchBinaryOp {
531                lhs: lhs.clone(),
532                rhs: rhs.clone(),
533                op,
534            }
535            .bt())
536        } else {
537            Ok(lhs)
538        }
539    }
540
541    /// Returns true if the computation graph should track this op, that is if it is
542    /// a variable or if it has some variable as dependencies.
543    pub fn track_op(&self) -> bool {
544        self.is_variable || self.op.is_some()
545    }
546
547    // TODO: Also make an inplace version or a pre-allocated? This could be tricky
548    // if this can create cycles in the compute graph.
549    binary_op!(add, Add);
550    binary_op!(mul, Mul);
551    binary_op!(sub, Sub);
552    binary_op!(div, Div);
553    binary_op_scalar!(maximum, Maximum);
554    binary_op_scalar!(minimum, Minimum);
555    broadcast_binary_op!(broadcast_add, add);
556    broadcast_binary_op!(broadcast_mul, mul);
557    broadcast_binary_op!(broadcast_sub, sub);
558    broadcast_binary_op!(broadcast_div, div);
559    broadcast_binary_op!(broadcast_maximum, maximum);
560    broadcast_binary_op!(broadcast_minimum, minimum);
561    broadcast_binary_op!(broadcast_eq, eq);
562    broadcast_binary_op!(broadcast_ne, ne);
563    broadcast_binary_op!(broadcast_lt, lt);
564    broadcast_binary_op!(broadcast_le, le);
565    broadcast_binary_op!(broadcast_gt, gt);
566    broadcast_binary_op!(broadcast_ge, ge);
567
568    unary_op!(recip, Recip);
569    unary_op!(neg, Neg);
570    unary_op!(exp, Exp);
571    unary_op!(log, Log);
572    unary_op!(sin, Sin);
573    unary_op!(cos, Cos);
574    unary_op!(tanh, Tanh);
575    unary_op!(abs, Abs);
576    unary_op!(sqr, Sqr);
577    unary_op!(sqrt, Sqrt);
578    unary_op!(gelu, Gelu);
579    unary_op!(gelu_erf, GeluErf);
580    unary_op!(erf, Erf);
581    unary_op!(relu, Relu);
582    unary_op!(silu, Silu);
583    unary_op!(ceil, Ceil);
584    unary_op!(floor, Floor);
585    unary_op!(round, Round);
586    unary_op!(sign, Sign);
587
588    /// Round element of the input tensor to the nearest integer.
589    ///
590    /// If the number of decimals is negative, it specifies the number of positions to the left of
591    /// the decimal point.
592    pub fn round_to(&self, decimals: i32) -> Result<Self> {
593        let mult = 10f64.powi(decimals);
594        (self * mult)?.round()? * (1f64 / mult)
595    }
596
597    /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
598    /// dimensions, an error is returned instead.
599    pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
600        if self.rank() != 0 {
601            Err(Error::UnexpectedNumberOfDims {
602                expected: 0,
603                got: self.rank(),
604                shape: self.shape().clone(),
605            }
606            .bt())?
607        }
608        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
609            let data = S::cpu_storage_as_slice(cpu_storage)?;
610            Ok::<_, Error>(data[self.layout().start_offset()])
611        };
612        match &*self.storage() {
613            Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
614            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
615            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
616        }
617    }
618
619    /// An alias for `to_scalar`.
620    pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
621        self.to_scalar::<S>()
622    }
623
624    /// Repeat this tensor along the specified dimensions.
625    pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
626        // Similar to PyTorch, we extend the number of dimensions of self if needed.
627        let repeats = shape.into();
628        let repeats = repeats.dims();
629        let mut inp = if self.rank() < repeats.len() {
630            let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
631            self.reshape(shape)?
632        } else {
633            self.clone()
634        };
635        for (idx, &repeat) in repeats.iter().enumerate() {
636            if repeat > 1 {
637                inp = Tensor::cat(&vec![&inp; repeat], idx)?
638            }
639        }
640        Ok(inp)
641    }
642
643    /// Creates grids of coordinates specified by the 1D inputs.
644    ///
645    /// # Arguments
646    ///
647    /// * `args` - A slice of 1D tensors.
648    /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
649    ///   first dimension corresponds to the cardinality of the second input and the second
650    ///   dimension corresponds to the cardinality of the first input. If ij is selected, the
651    ///   dimensions are in the same order as the cardinality of the inputs.
652    ///
653    /// # Examples
654    ///
655    /// ```rust
656    /// use candle_core::{Tensor, Device, Shape};
657    /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
658    /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
659    ///
660    /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
661    ///
662    /// assert_eq!(grids_xy.len(), 2);
663    /// assert_eq!(grids_xy[0].dims(), &[3, 3]);
664    ///
665    /// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
666    /// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
667    ///
668    /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
669    ///
670    /// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
671    /// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
672    /// # Ok::<(), candle_core::Error>(())
673    /// ```
674    ///
675    /// # Errors
676    ///
677    /// * Will return `Err` if `args` contains less than 2 tensors.
678    ///
679    pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
680        if args.len() <= 1 {
681            Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
682        }
683        let args: Vec<_> = if xy_indexing {
684            args.iter().rev().collect()
685        } else {
686            args.iter().collect()
687        };
688
689        let mut shape = Vec::with_capacity(args.len());
690        for arg in args.iter() {
691            shape.push(arg.as_ref().dims1()?)
692        }
693
694        let mut grids = Vec::with_capacity(args.len());
695        for idx in 0..args.len() {
696            let mut ones = vec![1usize; args.len()];
697            ones[idx] = shape[idx];
698            let arg = args[idx].as_ref().reshape(ones)?;
699            let mut repeats = shape.clone();
700            repeats[idx] = 1;
701            let repeated_tensor = arg.repeat(repeats)?;
702            grids.push(repeated_tensor);
703        }
704        if xy_indexing {
705            grids.reverse();
706        }
707        Ok(grids)
708    }
709
710    /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
711    /// The input values `mul` and `add` are casted to the appropriate type so some rounding might
712    /// be performed.
713    ///
714    /// ```rust
715    /// use candle_core::{Tensor, Device};
716    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
717    /// let a = a.affine(4., -2.)?;
718    /// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
719    /// # Ok::<(), candle_core::Error>(())
720    /// ```
721    pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
722        if self.elem_count() == 0 {
723            return Ok(self.clone());
724        }
725        let storage = self.storage().affine(self.layout(), mul, add)?;
726        let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
727        Ok(from_storage(storage, self.shape(), op, false))
728    }
729
730    /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
731    pub fn elu(&self, alpha: f64) -> Result<Self> {
732        if self.elem_count() == 0 {
733            return Ok(self.clone());
734        }
735        let storage = self.storage().elu(self.layout(), alpha)?;
736        let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
737        Ok(from_storage(storage, self.shape(), op, false))
738    }
739
740    /// Raise the tensor to some float exponent `e`.
741    pub fn powf(&self, e: f64) -> Result<Self> {
742        if self.elem_count() == 0 {
743            return Ok(self.clone());
744        }
745        let storage = self.storage().powf(self.layout(), e)?;
746        let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
747        Ok(from_storage(storage, self.shape(), op, false))
748    }
749
750    pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
751        if dim >= self.dims().len() {
752            Err(Error::DimOutOfRange {
753                shape: self.shape().clone(),
754                dim: dim as i32,
755                op,
756            }
757            .bt())?
758        } else {
759            Ok(())
760        }
761    }
762
763    /// Split a tensor into the specified number of chunks, this may return less chunks than
764    /// specified.
765    pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
766        let dim = dim.to_index(self.shape(), "chunk")?;
767        let size = self.dim(dim)?;
768        if size < chunks {
769            (0..size).map(|i| self.narrow(dim, i, 1)).collect()
770        } else {
771            let chunk_size = size / chunks;
772            let cnt_additional = size % chunks;
773            let mut tensors = vec![];
774            let mut sum_chunk_size = 0;
775            for i in 0..chunks {
776                let chunk_size = if i < cnt_additional {
777                    chunk_size + 1
778                } else {
779                    chunk_size
780                };
781                let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
782                tensors.push(tensor);
783                sum_chunk_size += chunk_size
784            }
785            Ok(tensors)
786        }
787    }
788
789    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
790    /// ranges from `start` to `start + len`.
791    /// ```
792    /// use candle_core::{Tensor, Device};
793    /// let a = Tensor::new(&[
794    ///     [0f32, 1., 2.],
795    ///     [3.  , 4., 5.],
796    ///     [6.  , 7., 8.]
797    /// ], &Device::Cpu)?;
798    ///
799    /// let b = a.narrow(0, 1, 2)?;
800    /// assert_eq!(b.shape().dims(), &[2, 3]);
801    /// assert_eq!(b.to_vec2::<f32>()?, &[
802    ///     [3., 4., 5.],
803    ///     [6., 7., 8.]
804    /// ]);
805    ///
806    /// let c = a.narrow(1, 1, 1)?;
807    /// assert_eq!(c.shape().dims(), &[3, 1]);
808    /// assert_eq!(c.to_vec2::<f32>()?, &[
809    ///     [1.],
810    ///     [4.],
811    ///     [7.]
812    /// ]);
813    /// # Ok::<(), candle_core::Error>(())
814    /// ```
815    pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
816        let dims = self.dims();
817        let dim = dim.to_index(self.shape(), "narrow")?;
818        let err = |msg| {
819            Err::<(), _>(
820                Error::NarrowInvalidArgs {
821                    shape: self.shape().clone(),
822                    dim,
823                    start,
824                    len,
825                    msg,
826                }
827                .bt(),
828            )
829        };
830        if start > dims[dim] {
831            err("start > dim_len")?
832        }
833        if start.saturating_add(len) > dims[dim] {
834            err("start + len > dim_len")?
835        }
836        if start == 0 && dims[dim] == len {
837            Ok(self.clone())
838        } else {
839            let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
840            let layout = self.layout().narrow(dim, start, len)?;
841            let tensor_ = Tensor_ {
842                id: TensorId::new(),
843                storage: self.storage.clone(),
844                layout,
845                op,
846                is_variable: false,
847                dtype: self.dtype,
848                device: self.device.clone(),
849            };
850            Ok(Tensor(Arc::new(tensor_)))
851        }
852    }
853
854    fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
855        match dims {
856            [] => Ok(self),
857            [i] => self.squeeze(*i),
858            dims => {
859                let dims = self
860                    .dims()
861                    .iter()
862                    .enumerate()
863                    .filter_map(|(dim_idx, &v)| {
864                        if dims.contains(&dim_idx) {
865                            None
866                        } else {
867                            Some(v)
868                        }
869                    })
870                    .collect::<Vec<_>>();
871                self.reshape(dims)
872            }
873        }
874    }
875
876    fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
877        let dim = dim.to_index(self.shape(), op.name())?;
878        let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
879        let mut dims = self.dims().to_vec();
880        dims[dim] = 1;
881        let op = match op {
882            ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
883                BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
884            }
885            ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
886        };
887        let res = from_storage(storage, dims, op, false);
888        if keepdim {
889            Ok(res)
890        } else {
891            res.squeeze_dims(&[dim])
892        }
893    }
894
895    fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
896        let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
897        let storage = self
898            .storage()
899            .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
900        let mut dims = self.dims().to_vec();
901        for &sum_dim in sum_dims.iter() {
902            dims[sum_dim] = 1
903        }
904        let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
905        let sum = from_storage(storage, dims, op, false);
906        if keepdim {
907            Ok(sum)
908        } else {
909            sum.squeeze_dims(&sum_dims)
910        }
911    }
912
913    /// Roll the tensor input along the given dimension.
914    /// Elements that are shifted beyond the last position are re-introduced at the first position.
915    ///
916    /// ```rust
917    /// # use candle_core::{Tensor, Device};
918    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
919    /// let tensor = tensor.roll(1, 0)?;
920    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
921    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
922    /// let tensor = tensor.roll(-1, 0)?;
923    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
924    /// # Ok::<(), candle_core::Error>(())
925    /// ```
926    pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
927    where
928        D: Dim + Clone,
929    {
930        let dim = dim.to_index(self.shape(), "roll")?;
931        let dim_size = self.dim(dim)?;
932        let shift = shift.rem_euclid(dim_size as i32) as usize;
933        if shift == 0 {
934            Ok(self.clone())
935        } else {
936            let a = self.narrow(dim, 0, dim_size - shift)?;
937            let b = self.narrow(dim, dim_size - shift, shift)?;
938            Tensor::cat(&[&b, &a], dim)
939        }
940    }
941
942    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
943    /// input dimensions.
944    ///
945    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
946    /// that the number of elements for each dimension index in `sum_dims` is 1.
947    ///
948    /// ```rust
949    /// use candle_core::{Tensor, Device};
950    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
951    /// let s = a.sum_keepdim(0)?;
952    /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
953    /// let s = a.sum_keepdim(1)?;
954    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
955    /// let s = a.sum_keepdim((0, 1))?;
956    /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
957    /// # Ok::<(), candle_core::Error>(())
958    /// ```
959    pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
960        self.sum_impl(sum_dims, true)
961    }
962
963    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
964    /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
965    /// kept.
966    pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
967        self.sum_impl(sum_dims, false)
968    }
969
970    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
971    /// input dimensions.
972    ///
973    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
974    /// that the number of elements for each dimension index in `mean_dims` is 1.
975    ///
976    /// ```rust
977    /// use candle_core::{Tensor, Device};
978    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
979    /// let s = a.mean_keepdim(0)?;
980    /// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
981    /// let s = a.mean_keepdim(1)?;
982    /// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
983    /// let s = a.mean_keepdim((0, 1))?;
984    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
985    /// # Ok::<(), candle_core::Error>(())
986    /// ```
987    pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
988        let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
989        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
990        let scale = 1f64 / (reduced_dim as f64);
991        self.sum_impl(mean_dims, true)? * scale
992    }
993
994    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
995    /// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
996    /// kept.
997    pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
998        let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
999        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1000        let scale = 1f64 / (reduced_dim as f64);
1001        self.sum_impl(mean_dims, false)? * scale
1002    }
1003
1004    /// Returns the unbiased variance over the selected dimension.
1005    pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1006        let dim = dim.to_index(self.shape(), "var")?;
1007        let mean = self.mean_keepdim(dim)?;
1008        let squares = self.broadcast_sub(&mean)?.sqr()?;
1009        squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1010    }
1011
1012    /// Returns the unbiased variance over the selected dimension.
1013    pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1014        let dim = dim.to_index(self.shape(), "var")?;
1015        self.var_keepdim(dim)?.squeeze(dim)
1016    }
1017
1018    /// Gathers the maximum value across the selected dimension. The resulting shape has the same
1019    /// number of dimensions as the original tensor and the select dimension has a single element.
1020    pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1021        self.reduce_impl(dim, true, ReduceOp::Max)
1022    }
1023
1024    /// Similar to `max_keepdim` but the target dimension is squeezed.
1025    pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1026        self.reduce_impl(dim, false, ReduceOp::Max)
1027    }
1028
1029    /// Gathers the minimum value across the selected dimension. The resulting shape has the same
1030    /// number of dimensions as the original tensor and the select dimension has a single element.
1031    pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1032        self.reduce_impl(dim, true, ReduceOp::Min)
1033    }
1034
1035    /// Similar to `min_keepdim` but the target dimension is squeezed.
1036    pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1037        self.reduce_impl(dim, false, ReduceOp::Min)
1038    }
1039
1040    pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1041        self.reduce_impl(dim, true, ReduceOp::ArgMax)
1042    }
1043
1044    /// Similar to `argmax_keepdim` but the target dimension is squeezed.
1045    pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1046        self.reduce_impl(dim, false, ReduceOp::ArgMax)
1047    }
1048
1049    pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1050        self.reduce_impl(dim, true, ReduceOp::ArgMin)
1051    }
1052
1053    /// Similar to `argmin_keepdim` but the target dimension is squeezed.
1054    pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1055        self.reduce_impl(dim, false, ReduceOp::ArgMin)
1056    }
1057
1058    /// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
1059    /// comparison operation is specified by the `op` argument.
1060    ///
1061    /// The returned tensor has the same shape as the original tensors and uses `u8` elements.
1062    pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1063        let rhs = match rhs.to_tensor_scalar()? {
1064            crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1065            crate::scalar::TensorScalar::Scalar(rhs) => rhs
1066                .to_dtype(self.dtype())?
1067                .to_device(self.device())?
1068                .broadcast_as(self.shape())?,
1069        };
1070        let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1071        let storage = self
1072            .storage()
1073            .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1074        let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1075        Ok(from_storage(storage, shape.dims(), op, false))
1076    }
1077
1078    /// Element-wise equality.
1079    pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1080        self.cmp(rhs, CmpOp::Eq)
1081    }
1082
1083    /// Element-wise non-equality.
1084    pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1085        self.cmp(rhs, CmpOp::Ne)
1086    }
1087
1088    /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
1089    /// rhs` and 0 otherwise.
1090    pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1091        self.cmp(rhs, CmpOp::Lt)
1092    }
1093
1094    /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
1095    /// rhs` and 0 otherwise.
1096    pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1097        self.cmp(rhs, CmpOp::Gt)
1098    }
1099
1100    /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
1101    /// rhs` and 0 otherwise.
1102    pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1103        self.cmp(rhs, CmpOp::Ge)
1104    }
1105
1106    /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
1107    /// rhs` and 0 otherwise.
1108    pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1109        self.cmp(rhs, CmpOp::Le)
1110    }
1111
1112    /// Clamp the tensor values to be between `min` and `max`.
1113    pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1114        self.maximum(min)?.minimum(max)
1115    }
1116
1117    /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
1118    ///
1119    /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
1120    /// tensor also has three dimensions, `(batch, channels, target_size)`.
1121    pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1122        let (n, c, _l) = self.dims3()?;
1123        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1124        let storage = self
1125            .storage()
1126            .upsample_nearest1d(self.layout(), target_size)?;
1127        Ok(from_storage(storage, (n, c, target_size), op, false))
1128    }
1129
1130    /// Alias for `interpolate1d`.
1131    pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1132        self.interpolate1d(target_size)
1133    }
1134
1135    /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
1136    /// nearest element.
1137    ///
1138    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1139    /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
1140    pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1141        let (n, c, _h, _w) = self.dims4()?;
1142        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1143            arg,
1144            target_h,
1145            target_w,
1146        });
1147        let storage = self
1148            .storage()
1149            .upsample_nearest2d(self.layout(), target_h, target_w)?;
1150        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1151    }
1152
1153    /// Alias for `interpolate2d`.
1154    pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1155        self.interpolate2d(target_h, target_w)
1156    }
1157
1158    /// 2D average pooling over an input tensor with multiple channels.
1159    ///
1160    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1161    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1162    /// the two last dimensions using a kernel of size `sz`. The returned element is the average
1163    /// value over the kernel window.
1164    pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1165        let sz = sz.to_usize2();
1166        self.avg_pool2d_with_stride(sz, sz)
1167    }
1168
1169    /// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
1170    /// kernel size.
1171    pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1172        &self,
1173        kernel_size: T,
1174        stride: T,
1175    ) -> Result<Self> {
1176        let kernel_size = kernel_size.to_usize2();
1177        let stride = stride.to_usize2();
1178        let (n, c, h, w) = self.dims4()?;
1179        if h < kernel_size.0 || w < kernel_size.1 {
1180            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1181        }
1182        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
1183        let h_out = (h - kernel_size.0) / stride.0 + 1;
1184        let w_out = (w - kernel_size.1) / stride.1 + 1;
1185        let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1186            arg,
1187            kernel_size,
1188            stride,
1189        });
1190        let storage = self
1191            .storage()
1192            .avg_pool2d(self.layout(), kernel_size, stride)?;
1193        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1194    }
1195
1196    /// 2D max pooling over an input tensor with multiple channels.
1197    ///
1198    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1199    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1200    /// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
1201    /// value over the kernel window.
1202    pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1203        let sz = sz.to_usize2();
1204        self.max_pool2d_with_stride(sz, sz)
1205    }
1206
1207    /// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
1208    /// kernel size.
1209    pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1210        &self,
1211        kernel_size: T,
1212        stride: T,
1213    ) -> Result<Self> {
1214        let kernel_size = kernel_size.to_usize2();
1215        let stride = stride.to_usize2();
1216        let (n, c, h, w) = self.dims4()?;
1217        if h < kernel_size.0 || w < kernel_size.1 {
1218            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1219        }
1220        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
1221        let h_out = (h - kernel_size.0) / stride.0 + 1;
1222        let w_out = (w - kernel_size.1) / stride.1 + 1;
1223        let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1224            arg,
1225            kernel_size,
1226            stride,
1227        });
1228        let storage = self
1229            .storage()
1230            .max_pool2d(self.layout(), kernel_size, stride)?;
1231        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1232    }
1233
1234    /// Returns the matrix-multiplication of the input tensor with the other provided tensor.
1235    ///
1236    /// # Arguments
1237    ///
1238    /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
1239    /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
1240    ///
1241    /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
1242    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1243        let a_dims = self.shape().dims();
1244        let b_dims = rhs.shape().dims();
1245
1246        let dim = a_dims.len();
1247
1248        if dim < 2 || b_dims.len() != dim {
1249            Err(Error::ShapeMismatchBinaryOp {
1250                lhs: self.shape().clone(),
1251                rhs: rhs.shape().clone(),
1252                op: "matmul",
1253            }
1254            .bt())?
1255        }
1256
1257        let m = a_dims[dim - 2];
1258        let k = a_dims[dim - 1];
1259        let k2 = b_dims[dim - 2];
1260        let n = b_dims[dim - 1];
1261
1262        let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1263        if c_shape.elem_count() == 0 || k == 0 {
1264            return Tensor::zeros(c_shape, self.dtype(), self.device());
1265        }
1266        let batching: usize = a_dims[..dim - 2].iter().product();
1267        let batching_b: usize = b_dims[..dim - 2].iter().product();
1268        if k != k2 || batching != batching_b {
1269            Err(Error::ShapeMismatchBinaryOp {
1270                lhs: self.shape().clone(),
1271                rhs: rhs.shape().clone(),
1272                op: "matmul",
1273            }
1274            .bt())?
1275        }
1276
1277        let storage = self.storage().matmul(
1278            &rhs.storage(),
1279            (batching, m, n, k),
1280            self.layout(),
1281            rhs.layout(),
1282        )?;
1283        let op = BackpropOp::new2(self, rhs, Op::Matmul);
1284        Ok(from_storage(storage, c_shape, op, false))
1285    }
1286
1287    /// Matrix-multiplication with broadcasting support.
1288    ///
1289    /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
1290    /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
1291    /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
1292    pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1293        let lhs = self;
1294        let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1295        let l_broadcast = l_shape != *lhs.shape();
1296        let r_broadcast = r_shape != *rhs.shape();
1297        // TODO: Avoid concretising the broadcasted matrixes via contiguous.
1298        match (l_broadcast, r_broadcast) {
1299            (true, true) => lhs
1300                .broadcast_as(&l_shape)?
1301                .contiguous()?
1302                .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1303            (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1304            (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1305            (false, false) => lhs.matmul(rhs),
1306        }
1307    }
1308
1309    /// Returns a tensor with the same shape as the input tensor, the values are taken from
1310    /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
1311    /// input tensor is equal to zero.
1312    pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1313        let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1314        let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1315        let storage = self.storage().where_cond(
1316            self.layout(),
1317            &on_true.storage(),
1318            on_true.layout(),
1319            &on_false.storage(),
1320            on_false.layout(),
1321        )?;
1322        let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1323        Ok(from_storage(storage, shape, op, false))
1324    }
1325
1326    /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
1327    /// values hold in the `ids` tensor.
1328    ///
1329    /// # Arguments
1330    ///
1331    /// * `self` - A tensor with dimensions `v, h`.
1332    /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
1333    ///
1334    /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
1335    /// vocabulary size, and `h` the hidden size.
1336    ///
1337    /// ```rust
1338    /// use candle_core::{Tensor, Device};
1339    /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1340    /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
1341    /// let emb = values.embedding(&ids)?;
1342    /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
1343    /// # Ok::<(), candle_core::Error>(())
1344    /// ```
1345    pub fn embedding(&self, ids: &Self) -> Result<Self> {
1346        if self.rank() != 2 || ids.rank() != 1 {
1347            Err(Error::ShapeMismatchBinaryOp {
1348                lhs: self.shape().clone(),
1349                rhs: ids.shape().clone(),
1350                op: "embedding",
1351            }
1352            .bt())?
1353        }
1354        self.index_select(ids, 0)
1355    }
1356
1357    fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1358        let source_dims = source.dims();
1359        let self_dims = self.dims();
1360        let mismatch = if source_dims.len() != self_dims.len() {
1361            true
1362        } else {
1363            let mut mismatch = false;
1364            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1365                if i != dim && d1 != d2 {
1366                    mismatch = true;
1367                    break;
1368                }
1369            }
1370            mismatch
1371        };
1372        if mismatch {
1373            Err(Error::ShapeMismatchBinaryOp {
1374                op: "scatter (self, src)",
1375                lhs: self.shape().clone(),
1376                rhs: source.shape().clone(),
1377            }
1378            .bt())?
1379        }
1380        if indexes.dims() != source.dims() {
1381            Err(Error::ShapeMismatchBinaryOp {
1382                op: "scatter (indexes, src)",
1383                lhs: indexes.shape().clone(),
1384                rhs: source.shape().clone(),
1385            }
1386            .bt())?
1387        }
1388        Ok(())
1389    }
1390
1391    pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1392        let dim = dim.to_index(self.shape(), "scatter")?;
1393        self.scatter_checks(indexes, source, dim)?;
1394        let shape = self.shape();
1395        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1396        self.storage()
1397            .copy_strided_src(&mut storage, 0, self.layout())?;
1398        let layout = Layout::contiguous(shape);
1399        storage.scatter_set(
1400            &layout,
1401            &indexes.storage(),
1402            indexes.layout(),
1403            &source.storage(),
1404            source.layout(),
1405            dim,
1406        )?;
1407        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1408            Op::Scatter(t1, t2, t3, dim)
1409        });
1410        Ok(from_storage(storage, self.shape(), op, false))
1411    }
1412
1413    pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1414        if self.same_storage(source) {
1415            crate::bail!("cannot use slice_set when self and src share their storage")
1416        }
1417        let dim = dim.to_index(self.shape(), "scatter-set")?;
1418        self.scatter_checks(indexes, source, dim)?;
1419        self.storage_mut().scatter_set(
1420            self.layout(),
1421            &indexes.storage(),
1422            indexes.layout(),
1423            &source.storage(),
1424            source.layout(),
1425            dim,
1426        )?;
1427        Ok(())
1428    }
1429
1430    pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1431        let dim = dim.to_index(self.shape(), "scatter-add")?;
1432        self.scatter_checks(indexes, source, dim)?;
1433        let shape = self.shape();
1434        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1435        self.storage()
1436            .copy_strided_src(&mut storage, 0, self.layout())?;
1437        let layout = Layout::contiguous(shape);
1438        storage.scatter_add(
1439            &layout,
1440            &indexes.storage(),
1441            indexes.layout(),
1442            &source.storage(),
1443            source.layout(),
1444            dim,
1445        )?;
1446        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1447            Op::ScatterAdd(t1, t2, t3, dim)
1448        });
1449        Ok(from_storage(storage, self.shape(), op, false))
1450    }
1451
1452    pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1453        if self.same_storage(source) {
1454            crate::bail!("cannot use slice_set when self and src share their storage")
1455        }
1456        let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1457        self.scatter_checks(indexes, source, dim)?;
1458        self.storage_mut().scatter_add(
1459            self.layout(),
1460            &indexes.storage(),
1461            indexes.layout(),
1462            &source.storage(),
1463            source.layout(),
1464            dim,
1465        )?;
1466        Ok(())
1467    }
1468
1469    /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
1470    pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1471        let dim = dim.to_index(self.shape(), "slice-scatter")?;
1472        if dim == 0 {
1473            self.slice_scatter0(src, start)
1474        } else {
1475            // TODO: Maybe we want to add a more efficient implementation at some point.
1476            self.transpose(0, dim)?
1477                .slice_scatter0(&src.transpose(0, dim)?, start)?
1478                .transpose(0, dim)
1479        }
1480    }
1481
1482    /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
1483    pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1484        if self.dtype() != src.dtype() {
1485            Err(Error::DTypeMismatchBinaryOp {
1486                lhs: self.dtype(),
1487                rhs: src.dtype(),
1488                op: "slice-scatter",
1489            }
1490            .bt())?
1491        }
1492        if self.device().location() != src.device.location() {
1493            Err(Error::DeviceMismatchBinaryOp {
1494                lhs: self.device().location(),
1495                rhs: src.device().location(),
1496                op: "slice-scatter",
1497            }
1498            .bt())?
1499        }
1500        if self.rank() != src.rank() {
1501            Err(Error::UnexpectedNumberOfDims {
1502                expected: self.rank(),
1503                got: src.rank(),
1504                shape: src.shape().clone(),
1505            }
1506            .bt())?
1507        }
1508        let shape_ok =
1509            self.dims()
1510                .iter()
1511                .zip(src.dims().iter())
1512                .enumerate()
1513                .all(|(dim_idx, (&d1, &d2))| {
1514                    if 0 == dim_idx {
1515                        d2 + start <= d1
1516                    } else {
1517                        d1 == d2
1518                    }
1519                });
1520        if !shape_ok {
1521            Err(Error::ShapeMismatchBinaryOp {
1522                op: "slice-scatter (self, src)",
1523                lhs: self.shape().clone(),
1524                rhs: src.shape().clone(),
1525            }
1526            .bt())?
1527        }
1528        let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1529        self.storage()
1530            .copy_strided_src(&mut storage, 0, self.layout())?;
1531        let offset = start * src.dims()[1..].iter().product::<usize>();
1532        src.storage()
1533            .copy_strided_src(&mut storage, offset, src.layout())?;
1534        let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1535        Ok(from_storage(storage, self.shape(), op, false))
1536    }
1537
1538    /// Accumulate element from `source` at indexes `indexes` and add them to `self`.
1539    pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1540        let dim = dim.to_index(self.shape(), "index-add")?;
1541        let source_dims = source.dims();
1542        let self_dims = self.dims();
1543        let mismatch = if source_dims.len() != self_dims.len() {
1544            true
1545        } else {
1546            let mut mismatch = false;
1547            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1548                if i != dim && d1 != d2 {
1549                    mismatch = true;
1550                    break;
1551                }
1552            }
1553            mismatch
1554        };
1555        if mismatch {
1556            Err(Error::ShapeMismatchBinaryOp {
1557                op: "index-add (self, source)",
1558                lhs: self.shape().clone(),
1559                rhs: source.shape().clone(),
1560            }
1561            .bt())?
1562        }
1563        // The number of element in indexes must match the dimension on which the add is
1564        // performed on the source tensor (and the index values from `indexes` are taken from
1565        // the target tensor self)
1566        let indexes_len = indexes.dims1()?;
1567        if source_dims[dim] != indexes_len {
1568            Err(Error::ShapeMismatchBinaryOp {
1569                op: "index-add (ids, source))",
1570                lhs: indexes.shape().clone(),
1571                rhs: source.shape().clone(),
1572            }
1573            .bt())?
1574        }
1575        let storage = self.storage().index_add(
1576            self.layout(),
1577            &indexes.storage(),
1578            indexes.layout(),
1579            &source.storage(),
1580            source.layout(),
1581            dim,
1582        )?;
1583        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1584            Op::IndexAdd(t1, t2, t3, dim)
1585        });
1586        Ok(from_storage(storage, self.shape(), op, false))
1587    }
1588
1589    /// Gather values across the target dimension.
1590    ///
1591    /// # Arguments
1592    ///
1593    /// * `self` - The input tensor.
1594    /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
1595    ///   and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
1596    /// * `dim` - the target dimension.
1597    ///
1598    /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
1599    /// dimension `dim` by the values in `indexes`.
1600    pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1601        let dim = dim.to_index(self.shape(), "gather")?;
1602
1603        let self_dims = self.dims();
1604        let indexes_dims = indexes.dims();
1605        let mismatch = if indexes_dims.len() != self_dims.len() {
1606            true
1607        } else {
1608            let mut mismatch = false;
1609            for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1610                if i != dim && d1 < d2 {
1611                    mismatch = true;
1612                    break;
1613                }
1614            }
1615            mismatch
1616        };
1617        if mismatch {
1618            Err(Error::ShapeMismatchBinaryOp {
1619                op: "gather",
1620                lhs: self.shape().clone(),
1621                rhs: indexes.shape().clone(),
1622            }
1623            .bt())?
1624        }
1625        let storage =
1626            self.storage()
1627                .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1628        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1629        Ok(from_storage(storage, indexes.shape(), op, false))
1630    }
1631
1632    /// Select values for the input tensor at the target indexes across the specified dimension.
1633    ///
1634    /// The `indexes` is argument is an int tensor with a single dimension.
1635    /// The output has the same number of dimension as the `self` input. The target dimension of
1636    /// the output has length the length of `indexes` and the values are taken from `self` using
1637    /// the index from `indexes`. Other dimensions have the same number of elements as the input
1638    /// tensor.
1639    pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1640        let dim = dim.to_index(self.shape(), "index-select")?;
1641        let indexes_len = match indexes.dims() {
1642            [l] => *l,
1643            _ => Err(Error::ShapeMismatchBinaryOp {
1644                lhs: self.shape().clone(),
1645                rhs: indexes.shape().clone(),
1646                op: "index-select",
1647            }
1648            .bt())?,
1649        };
1650        let storage = self.storage().index_select(
1651            &indexes.storage(),
1652            self.layout(),
1653            indexes.layout(),
1654            dim,
1655        )?;
1656        let mut dims = self.dims().to_vec();
1657        dims[dim] = indexes_len;
1658        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1659        Ok(from_storage(storage, dims, op, false))
1660    }
1661
1662    /// Returns an iterator over position of the elements in the storage when ranging over the
1663    /// index tuples in lexicographic order.
1664    pub fn strided_index(&self) -> crate::StridedIndex {
1665        self.layout.strided_index()
1666    }
1667
1668    /// Similar to `strided_index` but returns the position of the start of each contiguous block
1669    /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
1670    /// will only return the start offset and the size would be the number of elements in the
1671    /// tensor.
1672    pub fn strided_blocks(&self) -> crate::StridedBlocks {
1673        self.layout.strided_blocks()
1674    }
1675
1676    /// Returns the data contained in a 1D tensor as a vector of scalar values.
1677    pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1678        if self.rank() != 1 {
1679            Err(Error::UnexpectedNumberOfDims {
1680                expected: 1,
1681                got: self.rank(),
1682                shape: self.shape().clone(),
1683            }
1684            .bt())?
1685        }
1686        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1687            let data = S::cpu_storage_as_slice(cpu_storage)?;
1688            let data = match self.layout.contiguous_offsets() {
1689                Some((o1, o2)) => data[o1..o2].to_vec(),
1690                None => self.strided_index().map(|i| data[i]).collect(),
1691            };
1692            Ok::<Vec<_>, Error>(data)
1693        };
1694        match &*self.storage() {
1695            Storage::Cpu(storage) => from_cpu_storage(storage),
1696            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1697            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1698        }
1699    }
1700
1701    /// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
1702    pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1703        let (dim1, dim2) = self.dims2()?;
1704        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1705            let data = S::cpu_storage_as_slice(cpu_storage)?;
1706            let mut rows = vec![];
1707            match self.layout.contiguous_offsets() {
1708                Some((o1, o2)) => {
1709                    let data = &data[o1..o2];
1710                    for idx_row in 0..dim1 {
1711                        rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1712                    }
1713                }
1714                None => {
1715                    let mut src_index = self.strided_index();
1716                    for _idx_row in 0..dim1 {
1717                        let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1718                        rows.push(row)
1719                    }
1720                    assert!(src_index.next().is_none());
1721                }
1722            }
1723            Ok(rows)
1724        };
1725        match &*self.storage() {
1726            Storage::Cpu(storage) => from_cpu_storage(storage),
1727            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1728            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1729        }
1730    }
1731
1732    /// Returns the data contained in a 3D tensor.
1733    pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1734        let (dim1, dim2, dim3) = self.dims3()?;
1735        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1736            let data = S::cpu_storage_as_slice(cpu_storage)?;
1737            let mut top_rows = vec![];
1738            match self.layout.contiguous_offsets() {
1739                Some((o1, o2)) => {
1740                    let data = &data[o1..o2];
1741                    let dim23 = dim2 * dim3;
1742                    for idx1 in 0..dim1 {
1743                        let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1744                        let mut rows = vec![];
1745                        for idx2 in 0..dim2 {
1746                            rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
1747                        }
1748                        top_rows.push(rows);
1749                    }
1750                }
1751                None => {
1752                    let mut src_index = self.strided_index();
1753                    for _idx in 0..dim1 {
1754                        let mut rows = vec![];
1755                        for _jdx in 0..dim2 {
1756                            let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
1757                            rows.push(row)
1758                        }
1759                        top_rows.push(rows);
1760                    }
1761                    assert!(src_index.next().is_none());
1762                }
1763            }
1764            Ok(top_rows)
1765        };
1766        match &*self.storage() {
1767            Storage::Cpu(storage) => from_cpu_storage(storage),
1768            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1769            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1770        }
1771    }
1772
1773    /// The dtype for the elements stored in the input tensor.
1774    pub fn dtype(&self) -> DType {
1775        self.dtype
1776    }
1777
1778    /// The device on which the input tensor is located.
1779    pub fn device(&self) -> &Device {
1780        &self.device
1781    }
1782
1783    /// The tensor shape, i.e. dimension sizes on each axis.
1784    pub fn shape(&self) -> &Shape {
1785        self.layout().shape()
1786    }
1787
1788    /// The dimension size for this tensor on each axis.
1789    pub fn dims(&self) -> &[usize] {
1790        self.shape().dims()
1791    }
1792
1793    /// The dimension size for a specified dimension index.
1794    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
1795        let dim = dim.to_index(self.shape(), "dim")?;
1796        Ok(self.dims()[dim])
1797    }
1798
1799    /// The layout of the input tensor, this stores both the shape of the tensor as well as the
1800    /// strides and the start offset to apply to the underlying storage.
1801    pub fn layout(&self) -> &Layout {
1802        &self.layout
1803    }
1804
1805    pub fn stride(&self) -> &[usize] {
1806        self.layout.stride()
1807    }
1808
1809    /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
1810    pub fn rank(&self) -> usize {
1811        self.shape().rank()
1812    }
1813
1814    /// The number of elements stored in this tensor.
1815    pub fn elem_count(&self) -> usize {
1816        self.shape().elem_count()
1817    }
1818
1819    /// The unique identifier for this tensor.
1820    pub fn id(&self) -> TensorId {
1821        self.id
1822    }
1823
1824    /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is
1825    /// tracked and on which backpropagation can be performed.
1826    pub fn is_variable(&self) -> bool {
1827        self.is_variable
1828    }
1829
1830    pub(crate) fn op(&self) -> &Option<Op> {
1831        &self.op
1832    }
1833
1834    /// Computes the max of all the elements in this tensor and returns a tensor holding this
1835    /// scalar with zero dimensions.
1836    ///
1837    /// ```rust
1838    /// use candle_core::{Tensor, Device};
1839    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1840    /// let tensor = tensor.max_all()?;
1841    /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
1842    /// # Ok::<(), candle_core::Error>(())
1843    /// ```
1844    pub fn max_all(&self) -> Result<Tensor> {
1845        if self.rank() == 0 {
1846            Ok(self.clone())
1847        } else {
1848            self.flatten_all()?.max(0)
1849        }
1850    }
1851
1852    /// Computes the min of all the elements in this tensor and returns a tensor holding this
1853    /// scalar with zero dimensions.
1854    ///
1855    /// ```rust
1856    /// use candle_core::{Tensor, Device};
1857    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1858    /// let tensor = tensor.min_all()?;
1859    /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
1860    /// # Ok::<(), candle_core::Error>(())
1861    /// ```
1862    pub fn min_all(&self) -> Result<Tensor> {
1863        if self.rank() == 0 {
1864            Ok(self.clone())
1865        } else {
1866            self.flatten_all()?.min(0)
1867        }
1868    }
1869
1870    /// Computes the sum of all the elements in this tensor and returns a tensor holding this
1871    /// scalar with zero dimensions.
1872    ///
1873    /// ```rust
1874    /// use candle_core::{Tensor, Device};
1875    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1876    /// let tensor = tensor.sum_all()?;
1877    /// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
1878    /// # Ok::<(), candle_core::Error>(())
1879    /// ```
1880    pub fn sum_all(&self) -> Result<Tensor> {
1881        let dims: Vec<_> = (0..self.rank()).collect();
1882        self.sum(dims)
1883    }
1884
1885    pub fn mean_all(&self) -> Result<Tensor> {
1886        self.sum_all()? / self.elem_count() as f64
1887    }
1888
1889    fn flatten_<D1: Dim, D2: Dim>(
1890        &self,
1891        start_dim: Option<D1>,
1892        end_dim: Option<D2>,
1893    ) -> Result<Tensor> {
1894        if self.rank() == 0 {
1895            self.reshape(1)
1896        } else {
1897            let start_dim = match start_dim {
1898                None => 0,
1899                Some(dim) => dim.to_index(self.shape(), "flatten")?,
1900            };
1901            let end_dim = match end_dim {
1902                None => self.rank() - 1,
1903                Some(dim) => dim.to_index(self.shape(), "flatten")?,
1904            };
1905            if start_dim < end_dim {
1906                let dims = self.dims();
1907                let mut dst_dims = dims[..start_dim].to_vec();
1908                dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
1909                if end_dim + 1 < dims.len() {
1910                    dst_dims.extend(&dims[end_dim + 1..]);
1911                }
1912                self.reshape(dst_dims)
1913            } else {
1914                Ok(self.clone())
1915            }
1916        }
1917    }
1918
1919    /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
1920    /// inclusive).
1921    pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
1922        self.flatten_(Some(start_dim), Some(end_dim))
1923    }
1924
1925    /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
1926    pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
1927        self.flatten_(None::<usize>, Some(end_dim))
1928    }
1929
1930    /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
1931    /// dimension.
1932    pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
1933        self.flatten_(Some(start_dim), None::<usize>)
1934    }
1935
1936    /// Flattens the input tensor by reshaping it into a one dimension tensor.
1937    ///
1938    /// ```rust
1939    /// use candle_core::{Tensor, Device};
1940    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1941    /// let tensor = tensor.flatten_all()?;
1942    /// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
1943    /// # Ok::<(), candle_core::Error>(())
1944    /// ```
1945    pub fn flatten_all(&self) -> Result<Tensor> {
1946        self.flatten_(None::<usize>, None::<usize>)
1947    }
1948
1949    /// Returns the sub-tensor fixing the index at `i` on the first dimension.
1950    ///
1951    /// ```rust
1952    /// use candle_core::{Tensor, Device};
1953    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1954    /// let t = tensor.get(0)?;
1955    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
1956    /// let t = tensor.get(1)?;
1957    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
1958    /// # Ok::<(), candle_core::Error>(())
1959    /// ```
1960    pub fn get(&self, i: usize) -> Result<Tensor> {
1961        let dims = self.dims();
1962        if dims.is_empty() {
1963            Ok(self.clone())
1964        } else {
1965            self.narrow(0, i, 1)?.reshape(&dims[1..])
1966        }
1967    }
1968
1969    /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
1970    ///
1971    /// ```rust
1972    /// use candle_core::{Tensor, Device};
1973    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1974    /// let t = tensor.get_on_dim(1, 0)?;
1975    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
1976    /// let t = tensor.get_on_dim(1, 1)?;
1977    /// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
1978    /// let t = tensor.get_on_dim(0, 1)?;
1979    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
1980    /// # Ok::<(), candle_core::Error>(())
1981    /// ```
1982    pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
1983        let dim = dim.to_index(self.shape(), "get_on_dim")?;
1984        self.narrow(dim, index, 1)?.squeeze(dim)
1985    }
1986
1987    /// Returns a tensor that is a transposed version of the input, the two last dimensions of the
1988    /// input are swapped.
1989    ///
1990    /// ```rust
1991    /// use candle_core::{Tensor, Device};
1992    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1993    /// let tensor = tensor.t()?;
1994    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
1995    /// # Ok::<(), candle_core::Error>(())
1996    /// ```
1997    pub fn t(&self) -> Result<Tensor> {
1998        let rank = self.rank();
1999        if rank < 2 {
2000            Err(Error::UnexpectedNumberOfDims {
2001                expected: 2,
2002                got: rank,
2003                shape: self.shape().clone(),
2004            }
2005            .bt())?
2006        }
2007        self.transpose(rank - 2, rank - 1)
2008    }
2009
2010    /// Returns a tensor that is a transposed version of the input, the given dimensions are
2011    /// swapped.
2012    pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2013        let dim1 = dim1.to_index(self.shape(), "transpose")?;
2014        let dim2 = dim2.to_index(self.shape(), "transpose")?;
2015        if dim1 == dim2 {
2016            return Ok(self.clone());
2017        }
2018        let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2019        let tensor_ = Tensor_ {
2020            id: TensorId::new(),
2021            storage: self.storage.clone(),
2022            layout: self.layout.transpose(dim1, dim2)?,
2023            op,
2024            is_variable: false,
2025            dtype: self.dtype,
2026            device: self.device.clone(),
2027        };
2028        Ok(Tensor(Arc::new(tensor_)))
2029    }
2030
2031    /// Returns a tensor with the same data as the input where the dimensions have been permuted.
2032    /// dims must be a permutation, i.e. include each dimension index exactly once.
2033    ///
2034    /// ```rust
2035    /// use candle_core::{Tensor, Device};
2036    /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
2037    /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
2038    /// let tensor = tensor.permute((2, 3, 1, 0))?;
2039    /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
2040    /// # Ok::<(), candle_core::Error>(())
2041    /// ```
2042    pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2043        let dims = dims.to_indexes(self.shape(), "permute")?;
2044        // O(n^2) permutation check but these arrays are small.
2045        let is_permutation =
2046            dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2047        if !is_permutation {
2048            bail!(
2049                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2050                self.dims(),
2051                dims
2052            )
2053        }
2054        let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2055        let tensor_ = Tensor_ {
2056            id: TensorId::new(),
2057            storage: self.storage.clone(),
2058            layout: self.layout.permute(&dims)?,
2059            op,
2060            is_variable: false,
2061            dtype: self.dtype,
2062            device: self.device.clone(),
2063        };
2064        Ok(Tensor(Arc::new(tensor_)))
2065    }
2066
2067    /// Returns true if the data is stored in a C contiguous (aka row major) way.
2068    pub fn is_contiguous(&self) -> bool {
2069        self.layout.is_contiguous()
2070    }
2071
2072    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
2073    pub fn is_fortran_contiguous(&self) -> bool {
2074        self.layout.is_fortran_contiguous()
2075    }
2076
2077    /// Compared to clone, this copies the actual storage but may fail because of running out of
2078    /// memory.
2079    pub fn copy(&self) -> Result<Tensor> {
2080        let op = BackpropOp::new1(self, Op::Copy);
2081        let tensor_ = Tensor_ {
2082            id: TensorId::new(),
2083            storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2084            layout: self.layout.clone(),
2085            op,
2086            is_variable: false,
2087            dtype: self.dtype,
2088            device: self.device.clone(),
2089        };
2090        Ok(Tensor(Arc::new(tensor_)))
2091    }
2092
2093    /// Returns a new tensor detached from the current graph, gradient are not propagated through
2094    /// this new node. The storage of this tensor is shared with the initial tensor.
2095    ///
2096    /// If the tensor is already detached from the computation graph, the same tensor is returned.
2097    pub fn detach(&self) -> Tensor {
2098        if self.op.is_none() && !self.is_variable {
2099            self.clone()
2100        } else {
2101            let tensor_ = Tensor_ {
2102                id: TensorId::new(),
2103                storage: self.storage.clone(),
2104                layout: self.layout.clone(),
2105                op: BackpropOp::none(),
2106                is_variable: false,
2107                dtype: self.dtype,
2108                device: self.device.clone(),
2109            };
2110            Tensor(Arc::new(tensor_))
2111        }
2112    }
2113
2114    /// If the target device is the same as the tensor device, only a shallow copy is performed.
2115    pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2116        if self.device().same_device(device) {
2117            Ok(self.clone())
2118        } else {
2119            let storage = match (&*self.storage(), device) {
2120                (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2121                    Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2122                }
2123                (Storage::Cpu(storage), Device::Metal(metal)) => {
2124                    Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2125                }
2126                (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2127                (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2128                (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2129                    // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
2130                    // are the same.
2131                    let cpu_storage = storage.to_cpu_storage()?;
2132                    Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2133                }
2134                (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2135                _ => {
2136                    bail!(
2137                        "not implemented yet, self.device: {:?}, device: {:?}",
2138                        self.device(),
2139                        device
2140                    )
2141                }
2142            };
2143            let op = BackpropOp::new1(self, Op::ToDevice);
2144            let tensor_ = Tensor_ {
2145                id: TensorId::new(),
2146                storage: Arc::new(RwLock::new(storage)),
2147                layout: self.layout.clone(),
2148                op,
2149                is_variable: false,
2150                dtype: self.dtype,
2151                device: device.clone(),
2152            };
2153            Ok(Tensor(Arc::new(tensor_)))
2154        }
2155    }
2156
2157    /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
2158    /// on the left.
2159    pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2160        let left_shape = left_shape.into();
2161        let mut dims = left_shape.into_dims();
2162        dims.extend(self.dims());
2163        self.broadcast_as(dims)
2164    }
2165
2166    /// Broadcast the input tensor to the target shape. This returns an error if the input shape is
2167    /// not compatible with the target shape.
2168    ///
2169    /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
2170    /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
2171    /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
2172    /// `i_a` is equal to 1, any value can be used.
2173    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2174        let tensor_ = Tensor_ {
2175            id: TensorId::new(),
2176            storage: self.storage.clone(),
2177            layout: self.layout.broadcast_as(shape)?,
2178            op: BackpropOp::new1(self, Op::Broadcast),
2179            is_variable: false,
2180            dtype: self.dtype,
2181            device: self.device.clone(),
2182        };
2183        Ok(Tensor(Arc::new(tensor_)))
2184    }
2185
2186    /// An alias for broadcast_as.
2187    pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2188        self.broadcast_as(shape)
2189    }
2190
2191    /// Casts the input tensor to the target `dtype`.
2192    ///
2193    /// ```rust
2194    /// use candle_core::{Tensor, Device};
2195    /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
2196    /// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
2197    /// let tensor = tensor.to_dtype(candle_core::DType::F32)?;
2198    /// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
2199    /// # Ok::<(), candle_core::Error>(())
2200    /// ```
2201    pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2202        if self.dtype() == dtype {
2203            Ok(self.clone())
2204        } else {
2205            let shape = self.shape();
2206            let storage = self.storage().to_dtype(self.layout(), dtype)?;
2207            let op = BackpropOp::new1(self, Op::ToDType);
2208            Ok(from_storage(storage, shape.clone(), op, false))
2209        }
2210    }
2211
2212    /// Returns a tensor that is in row major order. This is the same as the original tensor if it
2213    /// was already contiguous, otherwise a copy is triggered.
2214    pub fn contiguous(&self) -> Result<Tensor> {
2215        if self.is_contiguous() {
2216            Ok(self.clone())
2217        } else {
2218            let shape = self.shape();
2219            let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2220            self.storage()
2221                .copy_strided_src(&mut storage, 0, self.layout())?;
2222            let op = BackpropOp::new1(self, Op::Copy);
2223            Ok(from_storage(storage, shape.clone(), op, false))
2224        }
2225    }
2226
2227    /// Returns a tensor that is in row major order. This always makes a copy.
2228    pub fn force_contiguous(&self) -> Result<Tensor> {
2229        let shape = self.shape();
2230        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2231        self.storage()
2232            .copy_strided_src(&mut storage, 0, self.layout())?;
2233        let op = BackpropOp::new1(self, Op::Copy);
2234        Ok(from_storage(storage, shape.clone(), op, false))
2235    }
2236
2237    /// Create a variable based on the values currently stored in a tensor. The storage is always
2238    /// copied.
2239    pub(crate) fn make_var(&self) -> Result<Tensor> {
2240        let shape = self.shape().clone();
2241        let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2242        self.storage()
2243            .copy_strided_src(&mut storage, 0, self.layout())?;
2244        Ok(from_storage(storage, shape, BackpropOp::none(), true))
2245    }
2246
2247    /// Reshape returns a tensor with the target shape provided that the number of elements of the
2248    /// original tensor is the same.
2249    /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
2250    /// a new storage and copies the data over, the returned tensor is always contiguous.
2251    ///
2252    /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
2253    /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
2254    /// as to match the number of elements in the tensor.
2255    ///
2256    /// ```rust
2257    /// # use candle_core::{Tensor, DType, Device, D};
2258    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2259    ///
2260    /// let c = a.reshape((1, 6))?;
2261    /// assert_eq!(c.shape().dims(), &[1, 6]);
2262    ///
2263    /// let c = a.reshape((3, 2))?;
2264    /// assert_eq!(c.shape().dims(), &[3, 2]);
2265    ///
2266    /// let c = a.reshape((2, (), 1))?;
2267    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2268    ///
2269    /// # Ok::<(), candle_core::Error>(())
2270    /// ```
2271    pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2272        let shape = s.into_shape(self.elem_count())?;
2273        if shape.elem_count() != self.elem_count() {
2274            return Err(Error::ShapeMismatchBinaryOp {
2275                lhs: self.shape().clone(),
2276                rhs: shape,
2277                op: "reshape",
2278            }
2279            .bt());
2280        }
2281        let op = BackpropOp::new1(self, Op::Reshape);
2282        if self.is_contiguous() {
2283            let tensor_ = Tensor_ {
2284                id: TensorId::new(),
2285                storage: self.storage.clone(),
2286                layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2287                op,
2288                is_variable: false,
2289                dtype: self.dtype,
2290                device: self.device.clone(),
2291            };
2292            Ok(Tensor(Arc::new(tensor_)))
2293        } else {
2294            let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2295            self.storage()
2296                .copy_strided_src(&mut storage, 0, self.layout())?;
2297            Ok(from_storage(storage, shape, op, false))
2298        }
2299    }
2300
2301    /// Creates a new tensor with the specified dimension removed if its size was one.
2302    ///
2303    /// ```rust
2304    /// # use candle_core::{Tensor, DType, Device, D};
2305    /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
2306    ///
2307    /// let c = a.squeeze(2)?;
2308    /// assert_eq!(c.shape().dims(), &[2, 3]);
2309    ///
2310    /// let c = a.squeeze(D::Minus1)?;
2311    /// assert_eq!(c.shape().dims(), &[2, 3]);
2312    /// # Ok::<(), candle_core::Error>(())
2313    /// ```
2314    pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2315        // The PyTorch semantics are to return the same tensor if the target dimension
2316        // does not have a size of 1.
2317        let dims = self.dims();
2318        let dim = dim.to_index(self.shape(), "squeeze")?;
2319        if dims[dim] == 1 {
2320            let mut dims = dims.to_vec();
2321            let mut strides = self.stride().to_vec();
2322            dims.remove(dim);
2323            strides.remove(dim);
2324            let tensor_ = Tensor_ {
2325                id: TensorId::new(),
2326                storage: self.storage.clone(),
2327                layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2328                op: BackpropOp::new1(self, Op::Reshape),
2329                is_variable: false,
2330                dtype: self.dtype,
2331                device: self.device.clone(),
2332            };
2333            Ok(Tensor(Arc::new(tensor_)))
2334        } else {
2335            Ok(self.clone())
2336        }
2337    }
2338
2339    /// Creates a new tensor with a dimension of size one inserted at the specified position.
2340    ///
2341    /// ```rust
2342    /// # use candle_core::{Tensor, DType, Device, D};
2343    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2344    ///
2345    /// let c = a.unsqueeze(0)?;
2346    /// assert_eq!(c.shape().dims(), &[1, 2, 3]);
2347    ///
2348    /// let c = a.unsqueeze(D::Minus1)?;
2349    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2350    /// # Ok::<(), candle_core::Error>(())
2351    /// ```
2352    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2353        let mut dims = self.dims().to_vec();
2354        let mut strides = self.stride().to_vec();
2355        let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2356        // Cannot panic because to_index_plus_one already checks dimensions
2357        dims.insert(dim, 1);
2358        // Any stride would work here, but we pick one so as to maximize the probability to remain
2359        // C contiguous.
2360        let stride = if dim < strides.len() { strides[dim] } else { 1 };
2361        strides.insert(dim, stride);
2362        let tensor_ = Tensor_ {
2363            id: TensorId::new(),
2364            storage: self.storage.clone(),
2365            layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2366            op: BackpropOp::new1(self, Op::Reshape),
2367            is_variable: false,
2368            dtype: self.dtype,
2369            device: self.device.clone(),
2370        };
2371        Ok(Tensor(Arc::new(tensor_)))
2372    }
2373
2374    /// Stacks two or more tensors along a particular dimension.
2375    ///
2376    /// All tensors must have the same rank, and the output has one additional rank
2377    ///
2378    /// ```rust
2379    /// # use candle_core::{Tensor, DType, Device};
2380    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2381    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2382    ///
2383    /// let c = Tensor::stack(&[&a, &b], 0)?;
2384    /// assert_eq!(c.shape().dims(), &[2, 2, 3]);
2385    ///
2386    /// let c = Tensor::stack(&[&a, &b], 2)?;
2387    /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
2388    /// # Ok::<(), candle_core::Error>(())
2389    /// ```
2390    pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2391        if args.is_empty() {
2392            Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2393        }
2394        let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2395        let args = args
2396            .iter()
2397            .map(|t| t.as_ref().unsqueeze(dim))
2398            .collect::<Result<Vec<_>>>()?;
2399        Self::cat(&args, dim)
2400    }
2401
2402    /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
2403    /// input tensor values and `right` elements after.
2404    pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2405        if left == 0 && right == 0 {
2406            Ok(self.clone())
2407        } else if left == 0 {
2408            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2409            let mut dims = self.dims().to_vec();
2410            dims[dim] = right;
2411            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2412            Tensor::cat(&[self, &right], dim)
2413        } else if right == 0 {
2414            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2415            let mut dims = self.dims().to_vec();
2416            dims[dim] = left;
2417            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2418            Tensor::cat(&[&left, self], dim)
2419        } else {
2420            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2421            let mut dims = self.dims().to_vec();
2422            dims[dim] = left;
2423            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2424            dims[dim] = right;
2425            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2426            Tensor::cat(&[&left, self, &right], dim)
2427        }
2428    }
2429
2430    /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
2431    /// input tensor values and `right` elements after.
2432    pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2433        if left == 0 && right == 0 {
2434            Ok(self.clone())
2435        } else if self.elem_count() == 0 {
2436            bail!("cannot use pad_with_same on an empty tensor")
2437        } else if left == 0 {
2438            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2439            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2440            let mut v = vec![self];
2441            for _ in 0..right {
2442                v.push(&r)
2443            }
2444            Tensor::cat(&v, dim)
2445        } else if right == 0 {
2446            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2447            let l = self.narrow(dim, 0, 1)?;
2448            let mut v = vec![];
2449            for _ in 0..left {
2450                v.push(&l)
2451            }
2452            v.push(self);
2453            Tensor::cat(&v, dim)
2454        } else {
2455            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2456            let l = self.narrow(dim, 0, 1)?;
2457            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2458            let mut v = vec![];
2459            for _ in 0..left {
2460                v.push(&l)
2461            }
2462            v.push(self);
2463            for _ in 0..right {
2464                v.push(&r)
2465            }
2466            Tensor::cat(&v, dim)
2467        }
2468    }
2469
2470    /// Run the `forward` method of `m` on `self`.
2471    pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2472        m.forward(self)
2473    }
2474
2475    /// Run the `forward` method of `m` on `self`.
2476    pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2477        m.forward_t(self, train)
2478    }
2479
2480    pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2481        self.storage.read().unwrap()
2482    }
2483
2484    pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2485        self.storage.write().unwrap()
2486    }
2487
2488    // If we extend the visibility of this function to be usable outside of this crate, we should
2489    // make it unsafe.
2490    pub(crate) fn storage_mut_and_layout(
2491        &self,
2492    ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2493        let storage = self.storage.write().unwrap();
2494        (storage, &self.layout)
2495    }
2496
2497    /// The storage used by this tensor, together with the layout to use to access it safely.
2498    pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2499        let storage = self.storage.read().unwrap();
2500        (storage, &self.layout)
2501    }
2502
2503    pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2504        let lhs: &RwLock<Storage> = self.storage.as_ref();
2505        let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2506        std::ptr::eq(lhs, rhs)
2507    }
2508
2509    /// Normalize a 'relative' axis value: positive values are kept, negative
2510    /// values means counting the dimensions from the back.
2511    pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2512        let rank = self.rank() as i64;
2513        if rank <= axis {
2514            bail!("axis {axis} is too large, tensor rank {rank}")
2515        } else if 0 <= axis {
2516            Ok(axis as usize)
2517        } else {
2518            let naxis = rank + axis;
2519            if naxis < 0 {
2520                bail!("axis {axis} is too small, tensor rank {rank}")
2521            }
2522            Ok(naxis as usize)
2523        }
2524    }
2525
2526    /// Returns a lower triangular matrix of ones of size n by n.
2527    pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2528        let t = Tensor::arange(0u32, n as u32, device)?;
2529        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2530        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2531        t1.le(&t2)?.to_dtype(dtype)
2532    }
2533
2534    /// Returns an upper triangular matrix of ones of size n by n.
2535    pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2536        let t = Tensor::arange(0u32, n as u32, device)?;
2537        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2538        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2539        t1.ge(&t2)?.to_dtype(dtype)
2540    }
2541
2542    /// Returns a matrix with a diagonal of ones of size n by n.
2543    pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2544        let t = Tensor::arange(0u32, n as u32, device)?;
2545        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2546        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2547        t1.eq(&t2)?.to_dtype(dtype)
2548    }
2549
2550    /// Returns the cumulative sum of elements of the input tensor summed over the specified
2551    /// dimension.
2552    ///
2553    /// This operation is most efficient when dim is the last dimension of the tensor.
2554    pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2555        let dim = dim.to_index(self.shape(), "cumsum")?;
2556        let rank = self.rank();
2557        if rank == 0 {
2558            return Ok(self.clone());
2559        }
2560        let n_axis = self.dim(dim)?;
2561        let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2562        if rank == 1 {
2563            self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2564        } else {
2565            let last = rank - 1;
2566            let t = self.transpose(dim, last)?;
2567            let t = t.broadcast_matmul(&triu)?;
2568            t.transpose(dim, last)
2569        }
2570    }
2571
2572    /// Returns a copy of `self` where the values within `ranges` have been replaced with the
2573    /// content of `src`.
2574    pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2575        &self,
2576        ranges: &[D],
2577        src: &Tensor,
2578    ) -> Result<Self> {
2579        let src_dims = src.dims();
2580        let self_dims = self.dims();
2581        if self_dims.len() != src_dims.len() {
2582            bail!(
2583                "slice-assign requires input with the same rank {} <> {}",
2584                self_dims.len(),
2585                src_dims.len()
2586            )
2587        }
2588        if self_dims.len() != ranges.len() {
2589            bail!(
2590                "slice-assign requires input with the same rank as there are ranges {} <> {}",
2591                self_dims.len(),
2592                ranges.len()
2593            )
2594        }
2595        let mut src = src.clone();
2596        let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2597        for (i, range) in ranges.iter().enumerate() {
2598            let start_included = match range.start_bound() {
2599                std::ops::Bound::Unbounded => 0,
2600                std::ops::Bound::Included(v) => *v,
2601                std::ops::Bound::Excluded(v) => *v + 1,
2602            };
2603            let end_excluded = match range.end_bound() {
2604                std::ops::Bound::Unbounded => self_dims[i],
2605                std::ops::Bound::Included(v) => *v + 1,
2606                std::ops::Bound::Excluded(v) => *v,
2607            };
2608            if end_excluded <= start_included {
2609                bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2610            }
2611            if self_dims[i] < end_excluded {
2612                bail!(
2613                    "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2614                    self_dims[i]
2615                )
2616            }
2617            if end_excluded - start_included != src_dims[i] {
2618                bail!(
2619                    "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2620                )
2621            }
2622            src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2623            mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2624        }
2625        mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
2626    }
2627
2628    /// Returns log(sum(exp(tensor), dim)).
2629    pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2630        let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2631        if sum_dims.is_empty() {
2632            return Ok(self.clone());
2633        }
2634        let max = sum_dims[1..]
2635            .iter()
2636            .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2637                max.max_keepdim(dim)
2638            })?;
2639        let exp = self.broadcast_sub(&max)?.exp()?;
2640        let sum = exp.sum(sum_dims.clone())?;
2641
2642        sum.log()? + max.squeeze_dims(&sum_dims)
2643    }
2644
2645    /// Pointwise pow operation.
2646    pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2647        rhs.mul(&self.log()?)?.exp()
2648    }
2649
2650    /// Broadcasting version of `pow`.
2651    pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2652        rhs.broadcast_mul(&self.log()?)?.exp()
2653    }
2654
2655    /// Returns a new tensor with the order of elements reversed along the specified dimensions.
2656    /// This function makes a copy of the tensor’s data.
2657    ///
2658    /// ```rust
2659    /// # use candle_core::{Tensor, Device};
2660    /// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
2661    /// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
2662    /// let t_flipped = t.flip(&[0])?;
2663    /// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
2664    /// # Ok::<(), candle_core::Error>(())
2665    /// ```
2666    pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2667        let mut result = self.clone();
2668        for &dim in dims.iter() {
2669            let size = result.dim(dim)?;
2670            let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2671            let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2672            result = result.index_select(&indices_tensor, dim)?;
2673        }
2674        Ok(result)
2675    }
2676}
2677
2678macro_rules! bin_trait {
2679    ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2680        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2681            type Output = Result<Tensor>;
2682
2683            fn $fn1(self, rhs: B) -> Self::Output {
2684                Tensor::$fn1(&self, rhs.borrow())
2685            }
2686        }
2687
2688        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2689            type Output = Result<Tensor>;
2690
2691            fn $fn1(self, rhs: B) -> Self::Output {
2692                Tensor::$fn1(&self, rhs.borrow())
2693            }
2694        }
2695
2696        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2697            type Output = Result<Tensor>;
2698
2699            fn $fn1(self, rhs: Tensor) -> Self::Output {
2700                Tensor::$fn1(self?.borrow(), &rhs)
2701            }
2702        }
2703
2704        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
2705            type Output = Result<Tensor>;
2706
2707            fn $fn1(self, rhs: &Tensor) -> Self::Output {
2708                Tensor::$fn1(self?.borrow(), rhs)
2709            }
2710        }
2711
2712        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
2713            type Output = Result<Tensor>;
2714
2715            fn $fn1(self, rhs: Result<B>) -> Self::Output {
2716                Tensor::$fn1(&self, rhs?.borrow())
2717            }
2718        }
2719
2720        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
2721            type Output = Result<Tensor>;
2722
2723            fn $fn1(self, rhs: Result<B>) -> Self::Output {
2724                Tensor::$fn1(&self, rhs?.borrow())
2725            }
2726        }
2727
2728        impl std::ops::$trait<f64> for Tensor {
2729            type Output = Result<Tensor>;
2730
2731            fn $fn1(self, rhs: f64) -> Self::Output {
2732                self.affine($mul(rhs), $add(rhs))
2733            }
2734        }
2735
2736        impl std::ops::$trait<f64> for &Tensor {
2737            type Output = Result<Tensor>;
2738
2739            fn $fn1(self, rhs: f64) -> Self::Output {
2740                self.affine($mul(rhs), $add(rhs))
2741            }
2742        }
2743    };
2744}
2745
2746bin_trait!(Add, add, |_| 1., |v| v);
2747bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
2748bin_trait!(Mul, mul, |v| v, |_| 0.);
2749bin_trait!(Div, div, |v| 1. / v, |_| 0.);
2750
2751impl std::ops::Add<Tensor> for f64 {
2752    type Output = Result<Tensor>;
2753
2754    fn add(self, rhs: Tensor) -> Self::Output {
2755        rhs + self
2756    }
2757}
2758
2759impl std::ops::Add<&Tensor> for f64 {
2760    type Output = Result<Tensor>;
2761
2762    fn add(self, rhs: &Tensor) -> Self::Output {
2763        rhs + self
2764    }
2765}
2766
2767impl std::ops::Mul<Tensor> for f64 {
2768    type Output = Result<Tensor>;
2769
2770    fn mul(self, rhs: Tensor) -> Self::Output {
2771        rhs * self
2772    }
2773}
2774
2775impl std::ops::Mul<&Tensor> for f64 {
2776    type Output = Result<Tensor>;
2777
2778    fn mul(self, rhs: &Tensor) -> Self::Output {
2779        rhs * self
2780    }
2781}
2782
2783impl std::ops::Sub<Tensor> for f64 {
2784    type Output = Result<Tensor>;
2785
2786    fn sub(self, rhs: Tensor) -> Self::Output {
2787        rhs.affine(-1., self)
2788    }
2789}
2790
2791impl std::ops::Sub<&Tensor> for f64 {
2792    type Output = Result<Tensor>;
2793
2794    fn sub(self, rhs: &Tensor) -> Self::Output {
2795        rhs.affine(-1., self)
2796    }
2797}
2798
2799impl std::ops::Div<Tensor> for f64 {
2800    type Output = Result<Tensor>;
2801
2802    #[allow(clippy::suspicious_arithmetic_impl)]
2803    fn div(self, rhs: Tensor) -> Self::Output {
2804        rhs.recip()? * self
2805    }
2806}
2807
2808impl std::ops::Div<&Tensor> for f64 {
2809    type Output = Result<Tensor>;
2810
2811    #[allow(clippy::suspicious_arithmetic_impl)]
2812    fn div(self, rhs: &Tensor) -> Self::Output {
2813        rhs.recip()? * self
2814    }
2815}