candle_core/
tensor.rs

1//! Tensors are N-dimensional matrixes of elements using a single data type.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10/// Unique identifier for tensors.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15    fn new() -> Self {
16        // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
17        use std::sync::atomic;
18        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20    }
21}
22
23pub struct Tensor_ {
24    id: TensorId,
25    // As we provide inner mutability on the tensor content, the alternatives are:
26    // - Using a mutex, this would have the highest cost when retrieving the storage but would
27    //   prevent errors when concurrent access takes place. Mutex would also be subject to
28    //   deadlocks for example using the current code if the same tensor is used twice by a single
29    //   binary op.
30    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
31    //   verified dynamically, but the resulting tensors would not be send or sync.
32    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
33    //   accesses.
34    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
35    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
36    // that's tricky to encode in the current setup.
37    storage: Arc<RwLock<Storage>>,
38    layout: Layout,
39    op: BackpropOp,
40    is_variable: bool,
41    dtype: DType,
42    device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46    fn as_ref(&self) -> &Tensor {
47        self
48    }
49}
50
51// Tensors are refcounted so that cloning is cheap when building the op graph.
52// Storages are also refcounted independently so that its possible to avoid
53// copying the storage for operations that only modify the shape or stride.
54#[derive(Clone)]
55/// The core struct for manipulating tensors.
56///
57/// ```rust
58/// use candle_core::{Tensor, DType, Device};
59///
60/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
61/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
62///
63/// let c = a.matmul(&b)?;
64/// # Ok::<(), candle_core::Error>(())
65/// ```
66///
67/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
68pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71    type Target = Tensor_;
72
73    fn deref(&self) -> &Self::Target {
74        self.0.as_ref()
75    }
76}
77
78macro_rules! unary_op {
79    ($fn_name:ident, $op_name:ident) => {
80        pub fn $fn_name(&self) -> Result<Self> {
81            let shape = self.shape();
82            if shape.elem_count() == 0 {
83                return Ok(self.clone());
84            }
85            let storage = self
86                .storage()
87                .unary_impl::<crate::op::$op_name>(self.layout())?;
88            let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89            Ok(from_storage(storage, shape.clone(), op, false))
90        }
91    };
92}
93
94macro_rules! binary_op {
95    ($fn_name:ident, $op_name:ident) => {
96        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97            let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98            if shape.elem_count() == 0 {
99                return Ok(self.clone());
100            }
101            let storage = self.storage().binary_impl::<crate::op::$op_name>(
102                &*rhs.storage(),
103                self.layout(),
104                rhs.layout(),
105            )?;
106            let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107            Ok(from_storage(storage, shape.clone(), op, false))
108        }
109    };
110}
111
112macro_rules! binary_op_scalar {
113    ($fn_name:ident, $op_name:ident) => {
114        pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115            let rhs = match rhs.to_tensor_scalar()? {
116                crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117                crate::scalar::TensorScalar::Scalar(rhs) => rhs
118                    .to_dtype(self.dtype())?
119                    .to_device(self.device())?
120                    .broadcast_as(self.shape())?,
121            };
122            let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123            if self.elem_count() == 0 {
124                return Ok(self.clone());
125            }
126            let storage = self.storage().binary_impl::<crate::op::$op_name>(
127                &*rhs.storage(),
128                self.layout(),
129                rhs.layout(),
130            )?;
131            let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132            Ok(from_storage(storage, shape.clone(), op, false))
133        }
134    };
135}
136
137macro_rules! broadcast_binary_op {
138    ($fn_name:ident, $inner_fn_name:ident) => {
139        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140            let lhs = self;
141            let shape = lhs
142                .shape()
143                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144            let l_broadcast = shape != *lhs.shape();
145            let r_broadcast = shape != *rhs.shape();
146            match (l_broadcast, r_broadcast) {
147                (true, true) => lhs
148                    .broadcast_as(&shape)?
149                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152                (false, false) => lhs.$inner_fn_name(rhs),
153            }
154        }
155    };
156}
157
158/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
159pub(crate) fn from_storage<S: Into<Shape>>(
160    storage: Storage,
161    shape: S,
162    op: BackpropOp,
163    is_variable: bool,
164) -> Tensor {
165    let dtype = storage.dtype();
166    let device = storage.device();
167    let tensor_ = Tensor_ {
168        id: TensorId::new(),
169        storage: Arc::new(RwLock::new(storage)),
170        layout: Layout::contiguous(shape),
171        op,
172        is_variable,
173        dtype,
174        device,
175    };
176    Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180    pub(crate) fn ones_impl<S: Into<Shape>>(
181        shape: S,
182        dtype: DType,
183        device: &Device,
184        is_variable: bool,
185    ) -> Result<Self> {
186        let none = BackpropOp::none();
187        let shape = shape.into();
188        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189        let layout = Layout::contiguous(shape.clone());
190        storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191        Ok(from_storage(storage, shape, none, is_variable))
192    }
193
194    /// Creates a new tensor filled with ones.
195    ///
196    /// ```rust
197    /// use candle_core::{Tensor, DType, Device};
198    /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
199    /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
200    /// // a == b
201    /// # Ok::<(), candle_core::Error>(())
202    /// ```
203    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204        Self::ones_impl(shape, dtype, device, false)
205    }
206
207    pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208        self.storage_mut().const_set(value, self.layout())
209    }
210
211    pub fn zero_set(&self) -> Result<()> {
212        self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213    }
214
215    pub fn one_set(&self) -> Result<()> {
216        self.const_set(crate::scalar::Scalar::one(self.dtype()))
217    }
218
219    /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
220    ///
221    /// ```rust
222    /// use candle_core::{Tensor, DType, Device};
223    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
224    /// let b = a.ones_like()?;
225    /// // b == a + 1
226    /// # Ok::<(), candle_core::Error>(())
227    /// ```
228    pub fn ones_like(&self) -> Result<Self> {
229        Tensor::ones(self.shape(), self.dtype(), self.device())
230    }
231
232    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
233    // the variable module.
234    pub(crate) fn zeros_impl<S: Into<Shape>>(
235        shape: S,
236        dtype: DType,
237        device: &Device,
238        is_variable: bool,
239    ) -> Result<Self> {
240        let none = BackpropOp::none();
241        let shape = shape.into();
242        let storage = device.zeros(&shape, dtype)?;
243        Ok(from_storage(storage, shape, none, is_variable))
244    }
245
246    /// Creates a new tensor filled with zeros.
247    ///
248    /// ```rust
249    /// use candle_core::{Tensor, DType, Device};
250    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
251    /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
252    /// // a == b
253    /// # Ok::<(), candle_core::Error>(())
254    /// ```
255    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256        Self::zeros_impl(shape, dtype, device, false)
257    }
258
259    /// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
260    /// tensor.
261    ///
262    /// ```rust
263    /// use candle_core::{Tensor, DType, Device};
264    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
265    /// let b = a.zeros_like()?;
266    /// // b is on CPU f32.
267    /// # Ok::<(), candle_core::Error>(())
268    /// ```
269    pub fn zeros_like(&self) -> Result<Self> {
270        Tensor::zeros(self.shape(), self.dtype(), self.device())
271    }
272
273    pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
274        lo: T,
275        up: T,
276        s: S,
277        device: &Device,
278        is_variable: bool,
279    ) -> Result<Self> {
280        let s = s.into();
281        let storage = device.rand_uniform(lo, up, &s)?;
282        let none = BackpropOp::none();
283        Ok(from_storage(storage, s, none, is_variable))
284    }
285
286    pub(crate) fn rand_f64_impl<S: Into<Shape>>(
287        lo: f64,
288        up: f64,
289        s: S,
290        dtype: DType,
291        device: &Device,
292        is_variable: bool,
293    ) -> Result<Self> {
294        let s = s.into();
295        let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
296        let none = BackpropOp::none();
297        Ok(from_storage(storage, s, none, is_variable))
298    }
299
300    /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
301    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
302        lo: T,
303        up: T,
304        s: S,
305        device: &Device,
306    ) -> Result<Self> {
307        Self::rand_impl(lo, up, s, device, false)
308    }
309
310    pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
311        Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
312    }
313
314    pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
315        mean: T,
316        std: T,
317        s: S,
318        device: &Device,
319        is_variable: bool,
320    ) -> Result<Self> {
321        let s = s.into();
322        let storage = device.rand_normal(mean, std, &s)?;
323        let none = BackpropOp::none();
324        Ok(from_storage(storage, s, none, is_variable))
325    }
326
327    pub(crate) fn randn_f64_impl<S: Into<Shape>>(
328        mean: f64,
329        std: f64,
330        s: S,
331        dtype: DType,
332        device: &Device,
333        is_variable: bool,
334    ) -> Result<Self> {
335        let s = s.into();
336        let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
337        let none = BackpropOp::none();
338        Ok(from_storage(storage, s, none, is_variable))
339    }
340
341    pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
342        Tensor::randn_f64_impl(
343            mean,
344            stdev,
345            self.shape(),
346            self.dtype(),
347            self.device(),
348            false,
349        )
350    }
351
352    /// Creates a new tensor initialized with values sampled from a normal distribution with the
353    /// specified `mean` and standard deviation `std`.
354    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
355        mean: T,
356        std: T,
357        s: S,
358        device: &Device,
359    ) -> Result<Self> {
360        Self::randn_impl(mean, std, s, device, false)
361    }
362
363    pub(crate) fn new_impl<A: crate::device::NdArray>(
364        array: A,
365        shape: Shape,
366        device: &Device,
367        is_variable: bool,
368    ) -> Result<Self> {
369        let n: usize = shape.elem_count();
370        let buffer_size: usize = array.shape()?.elem_count();
371        if buffer_size != n {
372            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
373        }
374        let storage = device.storage(array)?;
375        let none = BackpropOp::none();
376        Ok(from_storage(storage, shape, none, is_variable))
377    }
378
379    /// Creates a new tensor on the specified device using the content and shape of the input.
380    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
381        let shape = array.shape()?;
382        Self::new_impl(array, shape, device, false)
383    }
384
385    /// Returns a new tensor with all the elements having the same specified value.
386    ///```rust
387    /// use candle_core::{Tensor, Device};
388    /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
389    ///
390    /// assert_eq!(a.to_vec2::<f64>()?, &[
391    ///     [3.5, 3.5, 3.5, 3.5],
392    ///     [3.5, 3.5, 3.5, 3.5],
393    /// ]);
394    /// # Ok::<(), candle_core::Error>(())
395    pub fn full<D: crate::WithDType, S: Into<Shape>>(
396        value: D,
397        shape: S,
398        device: &Device,
399    ) -> Result<Self> {
400        let none = BackpropOp::none();
401        let shape = shape.into();
402        let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
403        let layout = Layout::contiguous(shape.clone());
404        storage.const_set(value.to_scalar(), &layout)?;
405        Ok(from_storage(storage, shape, none, false))
406    }
407
408    /// Creates a new 1D tensor from an iterator.
409    ///```rust
410    /// use candle_core::{Tensor, Device};
411    /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
412    ///
413    /// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
414    /// # Ok::<(), candle_core::Error>(())
415    /// ```
416    pub fn from_iter<D: crate::WithDType>(
417        iter: impl IntoIterator<Item = D>,
418        device: &Device,
419    ) -> Result<Self> {
420        let data = iter.into_iter().collect::<Vec<_>>();
421        let len = data.len();
422        Self::from_vec_impl(data, len, device, false)
423    }
424
425    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
426    /// difference `1` from `start`.
427    ///```rust
428    /// use candle_core::{Tensor, Device};
429    /// let a = Tensor::arange(2., 5., &Device::Cpu)?;
430    ///
431    /// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
432    /// # Ok::<(), candle_core::Error>(())
433    /// ```
434    pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
435        Self::arange_step(start, end, D::one(), device)
436    }
437
438    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
439    /// difference `step` from `start`.
440    ///```rust
441    /// use candle_core::{Tensor, Device};
442    /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
443    ///
444    /// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
445    /// # Ok::<(), candle_core::Error>(())
446    /// ```
447    pub fn arange_step<D: crate::WithDType>(
448        start: D,
449        end: D,
450        step: D,
451        device: &Device,
452    ) -> Result<Self> {
453        if D::is_zero(&step) {
454            bail!("step cannot be zero")
455        }
456        let mut data = vec![];
457        let mut current = start;
458        if step >= D::zero() {
459            while current < end {
460                data.push(current);
461                current += step;
462            }
463        } else {
464            while current > end {
465                data.push(current);
466                current += step;
467            }
468        }
469        let len = data.len();
470        Self::from_vec_impl(data, len, device, false)
471    }
472
473    pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
474        data: Vec<D>,
475        shape: S,
476        device: &Device,
477        is_variable: bool,
478    ) -> Result<Self> {
479        let shape = shape.into_shape(data.len())?;
480        let storage = device.storage_owned(data)?;
481        let none = BackpropOp::none();
482        Ok(from_storage(storage, shape, none, is_variable))
483    }
484
485    /// Creates a new tensor initialized with values from the input vector. The number of elements
486    /// in this vector must be the same as the number of elements defined by the shape.
487    /// If the device is cpu, no data copy is made.
488    ///```rust
489    /// use candle_core::{Tensor, Device};
490    /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
491    ///
492    /// assert_eq!(a.to_vec2::<f64>()?, &[
493    ///     [1., 2., 3.],
494    ///     [4., 5., 6.]
495    /// ]);
496    /// # Ok::<(), candle_core::Error>(())
497    /// ```
498    pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
499        data: Vec<D>,
500        shape: S,
501        device: &Device,
502    ) -> Result<Self> {
503        Self::from_vec_impl(data, shape, device, false)
504    }
505
506    /// Creates a new tensor initialized with values from the input slice. The number of elements
507    /// in this vector must be the same as the number of elements defined by the shape.
508    ///```rust
509    /// use candle_core::{Tensor, Device};
510    /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
511    /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
512    ///
513    /// assert_eq!(a.to_vec2::<f64>()?, &[
514    ///     [2., 3., 4.],
515    ///     [5., 6., 7.]
516    /// ]);
517    /// # Ok::<(), candle_core::Error>(())
518    /// ```
519    pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
520        array: &[D],
521        shape: S,
522        device: &Device,
523    ) -> Result<Self> {
524        let shape = shape.into_shape(array.len())?;
525        let storage = device.storage_from_slice(array)?;
526        let none = BackpropOp::none();
527        Ok(from_storage(storage, shape, none, false))
528    }
529
530    pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
531        let lhs = self.shape();
532        let rhs = rhs.shape();
533        if lhs != rhs {
534            Err(Error::ShapeMismatchBinaryOp {
535                lhs: lhs.clone(),
536                rhs: rhs.clone(),
537                op,
538            }
539            .bt())
540        } else {
541            Ok(lhs)
542        }
543    }
544
545    /// Returns true if the computation graph should track this op, that is if it is
546    /// a variable or if it has some variable as dependencies.
547    pub fn track_op(&self) -> bool {
548        self.is_variable || self.op.is_some()
549    }
550
551    // TODO: Also make an inplace version or a pre-allocated? This could be tricky
552    // if this can create cycles in the compute graph.
553    binary_op!(add, Add);
554    binary_op!(mul, Mul);
555    binary_op!(sub, Sub);
556    binary_op!(div, Div);
557    binary_op_scalar!(maximum, Maximum);
558    binary_op_scalar!(minimum, Minimum);
559    broadcast_binary_op!(broadcast_add, add);
560    broadcast_binary_op!(broadcast_mul, mul);
561    broadcast_binary_op!(broadcast_sub, sub);
562    broadcast_binary_op!(broadcast_div, div);
563    broadcast_binary_op!(broadcast_maximum, maximum);
564    broadcast_binary_op!(broadcast_minimum, minimum);
565    broadcast_binary_op!(broadcast_eq, eq);
566    broadcast_binary_op!(broadcast_ne, ne);
567    broadcast_binary_op!(broadcast_lt, lt);
568    broadcast_binary_op!(broadcast_le, le);
569    broadcast_binary_op!(broadcast_gt, gt);
570    broadcast_binary_op!(broadcast_ge, ge);
571
572    unary_op!(recip, Recip);
573    unary_op!(neg, Neg);
574    unary_op!(exp, Exp);
575    unary_op!(log, Log);
576    unary_op!(sin, Sin);
577    unary_op!(cos, Cos);
578    unary_op!(tanh, Tanh);
579    unary_op!(abs, Abs);
580    unary_op!(sqr, Sqr);
581    unary_op!(sqrt, Sqrt);
582    unary_op!(gelu, Gelu);
583    unary_op!(gelu_erf, GeluErf);
584    unary_op!(erf, Erf);
585    unary_op!(relu, Relu);
586    unary_op!(silu, Silu);
587    unary_op!(ceil, Ceil);
588    unary_op!(floor, Floor);
589    unary_op!(round, Round);
590    unary_op!(sign, Sign);
591
592    /// Round element of the input tensor to the nearest integer.
593    ///
594    /// If the number of decimals is negative, it specifies the number of positions to the left of
595    /// the decimal point.
596    pub fn round_to(&self, decimals: i32) -> Result<Self> {
597        let mult = 10f64.powi(decimals);
598        (self * mult)?.round()? * (1f64 / mult)
599    }
600
601    /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
602    /// dimensions, an error is returned instead.
603    pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
604        if self.rank() != 0 {
605            Err(Error::UnexpectedNumberOfDims {
606                expected: 0,
607                got: self.rank(),
608                shape: self.shape().clone(),
609            }
610            .bt())?
611        }
612        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
613            let data = S::cpu_storage_as_slice(cpu_storage)?;
614            Ok::<_, Error>(data[self.layout().start_offset()])
615        };
616        match &*self.storage() {
617            Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
618            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
619            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
620        }
621    }
622
623    /// An alias for `to_scalar`.
624    pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
625        self.to_scalar::<S>()
626    }
627
628    /// Repeat this tensor along the specified dimensions.
629    pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
630        // Similar to PyTorch, we extend the number of dimensions of self if needed.
631        let repeats = shape.into();
632        let repeats = repeats.dims();
633        let mut inp = if self.rank() < repeats.len() {
634            let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
635            self.reshape(shape)?
636        } else {
637            self.clone()
638        };
639        for (idx, &repeat) in repeats.iter().enumerate() {
640            if repeat > 1 {
641                inp = Tensor::cat(&vec![&inp; repeat], idx)?
642            }
643        }
644        Ok(inp)
645    }
646
647    /// Creates grids of coordinates specified by the 1D inputs.
648    ///
649    /// # Arguments
650    ///
651    /// * `args` - A slice of 1D tensors.
652    /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
653    ///   first dimension corresponds to the cardinality of the second input and the second
654    ///   dimension corresponds to the cardinality of the first input. If ij is selected, the
655    ///   dimensions are in the same order as the cardinality of the inputs.
656    ///
657    /// # Examples
658    ///
659    /// ```rust
660    /// use candle_core::{Tensor, Device, Shape};
661    /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
662    /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
663    ///
664    /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
665    ///
666    /// assert_eq!(grids_xy.len(), 2);
667    /// assert_eq!(grids_xy[0].dims(), &[3, 3]);
668    ///
669    /// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
670    /// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
671    ///
672    /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
673    ///
674    /// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
675    /// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
676    /// # Ok::<(), candle_core::Error>(())
677    /// ```
678    ///
679    /// # Errors
680    ///
681    /// * Will return `Err` if `args` contains less than 2 tensors.
682    ///
683    pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
684        if args.len() <= 1 {
685            Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
686        }
687        let args: Vec<_> = if xy_indexing {
688            args.iter().rev().collect()
689        } else {
690            args.iter().collect()
691        };
692
693        let mut shape = Vec::with_capacity(args.len());
694        for arg in args.iter() {
695            shape.push(arg.as_ref().dims1()?)
696        }
697
698        let mut grids = Vec::with_capacity(args.len());
699        for idx in 0..args.len() {
700            let mut ones = vec![1usize; args.len()];
701            ones[idx] = shape[idx];
702            let arg = args[idx].as_ref().reshape(ones)?;
703            let mut repeats = shape.clone();
704            repeats[idx] = 1;
705            let repeated_tensor = arg.repeat(repeats)?;
706            grids.push(repeated_tensor);
707        }
708        if xy_indexing {
709            grids.reverse();
710        }
711        Ok(grids)
712    }
713
714    /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
715    /// The input values `mul` and `add` are casted to the appropriate type so some rounding might
716    /// be performed.
717    ///
718    /// ```rust
719    /// use candle_core::{Tensor, Device};
720    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
721    /// let a = a.affine(4., -2.)?;
722    /// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
723    /// # Ok::<(), candle_core::Error>(())
724    /// ```
725    pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
726        if self.elem_count() == 0 {
727            return Ok(self.clone());
728        }
729        let storage = self.storage().affine(self.layout(), mul, add)?;
730        let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
731        Ok(from_storage(storage, self.shape(), op, false))
732    }
733
734    /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
735    pub fn elu(&self, alpha: f64) -> Result<Self> {
736        if self.elem_count() == 0 {
737            return Ok(self.clone());
738        }
739        let storage = self.storage().elu(self.layout(), alpha)?;
740        let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
741        Ok(from_storage(storage, self.shape(), op, false))
742    }
743
744    /// Raise the tensor to some float exponent `e`.
745    pub fn powf(&self, e: f64) -> Result<Self> {
746        if self.elem_count() == 0 {
747            return Ok(self.clone());
748        }
749        let storage = self.storage().powf(self.layout(), e)?;
750        let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
751        Ok(from_storage(storage, self.shape(), op, false))
752    }
753
754    pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
755        if dim >= self.dims().len() {
756            Err(Error::DimOutOfRange {
757                shape: self.shape().clone(),
758                dim: dim as i32,
759                op,
760            }
761            .bt())?
762        } else {
763            Ok(())
764        }
765    }
766
767    /// Split a tensor into the specified number of chunks, this may return less chunks than
768    /// specified.
769    pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
770        let dim = dim.to_index(self.shape(), "chunk")?;
771        let size = self.dim(dim)?;
772        if size < chunks {
773            (0..size).map(|i| self.narrow(dim, i, 1)).collect()
774        } else {
775            let chunk_size = size / chunks;
776            let cnt_additional = size % chunks;
777            let mut tensors = vec![];
778            let mut sum_chunk_size = 0;
779            for i in 0..chunks {
780                let chunk_size = if i < cnt_additional {
781                    chunk_size + 1
782                } else {
783                    chunk_size
784                };
785                let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
786                tensors.push(tensor);
787                sum_chunk_size += chunk_size
788            }
789            Ok(tensors)
790        }
791    }
792
793    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
794    /// ranges from `start` to `start + len`.
795    /// ```
796    /// use candle_core::{Tensor, Device};
797    /// let a = Tensor::new(&[
798    ///     [0f32, 1., 2.],
799    ///     [3.  , 4., 5.],
800    ///     [6.  , 7., 8.]
801    /// ], &Device::Cpu)?;
802    ///
803    /// let b = a.narrow(0, 1, 2)?;
804    /// assert_eq!(b.shape().dims(), &[2, 3]);
805    /// assert_eq!(b.to_vec2::<f32>()?, &[
806    ///     [3., 4., 5.],
807    ///     [6., 7., 8.]
808    /// ]);
809    ///
810    /// let c = a.narrow(1, 1, 1)?;
811    /// assert_eq!(c.shape().dims(), &[3, 1]);
812    /// assert_eq!(c.to_vec2::<f32>()?, &[
813    ///     [1.],
814    ///     [4.],
815    ///     [7.]
816    /// ]);
817    /// # Ok::<(), candle_core::Error>(())
818    /// ```
819    pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
820        let dims = self.dims();
821        let dim = dim.to_index(self.shape(), "narrow")?;
822        let err = |msg| {
823            Err::<(), _>(
824                Error::NarrowInvalidArgs {
825                    shape: self.shape().clone(),
826                    dim,
827                    start,
828                    len,
829                    msg,
830                }
831                .bt(),
832            )
833        };
834        if start > dims[dim] {
835            err("start > dim_len")?
836        }
837        if start.saturating_add(len) > dims[dim] {
838            err("start + len > dim_len")?
839        }
840        if start == 0 && dims[dim] == len {
841            Ok(self.clone())
842        } else {
843            let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
844            let layout = self.layout().narrow(dim, start, len)?;
845            let tensor_ = Tensor_ {
846                id: TensorId::new(),
847                storage: self.storage.clone(),
848                layout,
849                op,
850                is_variable: false,
851                dtype: self.dtype,
852                device: self.device.clone(),
853            };
854            Ok(Tensor(Arc::new(tensor_)))
855        }
856    }
857
858    fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
859        match dims {
860            [] => Ok(self),
861            [i] => self.squeeze(*i),
862            dims => {
863                let dims = self
864                    .dims()
865                    .iter()
866                    .enumerate()
867                    .filter_map(|(dim_idx, &v)| {
868                        if dims.contains(&dim_idx) {
869                            None
870                        } else {
871                            Some(v)
872                        }
873                    })
874                    .collect::<Vec<_>>();
875                self.reshape(dims)
876            }
877        }
878    }
879
880    fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
881        let dim = dim.to_index(self.shape(), op.name())?;
882        let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
883        let mut dims = self.dims().to_vec();
884        dims[dim] = 1;
885        let op = match op {
886            ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
887                BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
888            }
889            ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
890        };
891        let res = from_storage(storage, dims, op, false);
892        if keepdim {
893            Ok(res)
894        } else {
895            res.squeeze_dims(&[dim])
896        }
897    }
898
899    fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
900        let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
901        let storage = self
902            .storage()
903            .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
904        let mut dims = self.dims().to_vec();
905        for &sum_dim in sum_dims.iter() {
906            dims[sum_dim] = 1
907        }
908        let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
909        let sum = from_storage(storage, dims, op, false);
910        if keepdim {
911            Ok(sum)
912        } else {
913            sum.squeeze_dims(&sum_dims)
914        }
915    }
916
917    /// Roll the tensor input along the given dimension.
918    /// Elements that are shifted beyond the last position are re-introduced at the first position.
919    ///
920    /// ```rust
921    /// # use candle_core::{Tensor, Device};
922    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
923    /// let tensor = tensor.roll(1, 0)?;
924    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
925    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
926    /// let tensor = tensor.roll(-1, 0)?;
927    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
928    /// # Ok::<(), candle_core::Error>(())
929    /// ```
930    pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
931    where
932        D: Dim + Clone,
933    {
934        let dim = dim.to_index(self.shape(), "roll")?;
935        let dim_size = self.dim(dim)?;
936        let shift = shift.rem_euclid(dim_size as i32) as usize;
937        if shift == 0 {
938            Ok(self.clone())
939        } else {
940            let a = self.narrow(dim, 0, dim_size - shift)?;
941            let b = self.narrow(dim, dim_size - shift, shift)?;
942            Tensor::cat(&[&b, &a], dim)
943        }
944    }
945
946    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
947    /// input dimensions.
948    ///
949    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
950    /// that the number of elements for each dimension index in `sum_dims` is 1.
951    ///
952    /// ```rust
953    /// use candle_core::{Tensor, Device};
954    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
955    /// let s = a.sum_keepdim(0)?;
956    /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
957    /// let s = a.sum_keepdim(1)?;
958    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
959    /// let s = a.sum_keepdim((0, 1))?;
960    /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
961    /// # Ok::<(), candle_core::Error>(())
962    /// ```
963    pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
964        self.sum_impl(sum_dims, true)
965    }
966
967    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
968    /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
969    /// kept.
970    pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
971        self.sum_impl(sum_dims, false)
972    }
973
974    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
975    /// input dimensions.
976    ///
977    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
978    /// that the number of elements for each dimension index in `mean_dims` is 1.
979    ///
980    /// ```rust
981    /// use candle_core::{Tensor, Device};
982    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
983    /// let s = a.mean_keepdim(0)?;
984    /// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
985    /// let s = a.mean_keepdim(1)?;
986    /// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
987    /// let s = a.mean_keepdim((0, 1))?;
988    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
989    /// # Ok::<(), candle_core::Error>(())
990    /// ```
991    pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
992        let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
993        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
994        let scale = 1f64 / (reduced_dim as f64);
995        self.sum_impl(mean_dims, true)? * scale
996    }
997
998    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
999    /// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
1000    /// kept.
1001    pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1002        let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
1003        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1004        let scale = 1f64 / (reduced_dim as f64);
1005        self.sum_impl(mean_dims, false)? * scale
1006    }
1007
1008    /// Returns the unbiased variance over the selected dimension.
1009    pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1010        let dim = dim.to_index(self.shape(), "var")?;
1011        let mean = self.mean_keepdim(dim)?;
1012        let squares = self.broadcast_sub(&mean)?.sqr()?;
1013        squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1014    }
1015
1016    /// Returns the unbiased variance over the selected dimension.
1017    pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1018        let dim = dim.to_index(self.shape(), "var")?;
1019        self.var_keepdim(dim)?.squeeze(dim)
1020    }
1021
1022    /// Gathers the maximum value across the selected dimension. The resulting shape has the same
1023    /// number of dimensions as the original tensor and the select dimension has a single element.
1024    pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1025        self.reduce_impl(dim, true, ReduceOp::Max)
1026    }
1027
1028    /// Similar to `max_keepdim` but the target dimension is squeezed.
1029    pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1030        self.reduce_impl(dim, false, ReduceOp::Max)
1031    }
1032
1033    /// Gathers the minimum value across the selected dimension. The resulting shape has the same
1034    /// number of dimensions as the original tensor and the select dimension has a single element.
1035    pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1036        self.reduce_impl(dim, true, ReduceOp::Min)
1037    }
1038
1039    /// Similar to `min_keepdim` but the target dimension is squeezed.
1040    pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1041        self.reduce_impl(dim, false, ReduceOp::Min)
1042    }
1043
1044    pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1045        self.reduce_impl(dim, true, ReduceOp::ArgMax)
1046    }
1047
1048    /// Similar to `argmax_keepdim` but the target dimension is squeezed.
1049    pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1050        self.reduce_impl(dim, false, ReduceOp::ArgMax)
1051    }
1052
1053    pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1054        self.reduce_impl(dim, true, ReduceOp::ArgMin)
1055    }
1056
1057    /// Similar to `argmin_keepdim` but the target dimension is squeezed.
1058    pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1059        self.reduce_impl(dim, false, ReduceOp::ArgMin)
1060    }
1061
1062    /// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
1063    /// comparison operation is specified by the `op` argument.
1064    ///
1065    /// The returned tensor has the same shape as the original tensors and uses `u8` elements.
1066    pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1067        let rhs = match rhs.to_tensor_scalar()? {
1068            crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1069            crate::scalar::TensorScalar::Scalar(rhs) => rhs
1070                .to_dtype(self.dtype())?
1071                .to_device(self.device())?
1072                .broadcast_as(self.shape())?,
1073        };
1074        let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1075        let storage = self
1076            .storage()
1077            .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1078        let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1079        Ok(from_storage(storage, shape.dims(), op, false))
1080    }
1081
1082    /// Element-wise equality.
1083    pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1084        self.cmp(rhs, CmpOp::Eq)
1085    }
1086
1087    /// Element-wise non-equality.
1088    pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1089        self.cmp(rhs, CmpOp::Ne)
1090    }
1091
1092    /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
1093    /// rhs` and 0 otherwise.
1094    pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1095        self.cmp(rhs, CmpOp::Lt)
1096    }
1097
1098    /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
1099    /// rhs` and 0 otherwise.
1100    pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1101        self.cmp(rhs, CmpOp::Gt)
1102    }
1103
1104    /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
1105    /// rhs` and 0 otherwise.
1106    pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1107        self.cmp(rhs, CmpOp::Ge)
1108    }
1109
1110    /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
1111    /// rhs` and 0 otherwise.
1112    pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1113        self.cmp(rhs, CmpOp::Le)
1114    }
1115
1116    /// Clamp the tensor values to be between `min` and `max`.
1117    pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1118        self.maximum(min)?.minimum(max)
1119    }
1120
1121    /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
1122    ///
1123    /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
1124    /// tensor also has three dimensions, `(batch, channels, target_size)`.
1125    pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1126        let (n, c, _l) = self.dims3()?;
1127        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1128        let storage = self
1129            .storage()
1130            .upsample_nearest1d(self.layout(), target_size)?;
1131        Ok(from_storage(storage, (n, c, target_size), op, false))
1132    }
1133
1134    /// Alias for `interpolate1d`.
1135    pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1136        self.interpolate1d(target_size)
1137    }
1138
1139    /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
1140    /// nearest element.
1141    ///
1142    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1143    /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
1144    pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1145        let (n, c, _h, _w) = self.dims4()?;
1146        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1147            arg,
1148            target_h,
1149            target_w,
1150        });
1151        let storage = self
1152            .storage()
1153            .upsample_nearest2d(self.layout(), target_h, target_w)?;
1154        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1155    }
1156
1157    /// Alias for `interpolate2d`.
1158    pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1159        self.interpolate2d(target_h, target_w)
1160    }
1161
1162    /// 2D average pooling over an input tensor with multiple channels.
1163    ///
1164    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1165    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1166    /// the two last dimensions using a kernel of size `sz`. The returned element is the average
1167    /// value over the kernel window.
1168    pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1169        let sz = sz.to_usize2();
1170        self.avg_pool2d_with_stride(sz, sz)
1171    }
1172
1173    /// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
1174    /// kernel size.
1175    pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1176        &self,
1177        kernel_size: T,
1178        stride: T,
1179    ) -> Result<Self> {
1180        let kernel_size = kernel_size.to_usize2();
1181        let stride = stride.to_usize2();
1182        let (n, c, h, w) = self.dims4()?;
1183        if h < kernel_size.0 || w < kernel_size.1 {
1184            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1185        }
1186        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
1187        let h_out = (h - kernel_size.0) / stride.0 + 1;
1188        let w_out = (w - kernel_size.1) / stride.1 + 1;
1189        let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1190            arg,
1191            kernel_size,
1192            stride,
1193        });
1194        let storage = self
1195            .storage()
1196            .avg_pool2d(self.layout(), kernel_size, stride)?;
1197        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1198    }
1199
1200    /// 2D max pooling over an input tensor with multiple channels.
1201    ///
1202    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1203    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1204    /// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
1205    /// value over the kernel window.
1206    pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1207        let sz = sz.to_usize2();
1208        self.max_pool2d_with_stride(sz, sz)
1209    }
1210
1211    /// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
1212    /// kernel size.
1213    pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1214        &self,
1215        kernel_size: T,
1216        stride: T,
1217    ) -> Result<Self> {
1218        let kernel_size = kernel_size.to_usize2();
1219        let stride = stride.to_usize2();
1220        let (n, c, h, w) = self.dims4()?;
1221        if h < kernel_size.0 || w < kernel_size.1 {
1222            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1223        }
1224        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
1225        let h_out = (h - kernel_size.0) / stride.0 + 1;
1226        let w_out = (w - kernel_size.1) / stride.1 + 1;
1227        let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1228            arg,
1229            kernel_size,
1230            stride,
1231        });
1232        let storage = self
1233            .storage()
1234            .max_pool2d(self.layout(), kernel_size, stride)?;
1235        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1236    }
1237
1238    /// Returns the matrix-multiplication of the input tensor with the other provided tensor.
1239    ///
1240    /// # Arguments
1241    ///
1242    /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
1243    /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
1244    ///
1245    /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
1246    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1247        let a_dims = self.shape().dims();
1248        let b_dims = rhs.shape().dims();
1249
1250        let dim = a_dims.len();
1251
1252        if dim < 2 || b_dims.len() != dim {
1253            Err(Error::ShapeMismatchBinaryOp {
1254                lhs: self.shape().clone(),
1255                rhs: rhs.shape().clone(),
1256                op: "matmul",
1257            }
1258            .bt())?
1259        }
1260
1261        let m = a_dims[dim - 2];
1262        let k = a_dims[dim - 1];
1263        let k2 = b_dims[dim - 2];
1264        let n = b_dims[dim - 1];
1265
1266        let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1267        if c_shape.elem_count() == 0 || k == 0 {
1268            return Tensor::zeros(c_shape, self.dtype(), self.device());
1269        }
1270        let batching: usize = a_dims[..dim - 2].iter().product();
1271        let batching_b: usize = b_dims[..dim - 2].iter().product();
1272        if k != k2 || batching != batching_b {
1273            Err(Error::ShapeMismatchBinaryOp {
1274                lhs: self.shape().clone(),
1275                rhs: rhs.shape().clone(),
1276                op: "matmul",
1277            }
1278            .bt())?
1279        }
1280
1281        let storage = self.storage().matmul(
1282            &rhs.storage(),
1283            (batching, m, n, k),
1284            self.layout(),
1285            rhs.layout(),
1286        )?;
1287        let op = BackpropOp::new2(self, rhs, Op::Matmul);
1288        Ok(from_storage(storage, c_shape, op, false))
1289    }
1290
1291    /// Matrix-multiplication with broadcasting support.
1292    ///
1293    /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
1294    /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
1295    /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
1296    pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1297        let lhs = self;
1298        let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1299        let l_broadcast = l_shape != *lhs.shape();
1300        let r_broadcast = r_shape != *rhs.shape();
1301        // TODO: Avoid concretising the broadcasted matrixes via contiguous.
1302        match (l_broadcast, r_broadcast) {
1303            (true, true) => lhs
1304                .broadcast_as(&l_shape)?
1305                .contiguous()?
1306                .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1307            (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1308            (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1309            (false, false) => lhs.matmul(rhs),
1310        }
1311    }
1312
1313    /// Returns a tensor with the same shape as the input tensor, the values are taken from
1314    /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
1315    /// input tensor is equal to zero.
1316    pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1317        let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1318        let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1319        let storage = self.storage().where_cond(
1320            self.layout(),
1321            &on_true.storage(),
1322            on_true.layout(),
1323            &on_false.storage(),
1324            on_false.layout(),
1325        )?;
1326        let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1327        Ok(from_storage(storage, shape, op, false))
1328    }
1329
1330    /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
1331    /// values hold in the `ids` tensor.
1332    ///
1333    /// # Arguments
1334    ///
1335    /// * `self` - A tensor with dimensions `v, h`.
1336    /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
1337    ///
1338    /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
1339    /// vocabulary size, and `h` the hidden size.
1340    ///
1341    /// ```rust
1342    /// use candle_core::{Tensor, Device};
1343    /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1344    /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
1345    /// let emb = values.embedding(&ids)?;
1346    /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
1347    /// # Ok::<(), candle_core::Error>(())
1348    /// ```
1349    pub fn embedding(&self, ids: &Self) -> Result<Self> {
1350        if self.rank() != 2 || ids.rank() != 1 {
1351            Err(Error::ShapeMismatchBinaryOp {
1352                lhs: self.shape().clone(),
1353                rhs: ids.shape().clone(),
1354                op: "embedding",
1355            }
1356            .bt())?
1357        }
1358        self.index_select(ids, 0)
1359    }
1360
1361    fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1362        let source_dims = source.dims();
1363        let self_dims = self.dims();
1364        let mismatch = if source_dims.len() != self_dims.len() {
1365            true
1366        } else {
1367            let mut mismatch = false;
1368            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1369                if i != dim && d1 != d2 {
1370                    mismatch = true;
1371                    break;
1372                }
1373            }
1374            mismatch
1375        };
1376        if mismatch {
1377            Err(Error::ShapeMismatchBinaryOp {
1378                op: "scatter (self, src)",
1379                lhs: self.shape().clone(),
1380                rhs: source.shape().clone(),
1381            }
1382            .bt())?
1383        }
1384        if indexes.dims() != source.dims() {
1385            Err(Error::ShapeMismatchBinaryOp {
1386                op: "scatter (indexes, src)",
1387                lhs: indexes.shape().clone(),
1388                rhs: source.shape().clone(),
1389            }
1390            .bt())?
1391        }
1392        Ok(())
1393    }
1394
1395    pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1396        let dim = dim.to_index(self.shape(), "scatter")?;
1397        self.scatter_checks(indexes, source, dim)?;
1398        let shape = self.shape();
1399        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1400        self.storage()
1401            .copy_strided_src(&mut storage, 0, self.layout())?;
1402        let layout = Layout::contiguous(shape);
1403        storage.scatter_set(
1404            &layout,
1405            &indexes.storage(),
1406            indexes.layout(),
1407            &source.storage(),
1408            source.layout(),
1409            dim,
1410        )?;
1411        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1412            Op::Scatter(t1, t2, t3, dim)
1413        });
1414        Ok(from_storage(storage, self.shape(), op, false))
1415    }
1416
1417    pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1418        if self.same_storage(source) {
1419            crate::bail!("cannot use slice_set when self and src share their storage")
1420        }
1421        let dim = dim.to_index(self.shape(), "scatter-set")?;
1422        self.scatter_checks(indexes, source, dim)?;
1423        self.storage_mut().scatter_set(
1424            self.layout(),
1425            &indexes.storage(),
1426            indexes.layout(),
1427            &source.storage(),
1428            source.layout(),
1429            dim,
1430        )?;
1431        Ok(())
1432    }
1433
1434    pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1435        let dim = dim.to_index(self.shape(), "scatter-add")?;
1436        self.scatter_checks(indexes, source, dim)?;
1437        let shape = self.shape();
1438        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1439        self.storage()
1440            .copy_strided_src(&mut storage, 0, self.layout())?;
1441        let layout = Layout::contiguous(shape);
1442        storage.scatter_add(
1443            &layout,
1444            &indexes.storage(),
1445            indexes.layout(),
1446            &source.storage(),
1447            source.layout(),
1448            dim,
1449        )?;
1450        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1451            Op::ScatterAdd(t1, t2, t3, dim)
1452        });
1453        Ok(from_storage(storage, self.shape(), op, false))
1454    }
1455
1456    pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1457        if self.same_storage(source) {
1458            crate::bail!("cannot use slice_set when self and src share their storage")
1459        }
1460        let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1461        self.scatter_checks(indexes, source, dim)?;
1462        self.storage_mut().scatter_add(
1463            self.layout(),
1464            &indexes.storage(),
1465            indexes.layout(),
1466            &source.storage(),
1467            source.layout(),
1468            dim,
1469        )?;
1470        Ok(())
1471    }
1472
1473    /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
1474    pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1475        let dim = dim.to_index(self.shape(), "slice-scatter")?;
1476        if dim == 0 {
1477            self.slice_scatter0(src, start)
1478        } else {
1479            // TODO: Maybe we want to add a more efficient implementation at some point.
1480            self.transpose(0, dim)?
1481                .slice_scatter0(&src.transpose(0, dim)?, start)?
1482                .transpose(0, dim)
1483        }
1484    }
1485
1486    /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
1487    pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1488        if self.dtype() != src.dtype() {
1489            Err(Error::DTypeMismatchBinaryOp {
1490                lhs: self.dtype(),
1491                rhs: src.dtype(),
1492                op: "slice-scatter",
1493            }
1494            .bt())?
1495        }
1496        if self.device().location() != src.device.location() {
1497            Err(Error::DeviceMismatchBinaryOp {
1498                lhs: self.device().location(),
1499                rhs: src.device().location(),
1500                op: "slice-scatter",
1501            }
1502            .bt())?
1503        }
1504        if self.rank() != src.rank() {
1505            Err(Error::UnexpectedNumberOfDims {
1506                expected: self.rank(),
1507                got: src.rank(),
1508                shape: src.shape().clone(),
1509            }
1510            .bt())?
1511        }
1512        let shape_ok =
1513            self.dims()
1514                .iter()
1515                .zip(src.dims().iter())
1516                .enumerate()
1517                .all(|(dim_idx, (&d1, &d2))| {
1518                    if 0 == dim_idx {
1519                        d2 + start <= d1
1520                    } else {
1521                        d1 == d2
1522                    }
1523                });
1524        if !shape_ok {
1525            Err(Error::ShapeMismatchBinaryOp {
1526                op: "slice-scatter (self, src)",
1527                lhs: self.shape().clone(),
1528                rhs: src.shape().clone(),
1529            }
1530            .bt())?
1531        }
1532        let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1533        self.storage()
1534            .copy_strided_src(&mut storage, 0, self.layout())?;
1535        let offset = start * src.dims()[1..].iter().product::<usize>();
1536        src.storage()
1537            .copy_strided_src(&mut storage, offset, src.layout())?;
1538        let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1539        Ok(from_storage(storage, self.shape(), op, false))
1540    }
1541
1542    /// Accumulate element from `source` at indexes `indexes` and add them to `self`.
1543    pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1544        let dim = dim.to_index(self.shape(), "index-add")?;
1545        let source_dims = source.dims();
1546        let self_dims = self.dims();
1547        let mismatch = if source_dims.len() != self_dims.len() {
1548            true
1549        } else {
1550            let mut mismatch = false;
1551            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1552                if i != dim && d1 != d2 {
1553                    mismatch = true;
1554                    break;
1555                }
1556            }
1557            mismatch
1558        };
1559        if mismatch {
1560            Err(Error::ShapeMismatchBinaryOp {
1561                op: "index-add (self, source)",
1562                lhs: self.shape().clone(),
1563                rhs: source.shape().clone(),
1564            }
1565            .bt())?
1566        }
1567        // The number of element in indexes must match the dimension on which the add is
1568        // performed on the source tensor (and the index values from `indexes` are taken from
1569        // the target tensor self)
1570        let indexes_len = indexes.dims1()?;
1571        if source_dims[dim] != indexes_len {
1572            Err(Error::ShapeMismatchBinaryOp {
1573                op: "index-add (ids, source))",
1574                lhs: indexes.shape().clone(),
1575                rhs: source.shape().clone(),
1576            }
1577            .bt())?
1578        }
1579        let storage = self.storage().index_add(
1580            self.layout(),
1581            &indexes.storage(),
1582            indexes.layout(),
1583            &source.storage(),
1584            source.layout(),
1585            dim,
1586        )?;
1587        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1588            Op::IndexAdd(t1, t2, t3, dim)
1589        });
1590        Ok(from_storage(storage, self.shape(), op, false))
1591    }
1592
1593    /// Gather values across the target dimension.
1594    ///
1595    /// # Arguments
1596    ///
1597    /// * `self` - The input tensor.
1598    /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
1599    ///   and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
1600    /// * `dim` - the target dimension.
1601    ///
1602    /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
1603    /// dimension `dim` by the values in `indexes`.
1604    pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1605        let dim = dim.to_index(self.shape(), "gather")?;
1606
1607        let self_dims = self.dims();
1608        let indexes_dims = indexes.dims();
1609        let mismatch = if indexes_dims.len() != self_dims.len() {
1610            true
1611        } else {
1612            let mut mismatch = false;
1613            for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1614                if i != dim && d1 < d2 {
1615                    mismatch = true;
1616                    break;
1617                }
1618            }
1619            mismatch
1620        };
1621        if mismatch {
1622            Err(Error::ShapeMismatchBinaryOp {
1623                op: "gather",
1624                lhs: self.shape().clone(),
1625                rhs: indexes.shape().clone(),
1626            }
1627            .bt())?
1628        }
1629        let storage =
1630            self.storage()
1631                .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1632        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1633        Ok(from_storage(storage, indexes.shape(), op, false))
1634    }
1635
1636    /// Select values for the input tensor at the target indexes across the specified dimension.
1637    ///
1638    /// The `indexes` is argument is an int tensor with a single dimension.
1639    /// The output has the same number of dimension as the `self` input. The target dimension of
1640    /// the output has length the length of `indexes` and the values are taken from `self` using
1641    /// the index from `indexes`. Other dimensions have the same number of elements as the input
1642    /// tensor.
1643    pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1644        let dim = dim.to_index(self.shape(), "index-select")?;
1645        let indexes_len = match indexes.dims() {
1646            [l] => *l,
1647            _ => Err(Error::ShapeMismatchBinaryOp {
1648                lhs: self.shape().clone(),
1649                rhs: indexes.shape().clone(),
1650                op: "index-select",
1651            }
1652            .bt())?,
1653        };
1654        let storage = self.storage().index_select(
1655            &indexes.storage(),
1656            self.layout(),
1657            indexes.layout(),
1658            dim,
1659        )?;
1660        let mut dims = self.dims().to_vec();
1661        dims[dim] = indexes_len;
1662        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1663        Ok(from_storage(storage, dims, op, false))
1664    }
1665
1666    /// Returns an iterator over position of the elements in the storage when ranging over the
1667    /// index tuples in lexicographic order.
1668    pub fn strided_index(&self) -> crate::StridedIndex {
1669        self.layout.strided_index()
1670    }
1671
1672    /// Similar to `strided_index` but returns the position of the start of each contiguous block
1673    /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
1674    /// will only return the start offset and the size would be the number of elements in the
1675    /// tensor.
1676    pub fn strided_blocks(&self) -> crate::StridedBlocks {
1677        self.layout.strided_blocks()
1678    }
1679
1680    /// Returns the data contained in a 1D tensor as a vector of scalar values.
1681    pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1682        if self.rank() != 1 {
1683            Err(Error::UnexpectedNumberOfDims {
1684                expected: 1,
1685                got: self.rank(),
1686                shape: self.shape().clone(),
1687            }
1688            .bt())?
1689        }
1690        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1691            let data = S::cpu_storage_as_slice(cpu_storage)?;
1692            let data = match self.layout.contiguous_offsets() {
1693                Some((o1, o2)) => data[o1..o2].to_vec(),
1694                None => self.strided_index().map(|i| data[i]).collect(),
1695            };
1696            Ok::<Vec<_>, Error>(data)
1697        };
1698        match &*self.storage() {
1699            Storage::Cpu(storage) => from_cpu_storage(storage),
1700            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1701            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1702        }
1703    }
1704
1705    /// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
1706    pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1707        let (dim1, dim2) = self.dims2()?;
1708        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1709            let data = S::cpu_storage_as_slice(cpu_storage)?;
1710            let mut rows = vec![];
1711            match self.layout.contiguous_offsets() {
1712                Some((o1, o2)) => {
1713                    let data = &data[o1..o2];
1714                    for idx_row in 0..dim1 {
1715                        rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1716                    }
1717                }
1718                None => {
1719                    let mut src_index = self.strided_index();
1720                    for _idx_row in 0..dim1 {
1721                        let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1722                        rows.push(row)
1723                    }
1724                    assert!(src_index.next().is_none());
1725                }
1726            }
1727            Ok(rows)
1728        };
1729        match &*self.storage() {
1730            Storage::Cpu(storage) => from_cpu_storage(storage),
1731            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1732            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1733        }
1734    }
1735
1736    /// Returns the data contained in a 3D tensor.
1737    pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1738        let (dim1, dim2, dim3) = self.dims3()?;
1739        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1740            let data = S::cpu_storage_as_slice(cpu_storage)?;
1741            let mut top_rows = vec![];
1742            match self.layout.contiguous_offsets() {
1743                Some((o1, o2)) => {
1744                    let data = &data[o1..o2];
1745                    let dim23 = dim2 * dim3;
1746                    for idx1 in 0..dim1 {
1747                        let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1748                        let mut rows = vec![];
1749                        for idx2 in 0..dim2 {
1750                            rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
1751                        }
1752                        top_rows.push(rows);
1753                    }
1754                }
1755                None => {
1756                    let mut src_index = self.strided_index();
1757                    for _idx in 0..dim1 {
1758                        let mut rows = vec![];
1759                        for _jdx in 0..dim2 {
1760                            let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
1761                            rows.push(row)
1762                        }
1763                        top_rows.push(rows);
1764                    }
1765                    assert!(src_index.next().is_none());
1766                }
1767            }
1768            Ok(top_rows)
1769        };
1770        match &*self.storage() {
1771            Storage::Cpu(storage) => from_cpu_storage(storage),
1772            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1773            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1774        }
1775    }
1776
1777    /// The dtype for the elements stored in the input tensor.
1778    pub fn dtype(&self) -> DType {
1779        self.dtype
1780    }
1781
1782    /// The device on which the input tensor is located.
1783    pub fn device(&self) -> &Device {
1784        &self.device
1785    }
1786
1787    /// The tensor shape, i.e. dimension sizes on each axis.
1788    pub fn shape(&self) -> &Shape {
1789        self.layout().shape()
1790    }
1791
1792    /// The dimension size for this tensor on each axis.
1793    pub fn dims(&self) -> &[usize] {
1794        self.shape().dims()
1795    }
1796
1797    /// The dimension size for a specified dimension index.
1798    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
1799        let dim = dim.to_index(self.shape(), "dim")?;
1800        Ok(self.dims()[dim])
1801    }
1802
1803    /// The layout of the input tensor, this stores both the shape of the tensor as well as the
1804    /// strides and the start offset to apply to the underlying storage.
1805    pub fn layout(&self) -> &Layout {
1806        &self.layout
1807    }
1808
1809    pub fn stride(&self) -> &[usize] {
1810        self.layout.stride()
1811    }
1812
1813    /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
1814    pub fn rank(&self) -> usize {
1815        self.shape().rank()
1816    }
1817
1818    /// The number of elements stored in this tensor.
1819    pub fn elem_count(&self) -> usize {
1820        self.shape().elem_count()
1821    }
1822
1823    /// The unique identifier for this tensor.
1824    pub fn id(&self) -> TensorId {
1825        self.id
1826    }
1827
1828    /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is
1829    /// tracked and on which backpropagation can be performed.
1830    pub fn is_variable(&self) -> bool {
1831        self.is_variable
1832    }
1833
1834    pub(crate) fn op(&self) -> &Option<Op> {
1835        &self.op
1836    }
1837
1838    /// Computes the max of all the elements in this tensor and returns a tensor holding this
1839    /// scalar with zero dimensions.
1840    ///
1841    /// ```rust
1842    /// use candle_core::{Tensor, Device};
1843    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1844    /// let tensor = tensor.max_all()?;
1845    /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
1846    /// # Ok::<(), candle_core::Error>(())
1847    /// ```
1848    pub fn max_all(&self) -> Result<Tensor> {
1849        if self.rank() == 0 {
1850            Ok(self.clone())
1851        } else {
1852            self.flatten_all()?.max(0)
1853        }
1854    }
1855
1856    /// Computes the min of all the elements in this tensor and returns a tensor holding this
1857    /// scalar with zero dimensions.
1858    ///
1859    /// ```rust
1860    /// use candle_core::{Tensor, Device};
1861    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1862    /// let tensor = tensor.min_all()?;
1863    /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
1864    /// # Ok::<(), candle_core::Error>(())
1865    /// ```
1866    pub fn min_all(&self) -> Result<Tensor> {
1867        if self.rank() == 0 {
1868            Ok(self.clone())
1869        } else {
1870            self.flatten_all()?.min(0)
1871        }
1872    }
1873
1874    /// Computes the sum of all the elements in this tensor and returns a tensor holding this
1875    /// scalar with zero dimensions.
1876    ///
1877    /// ```rust
1878    /// use candle_core::{Tensor, Device};
1879    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1880    /// let tensor = tensor.sum_all()?;
1881    /// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
1882    /// # Ok::<(), candle_core::Error>(())
1883    /// ```
1884    pub fn sum_all(&self) -> Result<Tensor> {
1885        let dims: Vec<_> = (0..self.rank()).collect();
1886        self.sum(dims)
1887    }
1888
1889    pub fn mean_all(&self) -> Result<Tensor> {
1890        self.sum_all()? / self.elem_count() as f64
1891    }
1892
1893    fn flatten_<D1: Dim, D2: Dim>(
1894        &self,
1895        start_dim: Option<D1>,
1896        end_dim: Option<D2>,
1897    ) -> Result<Tensor> {
1898        if self.rank() == 0 {
1899            self.reshape(1)
1900        } else {
1901            let start_dim = match start_dim {
1902                None => 0,
1903                Some(dim) => dim.to_index(self.shape(), "flatten")?,
1904            };
1905            let end_dim = match end_dim {
1906                None => self.rank() - 1,
1907                Some(dim) => dim.to_index(self.shape(), "flatten")?,
1908            };
1909            if start_dim < end_dim {
1910                let dims = self.dims();
1911                let mut dst_dims = dims[..start_dim].to_vec();
1912                dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
1913                if end_dim + 1 < dims.len() {
1914                    dst_dims.extend(&dims[end_dim + 1..]);
1915                }
1916                self.reshape(dst_dims)
1917            } else {
1918                Ok(self.clone())
1919            }
1920        }
1921    }
1922
1923    /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
1924    /// inclusive).
1925    pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
1926        self.flatten_(Some(start_dim), Some(end_dim))
1927    }
1928
1929    /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
1930    pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
1931        self.flatten_(None::<usize>, Some(end_dim))
1932    }
1933
1934    /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
1935    /// dimension.
1936    pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
1937        self.flatten_(Some(start_dim), None::<usize>)
1938    }
1939
1940    /// Flattens the input tensor by reshaping it into a one dimension tensor.
1941    ///
1942    /// ```rust
1943    /// use candle_core::{Tensor, Device};
1944    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1945    /// let tensor = tensor.flatten_all()?;
1946    /// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
1947    /// # Ok::<(), candle_core::Error>(())
1948    /// ```
1949    pub fn flatten_all(&self) -> Result<Tensor> {
1950        self.flatten_(None::<usize>, None::<usize>)
1951    }
1952
1953    /// Returns the sub-tensor fixing the index at `i` on the first dimension.
1954    ///
1955    /// ```rust
1956    /// use candle_core::{Tensor, Device};
1957    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1958    /// let t = tensor.get(0)?;
1959    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
1960    /// let t = tensor.get(1)?;
1961    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
1962    /// # Ok::<(), candle_core::Error>(())
1963    /// ```
1964    pub fn get(&self, i: usize) -> Result<Tensor> {
1965        let dims = self.dims();
1966        if dims.is_empty() {
1967            Ok(self.clone())
1968        } else {
1969            self.narrow(0, i, 1)?.reshape(&dims[1..])
1970        }
1971    }
1972
1973    /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
1974    ///
1975    /// ```rust
1976    /// use candle_core::{Tensor, Device};
1977    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1978    /// let t = tensor.get_on_dim(1, 0)?;
1979    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
1980    /// let t = tensor.get_on_dim(1, 1)?;
1981    /// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
1982    /// let t = tensor.get_on_dim(0, 1)?;
1983    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
1984    /// # Ok::<(), candle_core::Error>(())
1985    /// ```
1986    pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
1987        let dim = dim.to_index(self.shape(), "get_on_dim")?;
1988        self.narrow(dim, index, 1)?.squeeze(dim)
1989    }
1990
1991    /// Returns a tensor that is a transposed version of the input, the two last dimensions of the
1992    /// input are swapped.
1993    ///
1994    /// ```rust
1995    /// use candle_core::{Tensor, Device};
1996    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1997    /// let tensor = tensor.t()?;
1998    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
1999    /// # Ok::<(), candle_core::Error>(())
2000    /// ```
2001    pub fn t(&self) -> Result<Tensor> {
2002        let rank = self.rank();
2003        if rank < 2 {
2004            Err(Error::UnexpectedNumberOfDims {
2005                expected: 2,
2006                got: rank,
2007                shape: self.shape().clone(),
2008            }
2009            .bt())?
2010        }
2011        self.transpose(rank - 2, rank - 1)
2012    }
2013
2014    /// Returns a tensor that is a transposed version of the input, the given dimensions are
2015    /// swapped.
2016    pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2017        let dim1 = dim1.to_index(self.shape(), "transpose")?;
2018        let dim2 = dim2.to_index(self.shape(), "transpose")?;
2019        if dim1 == dim2 {
2020            return Ok(self.clone());
2021        }
2022        let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2023        let tensor_ = Tensor_ {
2024            id: TensorId::new(),
2025            storage: self.storage.clone(),
2026            layout: self.layout.transpose(dim1, dim2)?,
2027            op,
2028            is_variable: false,
2029            dtype: self.dtype,
2030            device: self.device.clone(),
2031        };
2032        Ok(Tensor(Arc::new(tensor_)))
2033    }
2034
2035    /// Returns a tensor with the same data as the input where the dimensions have been permuted.
2036    /// dims must be a permutation, i.e. include each dimension index exactly once.
2037    ///
2038    /// ```rust
2039    /// use candle_core::{Tensor, Device};
2040    /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
2041    /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
2042    /// let tensor = tensor.permute((2, 3, 1, 0))?;
2043    /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
2044    /// # Ok::<(), candle_core::Error>(())
2045    /// ```
2046    pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2047        let dims = dims.to_indexes(self.shape(), "permute")?;
2048        // O(n^2) permutation check but these arrays are small.
2049        let is_permutation =
2050            dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2051        if !is_permutation {
2052            bail!(
2053                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2054                self.dims(),
2055                dims
2056            )
2057        }
2058        let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2059        let tensor_ = Tensor_ {
2060            id: TensorId::new(),
2061            storage: self.storage.clone(),
2062            layout: self.layout.permute(&dims)?,
2063            op,
2064            is_variable: false,
2065            dtype: self.dtype,
2066            device: self.device.clone(),
2067        };
2068        Ok(Tensor(Arc::new(tensor_)))
2069    }
2070
2071    /// Returns true if the data is stored in a C contiguous (aka row major) way.
2072    pub fn is_contiguous(&self) -> bool {
2073        self.layout.is_contiguous()
2074    }
2075
2076    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
2077    pub fn is_fortran_contiguous(&self) -> bool {
2078        self.layout.is_fortran_contiguous()
2079    }
2080
2081    /// Compared to clone, this copies the actual storage but may fail because of running out of
2082    /// memory.
2083    pub fn copy(&self) -> Result<Tensor> {
2084        let op = BackpropOp::new1(self, Op::Copy);
2085        let tensor_ = Tensor_ {
2086            id: TensorId::new(),
2087            storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2088            layout: self.layout.clone(),
2089            op,
2090            is_variable: false,
2091            dtype: self.dtype,
2092            device: self.device.clone(),
2093        };
2094        Ok(Tensor(Arc::new(tensor_)))
2095    }
2096
2097    /// Returns a new tensor detached from the current graph, gradient are not propagated through
2098    /// this new node. The storage of this tensor is shared with the initial tensor.
2099    ///
2100    /// If the tensor is already detached from the computation graph, the same tensor is returned.
2101    pub fn detach(&self) -> Tensor {
2102        if self.op.is_none() && !self.is_variable {
2103            self.clone()
2104        } else {
2105            let tensor_ = Tensor_ {
2106                id: TensorId::new(),
2107                storage: self.storage.clone(),
2108                layout: self.layout.clone(),
2109                op: BackpropOp::none(),
2110                is_variable: false,
2111                dtype: self.dtype,
2112                device: self.device.clone(),
2113            };
2114            Tensor(Arc::new(tensor_))
2115        }
2116    }
2117
2118    /// If the target device is the same as the tensor device, only a shallow copy is performed.
2119    pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2120        if self.device().same_device(device) {
2121            Ok(self.clone())
2122        } else {
2123            let storage = match (&*self.storage(), device) {
2124                (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2125                    Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2126                }
2127                (Storage::Cpu(storage), Device::Metal(metal)) => {
2128                    Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2129                }
2130                (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2131                (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2132                (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2133                    // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
2134                    // are the same.
2135                    let cpu_storage = storage.to_cpu_storage()?;
2136                    Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2137                }
2138                (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2139                _ => {
2140                    bail!(
2141                        "not implemented yet, self.device: {:?}, device: {:?}",
2142                        self.device(),
2143                        device
2144                    )
2145                }
2146            };
2147            let op = BackpropOp::new1(self, Op::ToDevice);
2148            let tensor_ = Tensor_ {
2149                id: TensorId::new(),
2150                storage: Arc::new(RwLock::new(storage)),
2151                layout: self.layout.clone(),
2152                op,
2153                is_variable: false,
2154                dtype: self.dtype,
2155                device: device.clone(),
2156            };
2157            Ok(Tensor(Arc::new(tensor_)))
2158        }
2159    }
2160
2161    /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
2162    /// on the left.
2163    pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2164        let left_shape = left_shape.into();
2165        let mut dims = left_shape.into_dims();
2166        dims.extend(self.dims());
2167        self.broadcast_as(dims)
2168    }
2169
2170    /// Broadcast the input tensor to the target shape. This returns an error if the input shape is
2171    /// not compatible with the target shape.
2172    ///
2173    /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
2174    /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
2175    /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
2176    /// `i_a` is equal to 1, any value can be used.
2177    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2178        let tensor_ = Tensor_ {
2179            id: TensorId::new(),
2180            storage: self.storage.clone(),
2181            layout: self.layout.broadcast_as(shape)?,
2182            op: BackpropOp::new1(self, Op::Broadcast),
2183            is_variable: false,
2184            dtype: self.dtype,
2185            device: self.device.clone(),
2186        };
2187        Ok(Tensor(Arc::new(tensor_)))
2188    }
2189
2190    /// An alias for broadcast_as.
2191    pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2192        self.broadcast_as(shape)
2193    }
2194
2195    /// Casts the input tensor to the target `dtype`.
2196    ///
2197    /// ```rust
2198    /// use candle_core::{Tensor, Device};
2199    /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
2200    /// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
2201    /// let tensor = tensor.to_dtype(candle_core::DType::F32)?;
2202    /// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
2203    /// # Ok::<(), candle_core::Error>(())
2204    /// ```
2205    pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2206        if self.dtype() == dtype {
2207            Ok(self.clone())
2208        } else {
2209            let shape = self.shape();
2210            let storage = self.storage().to_dtype(self.layout(), dtype)?;
2211            let op = BackpropOp::new1(self, Op::ToDType);
2212            Ok(from_storage(storage, shape.clone(), op, false))
2213        }
2214    }
2215
2216    /// Returns a tensor that is in row major order. This is the same as the original tensor if it
2217    /// was already contiguous, otherwise a copy is triggered.
2218    pub fn contiguous(&self) -> Result<Tensor> {
2219        if self.is_contiguous() {
2220            Ok(self.clone())
2221        } else {
2222            let shape = self.shape();
2223            let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2224            self.storage()
2225                .copy_strided_src(&mut storage, 0, self.layout())?;
2226            let op = BackpropOp::new1(self, Op::Copy);
2227            Ok(from_storage(storage, shape.clone(), op, false))
2228        }
2229    }
2230
2231    /// Returns a tensor that is in row major order. This always makes a copy.
2232    pub fn force_contiguous(&self) -> Result<Tensor> {
2233        let shape = self.shape();
2234        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2235        self.storage()
2236            .copy_strided_src(&mut storage, 0, self.layout())?;
2237        let op = BackpropOp::new1(self, Op::Copy);
2238        Ok(from_storage(storage, shape.clone(), op, false))
2239    }
2240
2241    /// Create a variable based on the values currently stored in a tensor. The storage is always
2242    /// copied.
2243    pub(crate) fn make_var(&self) -> Result<Tensor> {
2244        let shape = self.shape().clone();
2245        let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2246        self.storage()
2247            .copy_strided_src(&mut storage, 0, self.layout())?;
2248        Ok(from_storage(storage, shape, BackpropOp::none(), true))
2249    }
2250
2251    /// Reshape returns a tensor with the target shape provided that the number of elements of the
2252    /// original tensor is the same.
2253    /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
2254    /// a new storage and copies the data over, the returned tensor is always contiguous.
2255    ///
2256    /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
2257    /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
2258    /// as to match the number of elements in the tensor.
2259    ///
2260    /// ```rust
2261    /// # use candle_core::{Tensor, DType, Device, D};
2262    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2263    ///
2264    /// let c = a.reshape((1, 6))?;
2265    /// assert_eq!(c.shape().dims(), &[1, 6]);
2266    ///
2267    /// let c = a.reshape((3, 2))?;
2268    /// assert_eq!(c.shape().dims(), &[3, 2]);
2269    ///
2270    /// let c = a.reshape((2, (), 1))?;
2271    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2272    ///
2273    /// # Ok::<(), candle_core::Error>(())
2274    /// ```
2275    pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2276        let shape = s.into_shape(self.elem_count())?;
2277        if shape.elem_count() != self.elem_count() {
2278            return Err(Error::ShapeMismatchBinaryOp {
2279                lhs: self.shape().clone(),
2280                rhs: shape,
2281                op: "reshape",
2282            }
2283            .bt());
2284        }
2285        let op = BackpropOp::new1(self, Op::Reshape);
2286        if self.is_contiguous() {
2287            let tensor_ = Tensor_ {
2288                id: TensorId::new(),
2289                storage: self.storage.clone(),
2290                layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2291                op,
2292                is_variable: false,
2293                dtype: self.dtype,
2294                device: self.device.clone(),
2295            };
2296            Ok(Tensor(Arc::new(tensor_)))
2297        } else {
2298            let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2299            self.storage()
2300                .copy_strided_src(&mut storage, 0, self.layout())?;
2301            Ok(from_storage(storage, shape, op, false))
2302        }
2303    }
2304
2305    /// Creates a new tensor with the specified dimension removed if its size was one.
2306    ///
2307    /// ```rust
2308    /// # use candle_core::{Tensor, DType, Device, D};
2309    /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
2310    ///
2311    /// let c = a.squeeze(2)?;
2312    /// assert_eq!(c.shape().dims(), &[2, 3]);
2313    ///
2314    /// let c = a.squeeze(D::Minus1)?;
2315    /// assert_eq!(c.shape().dims(), &[2, 3]);
2316    /// # Ok::<(), candle_core::Error>(())
2317    /// ```
2318    pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2319        // The PyTorch semantics are to return the same tensor if the target dimension
2320        // does not have a size of 1.
2321        let dims = self.dims();
2322        let dim = dim.to_index(self.shape(), "squeeze")?;
2323        if dims[dim] == 1 {
2324            let mut dims = dims.to_vec();
2325            let mut strides = self.stride().to_vec();
2326            dims.remove(dim);
2327            strides.remove(dim);
2328            let tensor_ = Tensor_ {
2329                id: TensorId::new(),
2330                storage: self.storage.clone(),
2331                layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2332                op: BackpropOp::new1(self, Op::Reshape),
2333                is_variable: false,
2334                dtype: self.dtype,
2335                device: self.device.clone(),
2336            };
2337            Ok(Tensor(Arc::new(tensor_)))
2338        } else {
2339            Ok(self.clone())
2340        }
2341    }
2342
2343    /// Creates a new tensor with a dimension of size one inserted at the specified position.
2344    ///
2345    /// ```rust
2346    /// # use candle_core::{Tensor, DType, Device, D};
2347    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2348    ///
2349    /// let c = a.unsqueeze(0)?;
2350    /// assert_eq!(c.shape().dims(), &[1, 2, 3]);
2351    ///
2352    /// let c = a.unsqueeze(D::Minus1)?;
2353    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2354    /// # Ok::<(), candle_core::Error>(())
2355    /// ```
2356    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2357        let mut dims = self.dims().to_vec();
2358        let mut strides = self.stride().to_vec();
2359        let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2360        // Cannot panic because to_index_plus_one already checks dimensions
2361        dims.insert(dim, 1);
2362        // Any stride would work here, but we pick one so as to maximize the probability to remain
2363        // C contiguous.
2364        let stride = if dim < strides.len() { strides[dim] } else { 1 };
2365        strides.insert(dim, stride);
2366        let tensor_ = Tensor_ {
2367            id: TensorId::new(),
2368            storage: self.storage.clone(),
2369            layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2370            op: BackpropOp::new1(self, Op::Reshape),
2371            is_variable: false,
2372            dtype: self.dtype,
2373            device: self.device.clone(),
2374        };
2375        Ok(Tensor(Arc::new(tensor_)))
2376    }
2377
2378    /// Stacks two or more tensors along a particular dimension.
2379    ///
2380    /// All tensors must have the same rank, and the output has one additional rank
2381    ///
2382    /// ```rust
2383    /// # use candle_core::{Tensor, DType, Device};
2384    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2385    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2386    ///
2387    /// let c = Tensor::stack(&[&a, &b], 0)?;
2388    /// assert_eq!(c.shape().dims(), &[2, 2, 3]);
2389    ///
2390    /// let c = Tensor::stack(&[&a, &b], 2)?;
2391    /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
2392    /// # Ok::<(), candle_core::Error>(())
2393    /// ```
2394    pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2395        if args.is_empty() {
2396            Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2397        }
2398        let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2399        let args = args
2400            .iter()
2401            .map(|t| t.as_ref().unsqueeze(dim))
2402            .collect::<Result<Vec<_>>>()?;
2403        Self::cat(&args, dim)
2404    }
2405
2406    /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
2407    /// input tensor values and `right` elements after.
2408    pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2409        if left == 0 && right == 0 {
2410            Ok(self.clone())
2411        } else if left == 0 {
2412            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2413            let mut dims = self.dims().to_vec();
2414            dims[dim] = right;
2415            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2416            Tensor::cat(&[self, &right], dim)
2417        } else if right == 0 {
2418            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2419            let mut dims = self.dims().to_vec();
2420            dims[dim] = left;
2421            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2422            Tensor::cat(&[&left, self], dim)
2423        } else {
2424            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2425            let mut dims = self.dims().to_vec();
2426            dims[dim] = left;
2427            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2428            dims[dim] = right;
2429            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2430            Tensor::cat(&[&left, self, &right], dim)
2431        }
2432    }
2433
2434    /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
2435    /// input tensor values and `right` elements after.
2436    pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2437        if left == 0 && right == 0 {
2438            Ok(self.clone())
2439        } else if self.elem_count() == 0 {
2440            bail!("cannot use pad_with_same on an empty tensor")
2441        } else if left == 0 {
2442            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2443            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2444            let mut v = vec![self];
2445            for _ in 0..right {
2446                v.push(&r)
2447            }
2448            Tensor::cat(&v, dim)
2449        } else if right == 0 {
2450            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2451            let l = self.narrow(dim, 0, 1)?;
2452            let mut v = vec![];
2453            for _ in 0..left {
2454                v.push(&l)
2455            }
2456            v.push(self);
2457            Tensor::cat(&v, dim)
2458        } else {
2459            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2460            let l = self.narrow(dim, 0, 1)?;
2461            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2462            let mut v = vec![];
2463            for _ in 0..left {
2464                v.push(&l)
2465            }
2466            v.push(self);
2467            for _ in 0..right {
2468                v.push(&r)
2469            }
2470            Tensor::cat(&v, dim)
2471        }
2472    }
2473
2474    /// Run the `forward` method of `m` on `self`.
2475    pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2476        m.forward(self)
2477    }
2478
2479    /// Run the `forward` method of `m` on `self`.
2480    pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2481        m.forward_t(self, train)
2482    }
2483
2484    pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2485        self.storage.read().unwrap()
2486    }
2487
2488    pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2489        self.storage.write().unwrap()
2490    }
2491
2492    // If we extend the visibility of this function to be usable outside of this crate, we should
2493    // make it unsafe.
2494    pub(crate) fn storage_mut_and_layout(
2495        &self,
2496    ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2497        let storage = self.storage.write().unwrap();
2498        (storage, &self.layout)
2499    }
2500
2501    /// The storage used by this tensor, together with the layout to use to access it safely.
2502    pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2503        let storage = self.storage.read().unwrap();
2504        (storage, &self.layout)
2505    }
2506
2507    pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2508        let lhs: &RwLock<Storage> = self.storage.as_ref();
2509        let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2510        std::ptr::eq(lhs, rhs)
2511    }
2512
2513    /// Normalize a 'relative' axis value: positive values are kept, negative
2514    /// values means counting the dimensions from the back.
2515    pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2516        let rank = self.rank() as i64;
2517        if rank <= axis {
2518            bail!("axis {axis} is too large, tensor rank {rank}")
2519        } else if 0 <= axis {
2520            Ok(axis as usize)
2521        } else {
2522            let naxis = rank + axis;
2523            if naxis < 0 {
2524                bail!("axis {axis} is too small, tensor rank {rank}")
2525            }
2526            Ok(naxis as usize)
2527        }
2528    }
2529
2530    /// Returns a lower triangular matrix of ones of size n by n.
2531    pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2532        let t = Tensor::arange(0u32, n as u32, device)?;
2533        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2534        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2535        t1.le(&t2)?.to_dtype(dtype)
2536    }
2537
2538    /// Returns an upper triangular matrix of ones of size n by n.
2539    pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2540        let t = Tensor::arange(0u32, n as u32, device)?;
2541        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2542        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2543        t1.ge(&t2)?.to_dtype(dtype)
2544    }
2545
2546    /// Returns a matrix with a diagonal of ones of size n by n.
2547    pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2548        let t = Tensor::arange(0u32, n as u32, device)?;
2549        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2550        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2551        t1.eq(&t2)?.to_dtype(dtype)
2552    }
2553
2554    /// Returns the cumulative sum of elements of the input tensor summed over the specified
2555    /// dimension.
2556    ///
2557    /// This operation is most efficient when dim is the last dimension of the tensor.
2558    pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2559        let dim = dim.to_index(self.shape(), "cumsum")?;
2560        let rank = self.rank();
2561        if rank == 0 {
2562            return Ok(self.clone());
2563        }
2564        let n_axis = self.dim(dim)?;
2565        let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2566        if rank == 1 {
2567            self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2568        } else {
2569            let last = rank - 1;
2570            let t = self.transpose(dim, last)?;
2571            let t = t.broadcast_matmul(&triu)?;
2572            t.transpose(dim, last)
2573        }
2574    }
2575
2576    /// Returns a copy of `self` where the values within `ranges` have been replaced with the
2577    /// content of `src`.
2578    pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2579        &self,
2580        ranges: &[D],
2581        src: &Tensor,
2582    ) -> Result<Self> {
2583        let src_dims = src.dims();
2584        let self_dims = self.dims();
2585        if self_dims.len() != src_dims.len() {
2586            bail!(
2587                "slice-assign requires input with the same rank {} <> {}",
2588                self_dims.len(),
2589                src_dims.len()
2590            )
2591        }
2592        if self_dims.len() != ranges.len() {
2593            bail!(
2594                "slice-assign requires input with the same rank as there are ranges {} <> {}",
2595                self_dims.len(),
2596                ranges.len()
2597            )
2598        }
2599        let mut src = src.clone();
2600        let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2601        for (i, range) in ranges.iter().enumerate() {
2602            let start_included = match range.start_bound() {
2603                std::ops::Bound::Unbounded => 0,
2604                std::ops::Bound::Included(v) => *v,
2605                std::ops::Bound::Excluded(v) => *v + 1,
2606            };
2607            let end_excluded = match range.end_bound() {
2608                std::ops::Bound::Unbounded => self_dims[i],
2609                std::ops::Bound::Included(v) => *v + 1,
2610                std::ops::Bound::Excluded(v) => *v,
2611            };
2612            if end_excluded <= start_included {
2613                bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2614            }
2615            if self_dims[i] < end_excluded {
2616                bail!(
2617                    "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2618                    self_dims[i]
2619                )
2620            }
2621            if end_excluded - start_included != src_dims[i] {
2622                bail!(
2623                    "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2624                )
2625            }
2626            src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2627            mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2628        }
2629        mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
2630    }
2631
2632    /// Returns log(sum(exp(tensor), dim)).
2633    pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2634        let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2635        if sum_dims.is_empty() {
2636            return Ok(self.clone());
2637        }
2638        let max = sum_dims[1..]
2639            .iter()
2640            .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2641                max.max_keepdim(dim)
2642            })?;
2643        let exp = self.broadcast_sub(&max)?.exp()?;
2644        let sum = exp.sum(sum_dims.clone())?;
2645
2646        sum.log()? + max.squeeze_dims(&sum_dims)
2647    }
2648
2649    /// Pointwise pow operation.
2650    pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2651        rhs.mul(&self.log()?)?.exp()
2652    }
2653
2654    /// Broadcasting version of `pow`.
2655    pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2656        rhs.broadcast_mul(&self.log()?)?.exp()
2657    }
2658
2659    /// Returns a new tensor with the order of elements reversed along the specified dimensions.
2660    /// This function makes a copy of the tensor’s data.
2661    ///
2662    /// ```rust
2663    /// # use candle_core::{Tensor, Device};
2664    /// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
2665    /// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
2666    /// let t_flipped = t.flip(&[0])?;
2667    /// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
2668    /// # Ok::<(), candle_core::Error>(())
2669    /// ```
2670    pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2671        let mut result = self.clone();
2672        for &dim in dims.iter() {
2673            let size = result.dim(dim)?;
2674            let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2675            let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2676            result = result.index_select(&indices_tensor, dim)?;
2677        }
2678        Ok(result)
2679    }
2680}
2681
2682macro_rules! bin_trait {
2683    ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2684        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2685            type Output = Result<Tensor>;
2686
2687            fn $fn1(self, rhs: B) -> Self::Output {
2688                Tensor::$fn1(&self, rhs.borrow())
2689            }
2690        }
2691
2692        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2693            type Output = Result<Tensor>;
2694
2695            fn $fn1(self, rhs: B) -> Self::Output {
2696                Tensor::$fn1(&self, rhs.borrow())
2697            }
2698        }
2699
2700        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2701            type Output = Result<Tensor>;
2702
2703            fn $fn1(self, rhs: Tensor) -> Self::Output {
2704                Tensor::$fn1(self?.borrow(), &rhs)
2705            }
2706        }
2707
2708        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
2709            type Output = Result<Tensor>;
2710
2711            fn $fn1(self, rhs: &Tensor) -> Self::Output {
2712                Tensor::$fn1(self?.borrow(), rhs)
2713            }
2714        }
2715
2716        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
2717            type Output = Result<Tensor>;
2718
2719            fn $fn1(self, rhs: Result<B>) -> Self::Output {
2720                Tensor::$fn1(&self, rhs?.borrow())
2721            }
2722        }
2723
2724        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
2725            type Output = Result<Tensor>;
2726
2727            fn $fn1(self, rhs: Result<B>) -> Self::Output {
2728                Tensor::$fn1(&self, rhs?.borrow())
2729            }
2730        }
2731
2732        impl std::ops::$trait<f64> for Tensor {
2733            type Output = Result<Tensor>;
2734
2735            fn $fn1(self, rhs: f64) -> Self::Output {
2736                self.affine($mul(rhs), $add(rhs))
2737            }
2738        }
2739
2740        impl std::ops::$trait<f64> for &Tensor {
2741            type Output = Result<Tensor>;
2742
2743            fn $fn1(self, rhs: f64) -> Self::Output {
2744                self.affine($mul(rhs), $add(rhs))
2745            }
2746        }
2747    };
2748}
2749
2750bin_trait!(Add, add, |_| 1., |v| v);
2751bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
2752bin_trait!(Mul, mul, |v| v, |_| 0.);
2753bin_trait!(Div, div, |v| 1. / v, |_| 0.);
2754
2755impl std::ops::Add<Tensor> for f64 {
2756    type Output = Result<Tensor>;
2757
2758    fn add(self, rhs: Tensor) -> Self::Output {
2759        rhs + self
2760    }
2761}
2762
2763impl std::ops::Add<&Tensor> for f64 {
2764    type Output = Result<Tensor>;
2765
2766    fn add(self, rhs: &Tensor) -> Self::Output {
2767        rhs + self
2768    }
2769}
2770
2771impl std::ops::Mul<Tensor> for f64 {
2772    type Output = Result<Tensor>;
2773
2774    fn mul(self, rhs: Tensor) -> Self::Output {
2775        rhs * self
2776    }
2777}
2778
2779impl std::ops::Mul<&Tensor> for f64 {
2780    type Output = Result<Tensor>;
2781
2782    fn mul(self, rhs: &Tensor) -> Self::Output {
2783        rhs * self
2784    }
2785}
2786
2787impl std::ops::Sub<Tensor> for f64 {
2788    type Output = Result<Tensor>;
2789
2790    fn sub(self, rhs: Tensor) -> Self::Output {
2791        rhs.affine(-1., self)
2792    }
2793}
2794
2795impl std::ops::Sub<&Tensor> for f64 {
2796    type Output = Result<Tensor>;
2797
2798    fn sub(self, rhs: &Tensor) -> Self::Output {
2799        rhs.affine(-1., self)
2800    }
2801}
2802
2803impl std::ops::Div<Tensor> for f64 {
2804    type Output = Result<Tensor>;
2805
2806    #[allow(clippy::suspicious_arithmetic_impl)]
2807    fn div(self, rhs: Tensor) -> Self::Output {
2808        rhs.recip()? * self
2809    }
2810}
2811
2812impl std::ops::Div<&Tensor> for f64 {
2813    type Output = Result<Tensor>;
2814
2815    #[allow(clippy::suspicious_arithmetic_impl)]
2816    fn div(self, rhs: &Tensor) -> Self::Output {
2817        rhs.recip()? * self
2818    }
2819}