candle_core/
tensor.rs

1//! Tensors are N-dimensional matrixes of elements using a single data type.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10/// Unique identifier for tensors.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15    fn new() -> Self {
16        // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
17        use std::sync::atomic;
18        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20    }
21}
22
23pub struct Tensor_ {
24    id: TensorId,
25    // As we provide inner mutability on the tensor content, the alternatives are:
26    // - Using a mutex, this would have the highest cost when retrieving the storage but would
27    //   prevent errors when concurrent access takes place. Mutex would also be subject to
28    //   deadlocks for example using the current code if the same tensor is used twice by a single
29    //   binary op.
30    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
31    //   verified dynamically, but the resulting tensors would not be send or sync.
32    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
33    //   accesses.
34    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
35    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
36    // that's tricky to encode in the current setup.
37    storage: Arc<RwLock<Storage>>,
38    layout: Layout,
39    op: BackpropOp,
40    is_variable: bool,
41    dtype: DType,
42    device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46    fn as_ref(&self) -> &Tensor {
47        self
48    }
49}
50
51// Tensors are refcounted so that cloning is cheap when building the op graph.
52// Storages are also refcounted independently so that its possible to avoid
53// copying the storage for operations that only modify the shape or stride.
54#[derive(Clone)]
55/// The core struct for manipulating tensors.
56///
57/// ```rust
58/// use candle_core::{Tensor, DType, Device};
59///
60/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
61/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
62///
63/// let c = a.matmul(&b)?;
64/// # Ok::<(), candle_core::Error>(())
65/// ```
66///
67/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
68pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71    type Target = Tensor_;
72
73    fn deref(&self) -> &Self::Target {
74        self.0.as_ref()
75    }
76}
77
78macro_rules! unary_op {
79    ($fn_name:ident, $op_name:ident) => {
80        pub fn $fn_name(&self) -> Result<Self> {
81            let shape = self.shape();
82            if shape.elem_count() == 0 {
83                return Ok(self.clone());
84            }
85            let storage = self
86                .storage()
87                .unary_impl::<crate::op::$op_name>(self.layout())?;
88            let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89            Ok(from_storage(storage, shape.clone(), op, false))
90        }
91    };
92}
93
94macro_rules! binary_op {
95    ($fn_name:ident, $op_name:ident) => {
96        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97            let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98            if shape.elem_count() == 0 {
99                return Ok(self.clone());
100            }
101            let storage = self.storage().binary_impl::<crate::op::$op_name>(
102                &*rhs.storage(),
103                self.layout(),
104                rhs.layout(),
105            )?;
106            let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107            Ok(from_storage(storage, shape.clone(), op, false))
108        }
109    };
110}
111
112macro_rules! binary_op_scalar {
113    ($fn_name:ident, $op_name:ident) => {
114        pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115            let rhs = match rhs.to_tensor_scalar()? {
116                crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117                crate::scalar::TensorScalar::Scalar(rhs) => rhs
118                    .to_dtype(self.dtype())?
119                    .to_device(self.device())?
120                    .broadcast_as(self.shape())?,
121            };
122            let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123            if self.elem_count() == 0 {
124                return Ok(self.clone());
125            }
126            let storage = self.storage().binary_impl::<crate::op::$op_name>(
127                &*rhs.storage(),
128                self.layout(),
129                rhs.layout(),
130            )?;
131            let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132            Ok(from_storage(storage, shape.clone(), op, false))
133        }
134    };
135}
136
137macro_rules! broadcast_binary_op {
138    ($fn_name:ident, $inner_fn_name:ident) => {
139        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140            let lhs = self;
141            let shape = lhs
142                .shape()
143                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144            let l_broadcast = shape != *lhs.shape();
145            let r_broadcast = shape != *rhs.shape();
146            match (l_broadcast, r_broadcast) {
147                (true, true) => lhs
148                    .broadcast_as(&shape)?
149                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152                (false, false) => lhs.$inner_fn_name(rhs),
153            }
154        }
155    };
156}
157
158/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
159pub(crate) fn from_storage<S: Into<Shape>>(
160    storage: Storage,
161    shape: S,
162    op: BackpropOp,
163    is_variable: bool,
164) -> Tensor {
165    let dtype = storage.dtype();
166    let device = storage.device();
167    let tensor_ = Tensor_ {
168        id: TensorId::new(),
169        storage: Arc::new(RwLock::new(storage)),
170        layout: Layout::contiguous(shape),
171        op,
172        is_variable,
173        dtype,
174        device,
175    };
176    Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180    pub(crate) fn ones_impl<S: Into<Shape>>(
181        shape: S,
182        dtype: DType,
183        device: &Device,
184        is_variable: bool,
185    ) -> Result<Self> {
186        let none = BackpropOp::none();
187        let shape = shape.into();
188        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189        let layout = Layout::contiguous(shape.clone());
190        storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191        Ok(from_storage(storage, shape, none, is_variable))
192    }
193
194    /// Creates a new tensor filled with ones.
195    ///
196    /// ```rust
197    /// use candle_core::{Tensor, DType, Device};
198    /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
199    /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
200    /// // a == b
201    /// # Ok::<(), candle_core::Error>(())
202    /// ```
203    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204        Self::ones_impl(shape, dtype, device, false)
205    }
206
207    pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208        self.storage_mut().const_set(value, self.layout())
209    }
210
211    pub fn zero_set(&self) -> Result<()> {
212        self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213    }
214
215    pub fn one_set(&self) -> Result<()> {
216        self.const_set(crate::scalar::Scalar::one(self.dtype()))
217    }
218
219    /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
220    ///
221    /// ```rust
222    /// use candle_core::{Tensor, DType, Device};
223    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
224    /// let b = a.ones_like()?;
225    /// // b == a + 1
226    /// # Ok::<(), candle_core::Error>(())
227    /// ```
228    pub fn ones_like(&self) -> Result<Self> {
229        Tensor::ones(self.shape(), self.dtype(), self.device())
230    }
231
232    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
233    // the variable module.
234    pub(crate) fn zeros_impl<S: Into<Shape>>(
235        shape: S,
236        dtype: DType,
237        device: &Device,
238        is_variable: bool,
239    ) -> Result<Self> {
240        let none = BackpropOp::none();
241        let shape = shape.into();
242        let storage = device.zeros(&shape, dtype)?;
243        Ok(from_storage(storage, shape, none, is_variable))
244    }
245
246    /// Creates a new tensor filled with zeros.
247    ///
248    /// ```rust
249    /// use candle_core::{Tensor, DType, Device};
250    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
251    /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
252    /// // a == b
253    /// # Ok::<(), candle_core::Error>(())
254    /// ```
255    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256        Self::zeros_impl(shape, dtype, device, false)
257    }
258
259    /// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
260    /// tensor.
261    ///
262    /// ```rust
263    /// use candle_core::{Tensor, DType, Device};
264    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
265    /// let b = a.zeros_like()?;
266    /// // b is on CPU f32.
267    /// # Ok::<(), candle_core::Error>(())
268    /// ```
269    pub fn zeros_like(&self) -> Result<Self> {
270        Tensor::zeros(self.shape(), self.dtype(), self.device())
271    }
272
273    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
274    // the variable module.
275    pub(crate) unsafe fn empty_impl<S: Into<Shape>>(
276        shape: S,
277        dtype: DType,
278        device: &Device,
279        is_variable: bool,
280    ) -> Result<Self> {
281        let none = BackpropOp::none();
282        let shape = shape.into();
283        let storage = device.alloc_uninit(&shape, dtype)?;
284        Ok(from_storage(storage, shape, none, is_variable))
285    }
286
287    /// Creates a new tensor filled with uninitialized memory.
288    ///
289    /// # Safety
290    /// This returns uninitialized memory.
291    ///
292    /// ```rust
293    /// use candle_core::{Tensor, DType, Device};
294    /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? };
295    /// // a == b
296    /// # Ok::<(), candle_core::Error>(())
297    /// ```
298    pub unsafe fn empty<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
299        Self::empty_impl(shape, dtype, device, false)
300    }
301
302    /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other
303    /// tensor.
304    ///
305    /// # Safety
306    /// This returns uninitialized memory.
307    ///
308    /// ```rust
309    /// use candle_core::{Tensor, DType, Device};
310    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
311    /// let b = unsafe { a.empty_like()? };
312    /// # Ok::<(), candle_core::Error>(())
313    /// ```
314    pub unsafe fn empty_like(&self) -> Result<Self> {
315        Tensor::empty(self.shape(), self.dtype(), self.device())
316    }
317
318    pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
319        lo: T,
320        up: T,
321        s: S,
322        device: &Device,
323        is_variable: bool,
324    ) -> Result<Self> {
325        let s = s.into();
326        let storage = device.rand_uniform(lo, up, &s)?;
327        let none = BackpropOp::none();
328        Ok(from_storage(storage, s, none, is_variable))
329    }
330
331    pub(crate) fn rand_f64_impl<S: Into<Shape>>(
332        lo: f64,
333        up: f64,
334        s: S,
335        dtype: DType,
336        device: &Device,
337        is_variable: bool,
338    ) -> Result<Self> {
339        let s = s.into();
340        let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
341        let none = BackpropOp::none();
342        Ok(from_storage(storage, s, none, is_variable))
343    }
344
345    /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
346    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
347        lo: T,
348        up: T,
349        s: S,
350        device: &Device,
351    ) -> Result<Self> {
352        Self::rand_impl(lo, up, s, device, false)
353    }
354
355    pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
356        Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
357    }
358
359    pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
360        mean: T,
361        std: T,
362        s: S,
363        device: &Device,
364        is_variable: bool,
365    ) -> Result<Self> {
366        let s = s.into();
367        let storage = device.rand_normal(mean, std, &s)?;
368        let none = BackpropOp::none();
369        Ok(from_storage(storage, s, none, is_variable))
370    }
371
372    pub(crate) fn randn_f64_impl<S: Into<Shape>>(
373        mean: f64,
374        std: f64,
375        s: S,
376        dtype: DType,
377        device: &Device,
378        is_variable: bool,
379    ) -> Result<Self> {
380        let s = s.into();
381        let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
382        let none = BackpropOp::none();
383        Ok(from_storage(storage, s, none, is_variable))
384    }
385
386    pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
387        Tensor::randn_f64_impl(
388            mean,
389            stdev,
390            self.shape(),
391            self.dtype(),
392            self.device(),
393            false,
394        )
395    }
396
397    /// Creates a new tensor initialized with values sampled from a normal distribution with the
398    /// specified `mean` and standard deviation `std`.
399    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
400        mean: T,
401        std: T,
402        s: S,
403        device: &Device,
404    ) -> Result<Self> {
405        Self::randn_impl(mean, std, s, device, false)
406    }
407
408    pub(crate) fn new_impl<A: crate::device::NdArray>(
409        array: A,
410        shape: Shape,
411        device: &Device,
412        is_variable: bool,
413    ) -> Result<Self> {
414        let n: usize = shape.elem_count();
415        let buffer_size: usize = array.shape()?.elem_count();
416        if buffer_size != n {
417            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
418        }
419        let storage = device.storage(array)?;
420        let none = BackpropOp::none();
421        Ok(from_storage(storage, shape, none, is_variable))
422    }
423
424    /// Creates a new tensor on the specified device using the content and shape of the input.
425    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
426        let shape = array.shape()?;
427        Self::new_impl(array, shape, device, false)
428    }
429
430    /// Returns a new tensor with all the elements having the same specified value.
431    ///```rust
432    /// use candle_core::{Tensor, Device};
433    /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
434    ///
435    /// assert_eq!(a.to_vec2::<f64>()?, &[
436    ///     [3.5, 3.5, 3.5, 3.5],
437    ///     [3.5, 3.5, 3.5, 3.5],
438    /// ]);
439    /// # Ok::<(), candle_core::Error>(())
440    pub fn full<D: crate::WithDType, S: Into<Shape>>(
441        value: D,
442        shape: S,
443        device: &Device,
444    ) -> Result<Self> {
445        let none = BackpropOp::none();
446        let shape = shape.into();
447        let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
448        let layout = Layout::contiguous(shape.clone());
449        storage.const_set(value.to_scalar(), &layout)?;
450        Ok(from_storage(storage, shape, none, false))
451    }
452
453    /// Creates a new 1D tensor from an iterator.
454    ///```rust
455    /// use candle_core::{Tensor, Device};
456    /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
457    ///
458    /// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
459    /// # Ok::<(), candle_core::Error>(())
460    /// ```
461    pub fn from_iter<D: crate::WithDType>(
462        iter: impl IntoIterator<Item = D>,
463        device: &Device,
464    ) -> Result<Self> {
465        let data = iter.into_iter().collect::<Vec<_>>();
466        let len = data.len();
467        Self::from_vec_impl(data, len, device, false)
468    }
469
470    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
471    /// difference `1` from `start`.
472    ///```rust
473    /// use candle_core::{Tensor, Device};
474    /// let a = Tensor::arange(2., 5., &Device::Cpu)?;
475    ///
476    /// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
477    /// # Ok::<(), candle_core::Error>(())
478    /// ```
479    pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
480        Self::arange_step(start, end, D::one(), device)
481    }
482
483    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
484    /// difference `step` from `start`.
485    ///```rust
486    /// use candle_core::{Tensor, Device};
487    /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
488    ///
489    /// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
490    /// # Ok::<(), candle_core::Error>(())
491    /// ```
492    pub fn arange_step<D: crate::WithDType>(
493        start: D,
494        end: D,
495        step: D,
496        device: &Device,
497    ) -> Result<Self> {
498        if D::is_zero(&step) {
499            bail!("step cannot be zero")
500        }
501        let mut data = vec![];
502        let mut current = start;
503        if step >= D::zero() {
504            while current < end {
505                data.push(current);
506                current += step;
507            }
508        } else {
509            while current > end {
510                data.push(current);
511                current += step;
512            }
513        }
514        let len = data.len();
515        Self::from_vec_impl(data, len, device, false)
516    }
517
518    pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
519        data: Vec<D>,
520        shape: S,
521        device: &Device,
522        is_variable: bool,
523    ) -> Result<Self> {
524        let shape = shape.into_shape(data.len())?;
525        let storage = device.storage_owned(data)?;
526        let none = BackpropOp::none();
527        Ok(from_storage(storage, shape, none, is_variable))
528    }
529
530    /// Creates a new tensor initialized with values from the input vector. The number of elements
531    /// in this vector must be the same as the number of elements defined by the shape.
532    /// If the device is cpu, no data copy is made.
533    ///```rust
534    /// use candle_core::{Tensor, Device};
535    /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
536    ///
537    /// assert_eq!(a.to_vec2::<f64>()?, &[
538    ///     [1., 2., 3.],
539    ///     [4., 5., 6.]
540    /// ]);
541    /// # Ok::<(), candle_core::Error>(())
542    /// ```
543    pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
544        data: Vec<D>,
545        shape: S,
546        device: &Device,
547    ) -> Result<Self> {
548        Self::from_vec_impl(data, shape, device, false)
549    }
550
551    /// Creates a new tensor initialized with values from the input slice. The number of elements
552    /// in this vector must be the same as the number of elements defined by the shape.
553    ///```rust
554    /// use candle_core::{Tensor, Device};
555    /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
556    /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
557    ///
558    /// assert_eq!(a.to_vec2::<f64>()?, &[
559    ///     [2., 3., 4.],
560    ///     [5., 6., 7.]
561    /// ]);
562    /// # Ok::<(), candle_core::Error>(())
563    /// ```
564    pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
565        array: &[D],
566        shape: S,
567        device: &Device,
568    ) -> Result<Self> {
569        let shape = shape.into_shape(array.len())?;
570        let storage = device.storage_from_slice(array)?;
571        let none = BackpropOp::none();
572        Ok(from_storage(storage, shape, none, false))
573    }
574
575    pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
576        let lhs = self.shape();
577        let rhs = rhs.shape();
578        if lhs != rhs {
579            Err(Error::ShapeMismatchBinaryOp {
580                lhs: lhs.clone(),
581                rhs: rhs.clone(),
582                op,
583            }
584            .bt())
585        } else {
586            Ok(lhs)
587        }
588    }
589
590    /// Returns true if the computation graph should track this op, that is if it is
591    /// a variable or if it has some variable as dependencies.
592    pub fn track_op(&self) -> bool {
593        self.is_variable || self.op.is_some()
594    }
595
596    /// Creates a fresh tensor structure based on a storage and a shape.
597    ///
598    /// # Note
599    /// - This uses contiguous strides
600    /// - Ensure the shape is compatible with the shape of the storage.
601    pub fn from_storage<S: Into<Shape>>(
602        storage: Storage,
603        shape: S,
604        op: BackpropOp,
605        is_variable: bool,
606    ) -> Tensor {
607        from_storage(storage, shape, op, is_variable)
608    }
609
610    // TODO: Also make an inplace version or a pre-allocated? This could be tricky
611    // if this can create cycles in the compute graph.
612    binary_op!(add, Add);
613    binary_op!(mul, Mul);
614    binary_op!(sub, Sub);
615    binary_op!(div, Div);
616    binary_op_scalar!(maximum, Maximum);
617    binary_op_scalar!(minimum, Minimum);
618    broadcast_binary_op!(broadcast_add, add);
619    broadcast_binary_op!(broadcast_mul, mul);
620    broadcast_binary_op!(broadcast_sub, sub);
621    broadcast_binary_op!(broadcast_div, div);
622    broadcast_binary_op!(broadcast_maximum, maximum);
623    broadcast_binary_op!(broadcast_minimum, minimum);
624    broadcast_binary_op!(broadcast_eq, eq);
625    broadcast_binary_op!(broadcast_ne, ne);
626    broadcast_binary_op!(broadcast_lt, lt);
627    broadcast_binary_op!(broadcast_le, le);
628    broadcast_binary_op!(broadcast_gt, gt);
629    broadcast_binary_op!(broadcast_ge, ge);
630
631    unary_op!(recip, Recip);
632    unary_op!(neg, Neg);
633    unary_op!(exp, Exp);
634    unary_op!(log, Log);
635    unary_op!(sin, Sin);
636    unary_op!(cos, Cos);
637    unary_op!(tanh, Tanh);
638    unary_op!(abs, Abs);
639    unary_op!(sqr, Sqr);
640    unary_op!(sqrt, Sqrt);
641    unary_op!(gelu, Gelu);
642    unary_op!(gelu_erf, GeluErf);
643    unary_op!(erf, Erf);
644    unary_op!(relu, Relu);
645    unary_op!(silu, Silu);
646    unary_op!(ceil, Ceil);
647    unary_op!(floor, Floor);
648    unary_op!(round, Round);
649    unary_op!(sign, Sign);
650
651    /// Round element of the input tensor to the nearest integer.
652    ///
653    /// If the number of decimals is negative, it specifies the number of positions to the left of
654    /// the decimal point.
655    pub fn round_to(&self, decimals: i32) -> Result<Self> {
656        let mult = 10f64.powi(decimals);
657        (self * mult)?.round()? * (1f64 / mult)
658    }
659
660    /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
661    /// dimensions, an error is returned instead.
662    pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
663        if self.rank() != 0 {
664            Err(Error::UnexpectedNumberOfDims {
665                expected: 0,
666                got: self.rank(),
667                shape: self.shape().clone(),
668            }
669            .bt())?
670        }
671        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
672            let data = S::cpu_storage_as_slice(cpu_storage)?;
673            Ok::<_, Error>(data[self.layout().start_offset()])
674        };
675        match &*self.storage() {
676            Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
677            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
678            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
679        }
680    }
681
682    /// An alias for `to_scalar`.
683    pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
684        self.to_scalar::<S>()
685    }
686
687    /// Repeat this tensor along the specified dimensions.
688    pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
689        // Similar to PyTorch, we extend the number of dimensions of self if needed.
690        let repeats = shape.into();
691        let repeats = repeats.dims();
692        let mut inp = if self.rank() < repeats.len() {
693            let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
694            self.reshape(shape)?
695        } else {
696            self.clone()
697        };
698        for (idx, &repeat) in repeats.iter().enumerate() {
699            if repeat > 1 {
700                inp = Tensor::cat(&vec![&inp; repeat], idx)?
701            }
702        }
703        Ok(inp)
704    }
705
706    /// Creates grids of coordinates specified by the 1D inputs.
707    ///
708    /// # Arguments
709    ///
710    /// * `args` - A slice of 1D tensors.
711    /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
712    ///   first dimension corresponds to the cardinality of the second input and the second
713    ///   dimension corresponds to the cardinality of the first input. If ij is selected, the
714    ///   dimensions are in the same order as the cardinality of the inputs.
715    ///
716    /// # Examples
717    ///
718    /// ```rust
719    /// use candle_core::{Tensor, Device, Shape};
720    /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
721    /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
722    ///
723    /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
724    ///
725    /// assert_eq!(grids_xy.len(), 2);
726    /// assert_eq!(grids_xy[0].dims(), &[3, 3]);
727    ///
728    /// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
729    /// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
730    ///
731    /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
732    ///
733    /// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
734    /// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
735    /// # Ok::<(), candle_core::Error>(())
736    /// ```
737    ///
738    /// # Errors
739    ///
740    /// * Will return `Err` if `args` contains less than 2 tensors.
741    ///
742    pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
743        if args.len() <= 1 {
744            Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
745        }
746        let args: Vec<_> = if xy_indexing {
747            args.iter().rev().collect()
748        } else {
749            args.iter().collect()
750        };
751
752        let mut shape = Vec::with_capacity(args.len());
753        for arg in args.iter() {
754            shape.push(arg.as_ref().dims1()?)
755        }
756
757        let mut grids = Vec::with_capacity(args.len());
758        for idx in 0..args.len() {
759            let mut ones = vec![1usize; args.len()];
760            ones[idx] = shape[idx];
761            let arg = args[idx].as_ref().reshape(ones)?;
762            let mut repeats = shape.clone();
763            repeats[idx] = 1;
764            let repeated_tensor = arg.repeat(repeats)?;
765            grids.push(repeated_tensor);
766        }
767        if xy_indexing {
768            grids.reverse();
769        }
770        Ok(grids)
771    }
772
773    /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
774    /// The input values `mul` and `add` are casted to the appropriate type so some rounding might
775    /// be performed.
776    ///
777    /// ```rust
778    /// use candle_core::{Tensor, Device};
779    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
780    /// let a = a.affine(4., -2.)?;
781    /// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
782    /// # Ok::<(), candle_core::Error>(())
783    /// ```
784    pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
785        if self.elem_count() == 0 {
786            return Ok(self.clone());
787        }
788        let storage = self.storage().affine(self.layout(), mul, add)?;
789        let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
790        Ok(from_storage(storage, self.shape(), op, false))
791    }
792
793    /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
794    pub fn elu(&self, alpha: f64) -> Result<Self> {
795        if self.elem_count() == 0 {
796            return Ok(self.clone());
797        }
798        let storage = self.storage().elu(self.layout(), alpha)?;
799        let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
800        Ok(from_storage(storage, self.shape(), op, false))
801    }
802
803    /// Raise the tensor to some float exponent `e`.
804    pub fn powf(&self, e: f64) -> Result<Self> {
805        if self.elem_count() == 0 {
806            return Ok(self.clone());
807        }
808        let storage = self.storage().powf(self.layout(), e)?;
809        let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
810        Ok(from_storage(storage, self.shape(), op, false))
811    }
812
813    pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
814        if dim >= self.dims().len() {
815            Err(Error::DimOutOfRange {
816                shape: self.shape().clone(),
817                dim: dim as i32,
818                op,
819            }
820            .bt())?
821        } else {
822            Ok(())
823        }
824    }
825
826    /// Split a tensor into the specified number of chunks, this may return less chunks than
827    /// specified.
828    pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
829        let dim = dim.to_index(self.shape(), "chunk")?;
830        let size = self.dim(dim)?;
831        if size < chunks {
832            (0..size).map(|i| self.narrow(dim, i, 1)).collect()
833        } else {
834            let chunk_size = size / chunks;
835            let cnt_additional = size % chunks;
836            let mut tensors = vec![];
837            let mut sum_chunk_size = 0;
838            for i in 0..chunks {
839                let chunk_size = if i < cnt_additional {
840                    chunk_size + 1
841                } else {
842                    chunk_size
843                };
844                let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
845                tensors.push(tensor);
846                sum_chunk_size += chunk_size
847            }
848            Ok(tensors)
849        }
850    }
851
852    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
853    /// ranges from `start` to `start + len`.
854    /// ```
855    /// use candle_core::{Tensor, Device};
856    /// let a = Tensor::new(&[
857    ///     [0f32, 1., 2.],
858    ///     [3.  , 4., 5.],
859    ///     [6.  , 7., 8.]
860    /// ], &Device::Cpu)?;
861    ///
862    /// let b = a.narrow(0, 1, 2)?;
863    /// assert_eq!(b.shape().dims(), &[2, 3]);
864    /// assert_eq!(b.to_vec2::<f32>()?, &[
865    ///     [3., 4., 5.],
866    ///     [6., 7., 8.]
867    /// ]);
868    ///
869    /// let c = a.narrow(1, 1, 1)?;
870    /// assert_eq!(c.shape().dims(), &[3, 1]);
871    /// assert_eq!(c.to_vec2::<f32>()?, &[
872    ///     [1.],
873    ///     [4.],
874    ///     [7.]
875    /// ]);
876    /// # Ok::<(), candle_core::Error>(())
877    /// ```
878    pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
879        let dims = self.dims();
880        let dim = dim.to_index(self.shape(), "narrow")?;
881        let err = |msg| {
882            Err::<(), _>(
883                Error::NarrowInvalidArgs {
884                    shape: self.shape().clone(),
885                    dim,
886                    start,
887                    len,
888                    msg,
889                }
890                .bt(),
891            )
892        };
893        if start > dims[dim] {
894            err("start > dim_len")?
895        }
896        if start.saturating_add(len) > dims[dim] {
897            err("start + len > dim_len")?
898        }
899        if start == 0 && dims[dim] == len {
900            Ok(self.clone())
901        } else {
902            let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
903            let layout = self.layout().narrow(dim, start, len)?;
904            let tensor_ = Tensor_ {
905                id: TensorId::new(),
906                storage: self.storage.clone(),
907                layout,
908                op,
909                is_variable: false,
910                dtype: self.dtype,
911                device: self.device.clone(),
912            };
913            Ok(Tensor(Arc::new(tensor_)))
914        }
915    }
916
917    fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
918        match dims {
919            [] => Ok(self),
920            [i] => self.squeeze(*i),
921            dims => {
922                let dims = self
923                    .dims()
924                    .iter()
925                    .enumerate()
926                    .filter_map(|(dim_idx, &v)| {
927                        if dims.contains(&dim_idx) {
928                            None
929                        } else {
930                            Some(v)
931                        }
932                    })
933                    .collect::<Vec<_>>();
934                self.reshape(dims)
935            }
936        }
937    }
938
939    fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
940        let dim = dim.to_index(self.shape(), op.name())?;
941        let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
942        let mut dims = self.dims().to_vec();
943        dims[dim] = 1;
944        let op = match op {
945            ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
946                BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
947            }
948            ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
949        };
950        let res = from_storage(storage, dims, op, false);
951        if keepdim {
952            Ok(res)
953        } else {
954            res.squeeze_dims(&[dim])
955        }
956    }
957
958    fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
959        let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
960        let storage = self
961            .storage()
962            .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
963        let mut dims = self.dims().to_vec();
964        for &sum_dim in sum_dims.iter() {
965            dims[sum_dim] = 1
966        }
967        let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
968        let sum = from_storage(storage, dims, op, false);
969        if keepdim {
970            Ok(sum)
971        } else {
972            sum.squeeze_dims(&sum_dims)
973        }
974    }
975
976    /// Roll the tensor input along the given dimension.
977    /// Elements that are shifted beyond the last position are re-introduced at the first position.
978    ///
979    /// ```rust
980    /// # use candle_core::{Tensor, Device};
981    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
982    /// let tensor = tensor.roll(1, 0)?;
983    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
984    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
985    /// let tensor = tensor.roll(-1, 0)?;
986    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
987    /// # Ok::<(), candle_core::Error>(())
988    /// ```
989    pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
990    where
991        D: Dim + Clone,
992    {
993        let dim = dim.to_index(self.shape(), "roll")?;
994        let dim_size = self.dim(dim)?;
995        let shift = shift.rem_euclid(dim_size as i32) as usize;
996        if shift == 0 {
997            Ok(self.clone())
998        } else {
999            let a = self.narrow(dim, 0, dim_size - shift)?;
1000            let b = self.narrow(dim, dim_size - shift, shift)?;
1001            Tensor::cat(&[&b, &a], dim)
1002        }
1003    }
1004
1005    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
1006    /// input dimensions.
1007    ///
1008    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
1009    /// that the number of elements for each dimension index in `sum_dims` is 1.
1010    ///
1011    /// ```rust
1012    /// use candle_core::{Tensor, Device};
1013    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
1014    /// let s = a.sum_keepdim(0)?;
1015    /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
1016    /// let s = a.sum_keepdim(1)?;
1017    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
1018    /// let s = a.sum_keepdim((0, 1))?;
1019    /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
1020    /// # Ok::<(), candle_core::Error>(())
1021    /// ```
1022    pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1023        self.sum_impl(sum_dims, true)
1024    }
1025
1026    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
1027    /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
1028    /// kept.
1029    pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1030        self.sum_impl(sum_dims, false)
1031    }
1032
1033    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
1034    /// input dimensions.
1035    ///
1036    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
1037    /// that the number of elements for each dimension index in `mean_dims` is 1.
1038    ///
1039    /// ```rust
1040    /// use candle_core::{Tensor, Device};
1041    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
1042    /// let s = a.mean_keepdim(0)?;
1043    /// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
1044    /// let s = a.mean_keepdim(1)?;
1045    /// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
1046    /// let s = a.mean_keepdim((0, 1))?;
1047    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
1048    /// # Ok::<(), candle_core::Error>(())
1049    /// ```
1050    pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1051        let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
1052        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1053        let scale = 1f64 / (reduced_dim as f64);
1054        self.sum_impl(mean_dims, true)? * scale
1055    }
1056
1057    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
1058    /// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
1059    /// kept.
1060    pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1061        let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
1062        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1063        let scale = 1f64 / (reduced_dim as f64);
1064        self.sum_impl(mean_dims, false)? * scale
1065    }
1066
1067    /// Returns the unbiased variance over the selected dimension.
1068    pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1069        let dim = dim.to_index(self.shape(), "var")?;
1070        let mean = self.mean_keepdim(dim)?;
1071        let squares = self.broadcast_sub(&mean)?.sqr()?;
1072        squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1073    }
1074
1075    /// Returns the unbiased variance over the selected dimension.
1076    pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1077        let dim = dim.to_index(self.shape(), "var")?;
1078        self.var_keepdim(dim)?.squeeze(dim)
1079    }
1080
1081    /// Gathers the maximum value across the selected dimension. The resulting shape has the same
1082    /// number of dimensions as the original tensor and the select dimension has a single element.
1083    pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1084        self.reduce_impl(dim, true, ReduceOp::Max)
1085    }
1086
1087    /// Similar to `max_keepdim` but the target dimension is squeezed.
1088    pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1089        self.reduce_impl(dim, false, ReduceOp::Max)
1090    }
1091
1092    /// Gathers the minimum value across the selected dimension. The resulting shape has the same
1093    /// number of dimensions as the original tensor and the select dimension has a single element.
1094    pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1095        self.reduce_impl(dim, true, ReduceOp::Min)
1096    }
1097
1098    /// Similar to `min_keepdim` but the target dimension is squeezed.
1099    pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1100        self.reduce_impl(dim, false, ReduceOp::Min)
1101    }
1102
1103    pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1104        self.reduce_impl(dim, true, ReduceOp::ArgMax)
1105    }
1106
1107    /// Similar to `argmax_keepdim` but the target dimension is squeezed.
1108    pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1109        self.reduce_impl(dim, false, ReduceOp::ArgMax)
1110    }
1111
1112    pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1113        self.reduce_impl(dim, true, ReduceOp::ArgMin)
1114    }
1115
1116    /// Similar to `argmin_keepdim` but the target dimension is squeezed.
1117    pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1118        self.reduce_impl(dim, false, ReduceOp::ArgMin)
1119    }
1120
1121    /// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
1122    /// comparison operation is specified by the `op` argument.
1123    ///
1124    /// The returned tensor has the same shape as the original tensors and uses `u8` elements.
1125    pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1126        let rhs = match rhs.to_tensor_scalar()? {
1127            crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1128            crate::scalar::TensorScalar::Scalar(rhs) => rhs
1129                .to_dtype(self.dtype())?
1130                .to_device(self.device())?
1131                .broadcast_as(self.shape())?,
1132        };
1133        let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1134        let storage = self
1135            .storage()
1136            .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1137        let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1138        Ok(from_storage(storage, shape.dims(), op, false))
1139    }
1140
1141    /// Element-wise equality.
1142    pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1143        self.cmp(rhs, CmpOp::Eq)
1144    }
1145
1146    /// Element-wise non-equality.
1147    pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1148        self.cmp(rhs, CmpOp::Ne)
1149    }
1150
1151    /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
1152    /// rhs` and 0 otherwise.
1153    pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1154        self.cmp(rhs, CmpOp::Lt)
1155    }
1156
1157    /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
1158    /// rhs` and 0 otherwise.
1159    pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1160        self.cmp(rhs, CmpOp::Gt)
1161    }
1162
1163    /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
1164    /// rhs` and 0 otherwise.
1165    pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1166        self.cmp(rhs, CmpOp::Ge)
1167    }
1168
1169    /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
1170    /// rhs` and 0 otherwise.
1171    pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1172        self.cmp(rhs, CmpOp::Le)
1173    }
1174
1175    /// Clamp the tensor values to be between `min` and `max`.
1176    pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1177        self.maximum(min)?.minimum(max)
1178    }
1179
1180    /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
1181    ///
1182    /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
1183    /// tensor also has three dimensions, `(batch, channels, target_size)`.
1184    pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1185        let (n, c, _l) = self.dims3()?;
1186        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1187        let storage = self
1188            .storage()
1189            .upsample_nearest1d(self.layout(), target_size)?;
1190        Ok(from_storage(storage, (n, c, target_size), op, false))
1191    }
1192
1193    /// Alias for `interpolate1d`.
1194    pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1195        self.interpolate1d(target_size)
1196    }
1197
1198    /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
1199    /// nearest element.
1200    ///
1201    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1202    /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
1203    pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1204        let (n, c, _h, _w) = self.dims4()?;
1205        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1206            arg,
1207            target_h,
1208            target_w,
1209        });
1210        let storage = self
1211            .storage()
1212            .upsample_nearest2d(self.layout(), target_h, target_w)?;
1213        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1214    }
1215
1216    /// Alias for `interpolate2d`.
1217    pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1218        self.interpolate2d(target_h, target_w)
1219    }
1220
1221    /// Bilinear interpolation to resize the input tensor to the specified size.
1222    ///
1223    /// The input tensor should have four dimensions: `(batch, channels, h, w)`.
1224    /// The returned tensor also has four dimensions: `(batch, channels, target_h, target_w)`.
1225    ///
1226    /// # Arguments
1227    ///
1228    /// * `target_h` - Target height
1229    /// * `target_w` - Target width  
1230    /// * `align_corners` - If true, corner pixels are aligned. If false (default),
1231    ///   pixels are treated as areas (matches PyTorch default behavior).
1232    ///
1233    /// # Example
1234    ///
1235    /// ```rust
1236    /// use candle_core::{Tensor, Device};
1237    /// # fn main() -> candle_core::Result<()> {
1238    /// let t = Tensor::arange(0f32, 16f32, &Device::Cpu)?.reshape((1, 1, 4, 4))?;
1239    /// let upsampled = t.upsample_bilinear2d(8, 8, false)?;
1240    /// assert_eq!(upsampled.dims(), &[1, 1, 8, 8]);
1241    /// # Ok(())
1242    /// # }
1243    /// ```
1244    pub fn upsample_bilinear2d(
1245        &self,
1246        target_h: usize,
1247        target_w: usize,
1248        align_corners: bool,
1249    ) -> Result<Self> {
1250        let (n, c, _h, _w) = self.dims4()?;
1251        let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1252            arg,
1253            target_h,
1254            target_w,
1255            align_corners,
1256        });
1257        // Pass None for scale factors (size mode)
1258        let storage = self.storage().upsample_bilinear2d(
1259            self.layout(),
1260            target_h,
1261            target_w,
1262            align_corners,
1263            None,
1264            None,
1265        )?;
1266        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1267    }
1268
1269    /// Bilinear interpolation using scale factors.
1270    ///
1271    /// Similar to `upsample_bilinear2d` but uses scale factors instead of absolute sizes.
1272    /// This matches PyTorch's `interpolate(scale_factor=...)` behavior.
1273    ///
1274    /// # Arguments
1275    ///
1276    /// * `scale_h` - Height scaling factor
1277    /// * `scale_w` - Width scaling factor
1278    /// * `align_corners` - If true, corner pixels are aligned
1279    ///
1280    /// # Example
1281    ///
1282    /// ```rust
1283    /// use candle_core::{Tensor, Device};
1284    /// # fn main() -> candle_core::Result<()> {
1285    /// let t = Tensor::arange(0f32, 16f32, &Device::Cpu)?.reshape((1, 1, 4, 4))?;
1286    /// // Scale by 2x in both dimensions
1287    /// let upsampled = t.upsample_bilinear2d_with_scale(2.0, 2.0, false)?;
1288    /// assert_eq!(upsampled.dims(), &[1, 1, 8, 8]);
1289    /// # Ok(())
1290    /// # }
1291    /// ```
1292    pub fn upsample_bilinear2d_with_scale(
1293        &self,
1294        scale_h: f64,
1295        scale_w: f64,
1296        align_corners: bool,
1297    ) -> Result<Self> {
1298        let (n, c, height_in, width_in) = self.dims4()?;
1299
1300        // Calculate output size (floor, matching PyTorch)
1301        let height_out = (height_in as f64 * scale_h).floor() as usize;
1302        let width_out = (width_in as f64 * scale_w).floor() as usize;
1303
1304        // Early return if size unchanged
1305        if height_in == height_out && width_in == width_out {
1306            return Ok(self.clone());
1307        }
1308
1309        let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1310            arg,
1311            target_h: height_out,
1312            target_w: width_out,
1313            align_corners,
1314        });
1315
1316        // Pass original scale factors (scale_factor mode)
1317        // This ensures PyTorch-compatible scale calculation
1318        let storage = self.storage().upsample_bilinear2d(
1319            self.layout(),
1320            height_out,
1321            width_out,
1322            align_corners,
1323            Some(scale_h),
1324            Some(scale_w),
1325        )?;
1326        Ok(from_storage(
1327            storage,
1328            (n, c, height_out, width_out),
1329            op,
1330            false,
1331        ))
1332    }
1333
1334    /// 2D average pooling over an input tensor with multiple channels.
1335    ///
1336    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1337    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1338    /// the two last dimensions using a kernel of size `sz`. The returned element is the average
1339    /// value over the kernel window.
1340    pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1341        let sz = sz.to_usize2();
1342        self.avg_pool2d_with_stride(sz, sz)
1343    }
1344
1345    /// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
1346    /// kernel size.
1347    pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1348        &self,
1349        kernel_size: T,
1350        stride: T,
1351    ) -> Result<Self> {
1352        let kernel_size = kernel_size.to_usize2();
1353        let stride = stride.to_usize2();
1354        let (n, c, h, w) = self.dims4()?;
1355        if h < kernel_size.0 || w < kernel_size.1 {
1356            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1357        }
1358        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
1359        let h_out = (h - kernel_size.0) / stride.0 + 1;
1360        let w_out = (w - kernel_size.1) / stride.1 + 1;
1361        let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1362            arg,
1363            kernel_size,
1364            stride,
1365        });
1366        let storage = self
1367            .storage()
1368            .avg_pool2d(self.layout(), kernel_size, stride)?;
1369        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1370    }
1371
1372    /// 2D max pooling over an input tensor with multiple channels.
1373    ///
1374    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1375    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1376    /// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
1377    /// value over the kernel window.
1378    pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1379        let sz = sz.to_usize2();
1380        self.max_pool2d_with_stride(sz, sz)
1381    }
1382
1383    /// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
1384    /// kernel size.
1385    pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1386        &self,
1387        kernel_size: T,
1388        stride: T,
1389    ) -> Result<Self> {
1390        let kernel_size = kernel_size.to_usize2();
1391        let stride = stride.to_usize2();
1392        let (n, c, h, w) = self.dims4()?;
1393        if h < kernel_size.0 || w < kernel_size.1 {
1394            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1395        }
1396        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
1397        let h_out = (h - kernel_size.0) / stride.0 + 1;
1398        let w_out = (w - kernel_size.1) / stride.1 + 1;
1399        let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1400            arg,
1401            kernel_size,
1402            stride,
1403        });
1404        let storage = self
1405            .storage()
1406            .max_pool2d(self.layout(), kernel_size, stride)?;
1407        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1408    }
1409
1410    /// Computes the dot product of two 1D tensors.
1411    ///
1412    /// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
1413    /// - Panics if shapes are not compatible
1414    /// - Not supported for integer dtypes
1415    ///
1416    /// # Example (vectors)
1417    /// ```rust
1418    /// use candle_core::{Tensor, Device};
1419    /// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
1420    /// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
1421    /// let res = t1.dot(&t2)?;
1422    /// assert_eq!(res.to_scalar::<f64>()?, 32.);
1423    /// # Ok::<(), candle_core::Error>(())
1424    /// ```
1425    pub fn dot(&self, rhs: &Self) -> Result<Self> {
1426        if self.dims().len() != 1 || rhs.dims().len() != 1 {
1427            return Err(Error::ShapeMismatchBinaryOp {
1428                lhs: self.shape().clone(),
1429                rhs: rhs.shape().clone(),
1430                op: "dot",
1431            });
1432        }
1433
1434        (self * rhs).and_then(|ret| ret.sum_all())
1435    }
1436
1437    /// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
1438    /// - Output is `sqrt(sum(x^2))`.
1439    /// - Always returns a scalar (`[]` shape).
1440    ///
1441    /// # Example
1442    /// ```rust
1443    /// use candle_core::{Tensor, Device};
1444    /// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
1445    /// let norm = t.norm()?;
1446    /// assert_eq!(norm.to_scalar::<f64>()?, 5.);
1447    /// # Ok::<(), candle_core::Error>(())
1448    /// ```
1449    pub fn norm(&self) -> Result<Self> {
1450        if self.dtype().is_int() {
1451            bail!("norm not supported for integer dtypes");
1452        }
1453
1454        self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
1455    }
1456
1457    /// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
1458    ///
1459    /// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
1460    /// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
1461    ///
1462    /// # Example
1463    /// ```rust
1464    /// use candle_core::{Tensor, Device};
1465    /// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
1466    /// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
1467    /// let res = mat.mv(&vec)?;
1468    /// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
1469    /// # Ok::<(), candle_core::Error>(())
1470    /// ```
1471    pub fn mv(&self, rhs: &Self) -> Result<Self> {
1472        // Strict shape checks
1473        let lhs_dims = self.dims();
1474        let rhs_dims = rhs.dims();
1475        if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
1476            return Err(Error::ShapeMismatchBinaryOp {
1477                lhs: self.shape().clone(),
1478                rhs: rhs.shape().clone(),
1479                op: "mv",
1480            });
1481        }
1482
1483        // Direct matmul after ensuring rhs is column vector
1484        self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
1485    }
1486
1487    /// Returns the matrix-multiplication of the input tensor with the other provided tensor.
1488    ///
1489    /// # Arguments
1490    ///
1491    /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
1492    /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
1493    ///
1494    /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
1495    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1496        let a_dims = self.shape().dims();
1497        let b_dims = rhs.shape().dims();
1498
1499        let dim = a_dims.len();
1500
1501        if dim < 2 || b_dims.len() != dim {
1502            Err(Error::ShapeMismatchBinaryOp {
1503                lhs: self.shape().clone(),
1504                rhs: rhs.shape().clone(),
1505                op: "matmul",
1506            }
1507            .bt())?
1508        }
1509
1510        let m = a_dims[dim - 2];
1511        let k = a_dims[dim - 1];
1512        let k2 = b_dims[dim - 2];
1513        let n = b_dims[dim - 1];
1514
1515        let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1516        if c_shape.elem_count() == 0 || k == 0 {
1517            return Tensor::zeros(c_shape, self.dtype(), self.device());
1518        }
1519        let batching: usize = a_dims[..dim - 2].iter().product();
1520        let batching_b: usize = b_dims[..dim - 2].iter().product();
1521        if k != k2 || batching != batching_b {
1522            Err(Error::ShapeMismatchBinaryOp {
1523                lhs: self.shape().clone(),
1524                rhs: rhs.shape().clone(),
1525                op: "matmul",
1526            }
1527            .bt())?
1528        }
1529
1530        let storage = self.storage().matmul(
1531            &rhs.storage(),
1532            (batching, m, n, k),
1533            self.layout(),
1534            rhs.layout(),
1535        )?;
1536        let op = BackpropOp::new2(self, rhs, Op::Matmul);
1537        Ok(from_storage(storage, c_shape, op, false))
1538    }
1539
1540    /// Matrix-multiplication with broadcasting support.
1541    ///
1542    /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
1543    /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
1544    /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
1545    pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1546        let lhs = self;
1547        let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1548        let l_broadcast = l_shape != *lhs.shape();
1549        let r_broadcast = r_shape != *rhs.shape();
1550        // TODO: Avoid concretising the broadcasted matrixes via contiguous.
1551        match (l_broadcast, r_broadcast) {
1552            (true, true) => lhs
1553                .broadcast_as(&l_shape)?
1554                .contiguous()?
1555                .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1556            (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1557            (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1558            (false, false) => lhs.matmul(rhs),
1559        }
1560    }
1561
1562    /// Returns a tensor with the same shape as the input tensor, the values are taken from
1563    /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
1564    /// input tensor is equal to zero.
1565    pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1566        let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1567        let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1568        let storage = self.storage().where_cond(
1569            self.layout(),
1570            &on_true.storage(),
1571            on_true.layout(),
1572            &on_false.storage(),
1573            on_false.layout(),
1574        )?;
1575        let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1576        Ok(from_storage(storage, shape, op, false))
1577    }
1578
1579    /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
1580    /// values hold in the `ids` tensor.
1581    ///
1582    /// # Arguments
1583    ///
1584    /// * `self` - A tensor with dimensions `v, h`.
1585    /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
1586    ///
1587    /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
1588    /// vocabulary size, and `h` the hidden size.
1589    ///
1590    /// ```rust
1591    /// use candle_core::{Tensor, Device};
1592    /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1593    /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
1594    /// let emb = values.embedding(&ids)?;
1595    /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
1596    /// # Ok::<(), candle_core::Error>(())
1597    /// ```
1598    pub fn embedding(&self, ids: &Self) -> Result<Self> {
1599        if self.rank() != 2 || ids.rank() != 1 {
1600            Err(Error::ShapeMismatchBinaryOp {
1601                lhs: self.shape().clone(),
1602                rhs: ids.shape().clone(),
1603                op: "embedding",
1604            }
1605            .bt())?
1606        }
1607        self.index_select(ids, 0)
1608    }
1609
1610    fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1611        let source_dims = source.dims();
1612        let self_dims = self.dims();
1613        let mismatch = if source_dims.len() != self_dims.len() {
1614            true
1615        } else {
1616            let mut mismatch = false;
1617            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1618                if i != dim && d1 != d2 {
1619                    mismatch = true;
1620                    break;
1621                }
1622            }
1623            mismatch
1624        };
1625        if mismatch {
1626            Err(Error::ShapeMismatchBinaryOp {
1627                op: "scatter (self, src)",
1628                lhs: self.shape().clone(),
1629                rhs: source.shape().clone(),
1630            }
1631            .bt())?
1632        }
1633        if indexes.dims() != source.dims() {
1634            Err(Error::ShapeMismatchBinaryOp {
1635                op: "scatter (indexes, src)",
1636                lhs: indexes.shape().clone(),
1637                rhs: source.shape().clone(),
1638            }
1639            .bt())?
1640        }
1641        Ok(())
1642    }
1643
1644    pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1645        let dim = dim.to_index(self.shape(), "scatter")?;
1646        self.scatter_checks(indexes, source, dim)?;
1647        let shape = self.shape();
1648        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1649        self.storage()
1650            .copy_strided_src(&mut storage, 0, self.layout())?;
1651        let layout = Layout::contiguous(shape);
1652        storage.scatter_set(
1653            &layout,
1654            &indexes.storage(),
1655            indexes.layout(),
1656            &source.storage(),
1657            source.layout(),
1658            dim,
1659        )?;
1660        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1661            Op::Scatter(t1, t2, t3, dim)
1662        });
1663        Ok(from_storage(storage, self.shape(), op, false))
1664    }
1665
1666    pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1667        if self.same_storage(source) {
1668            crate::bail!("cannot use slice_set when self and src share their storage")
1669        }
1670        let dim = dim.to_index(self.shape(), "scatter-set")?;
1671        self.scatter_checks(indexes, source, dim)?;
1672        self.storage_mut().scatter_set(
1673            self.layout(),
1674            &indexes.storage(),
1675            indexes.layout(),
1676            &source.storage(),
1677            source.layout(),
1678            dim,
1679        )?;
1680        Ok(())
1681    }
1682
1683    pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1684        let dim = dim.to_index(self.shape(), "scatter-add")?;
1685        self.scatter_checks(indexes, source, dim)?;
1686        let shape = self.shape();
1687        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1688        self.storage()
1689            .copy_strided_src(&mut storage, 0, self.layout())?;
1690        let layout = Layout::contiguous(shape);
1691        storage.scatter_add(
1692            &layout,
1693            &indexes.storage(),
1694            indexes.layout(),
1695            &source.storage(),
1696            source.layout(),
1697            dim,
1698        )?;
1699        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1700            Op::ScatterAdd(t1, t2, t3, dim)
1701        });
1702        Ok(from_storage(storage, self.shape(), op, false))
1703    }
1704
1705    pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1706        if self.same_storage(source) {
1707            crate::bail!("cannot use slice_set when self and src share their storage")
1708        }
1709        let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1710        self.scatter_checks(indexes, source, dim)?;
1711        self.storage_mut().scatter_add(
1712            self.layout(),
1713            &indexes.storage(),
1714            indexes.layout(),
1715            &source.storage(),
1716            source.layout(),
1717            dim,
1718        )?;
1719        Ok(())
1720    }
1721
1722    /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
1723    pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1724        let dim = dim.to_index(self.shape(), "slice-scatter")?;
1725        if dim == 0 {
1726            self.slice_scatter0(src, start)
1727        } else {
1728            // TODO: Maybe we want to add a more efficient implementation at some point.
1729            self.transpose(0, dim)?
1730                .slice_scatter0(&src.transpose(0, dim)?, start)?
1731                .transpose(0, dim)
1732        }
1733    }
1734
1735    /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
1736    pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1737        if self.dtype() != src.dtype() {
1738            Err(Error::DTypeMismatchBinaryOp {
1739                lhs: self.dtype(),
1740                rhs: src.dtype(),
1741                op: "slice-scatter",
1742            }
1743            .bt())?
1744        }
1745        if self.device().location() != src.device.location() {
1746            Err(Error::DeviceMismatchBinaryOp {
1747                lhs: self.device().location(),
1748                rhs: src.device().location(),
1749                op: "slice-scatter",
1750            }
1751            .bt())?
1752        }
1753        if self.rank() != src.rank() {
1754            Err(Error::UnexpectedNumberOfDims {
1755                expected: self.rank(),
1756                got: src.rank(),
1757                shape: src.shape().clone(),
1758            }
1759            .bt())?
1760        }
1761        let shape_ok =
1762            self.dims()
1763                .iter()
1764                .zip(src.dims().iter())
1765                .enumerate()
1766                .all(|(dim_idx, (&d1, &d2))| {
1767                    if 0 == dim_idx {
1768                        d2 + start <= d1
1769                    } else {
1770                        d1 == d2
1771                    }
1772                });
1773        if !shape_ok {
1774            Err(Error::ShapeMismatchBinaryOp {
1775                op: "slice-scatter (self, src)",
1776                lhs: self.shape().clone(),
1777                rhs: src.shape().clone(),
1778            }
1779            .bt())?
1780        }
1781        let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1782        self.storage()
1783            .copy_strided_src(&mut storage, 0, self.layout())?;
1784        let offset = start * src.dims()[1..].iter().product::<usize>();
1785        src.storage()
1786            .copy_strided_src(&mut storage, offset, src.layout())?;
1787        let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1788        Ok(from_storage(storage, self.shape(), op, false))
1789    }
1790
1791    /// Accumulate element from `source` at indexes `indexes` and add them to `self`.
1792    pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1793        let dim = dim.to_index(self.shape(), "index-add")?;
1794        let source_dims = source.dims();
1795        let self_dims = self.dims();
1796        let mismatch = if source_dims.len() != self_dims.len() {
1797            true
1798        } else {
1799            let mut mismatch = false;
1800            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1801                if i != dim && d1 != d2 {
1802                    mismatch = true;
1803                    break;
1804                }
1805            }
1806            mismatch
1807        };
1808        if mismatch {
1809            Err(Error::ShapeMismatchBinaryOp {
1810                op: "index-add (self, source)",
1811                lhs: self.shape().clone(),
1812                rhs: source.shape().clone(),
1813            }
1814            .bt())?
1815        }
1816        // The number of element in indexes must match the dimension on which the add is
1817        // performed on the source tensor (and the index values from `indexes` are taken from
1818        // the target tensor self)
1819        let indexes_len = indexes.dims1()?;
1820        if source_dims[dim] != indexes_len {
1821            Err(Error::ShapeMismatchBinaryOp {
1822                op: "index-add (ids, source))",
1823                lhs: indexes.shape().clone(),
1824                rhs: source.shape().clone(),
1825            }
1826            .bt())?
1827        }
1828        let storage = self.storage().index_add(
1829            self.layout(),
1830            &indexes.storage(),
1831            indexes.layout(),
1832            &source.storage(),
1833            source.layout(),
1834            dim,
1835        )?;
1836        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1837            Op::IndexAdd(t1, t2, t3, dim)
1838        });
1839        Ok(from_storage(storage, self.shape(), op, false))
1840    }
1841
1842    /// Gather values across the target dimension.
1843    ///
1844    /// # Arguments
1845    ///
1846    /// * `self` - The input tensor.
1847    /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
1848    ///   and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
1849    /// * `dim` - the target dimension.
1850    ///
1851    /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
1852    /// dimension `dim` by the values in `indexes`.
1853    pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1854        let dim = dim.to_index(self.shape(), "gather")?;
1855
1856        let self_dims = self.dims();
1857        let indexes_dims = indexes.dims();
1858        let mismatch = if indexes_dims.len() != self_dims.len() {
1859            true
1860        } else {
1861            let mut mismatch = false;
1862            for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1863                if i != dim && d1 < d2 {
1864                    mismatch = true;
1865                    break;
1866                }
1867            }
1868            mismatch
1869        };
1870        if mismatch {
1871            Err(Error::ShapeMismatchBinaryOp {
1872                op: "gather",
1873                lhs: self.shape().clone(),
1874                rhs: indexes.shape().clone(),
1875            }
1876            .bt())?
1877        }
1878        let storage =
1879            self.storage()
1880                .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1881        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1882        Ok(from_storage(storage, indexes.shape(), op, false))
1883    }
1884
1885    /// Select values for the input tensor at the target indexes across the specified dimension.
1886    ///
1887    /// The `indexes` is argument is an int tensor with a single dimension.
1888    /// The output has the same number of dimension as the `self` input. The target dimension of
1889    /// the output has length the length of `indexes` and the values are taken from `self` using
1890    /// the index from `indexes`. Other dimensions have the same number of elements as the input
1891    /// tensor.
1892    pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1893        let dim = dim.to_index(self.shape(), "index-select")?;
1894        let indexes_len = match indexes.dims() {
1895            [l] => *l,
1896            _ => Err(Error::ShapeMismatchBinaryOp {
1897                lhs: self.shape().clone(),
1898                rhs: indexes.shape().clone(),
1899                op: "index-select",
1900            }
1901            .bt())?,
1902        };
1903        let storage = self.storage().index_select(
1904            &indexes.storage(),
1905            self.layout(),
1906            indexes.layout(),
1907            dim,
1908        )?;
1909        let mut dims = self.dims().to_vec();
1910        dims[dim] = indexes_len;
1911        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1912        Ok(from_storage(storage, dims, op, false))
1913    }
1914
1915    /// Returns an iterator over position of the elements in the storage when ranging over the
1916    /// index tuples in lexicographic order.
1917    pub fn strided_index(&self) -> crate::StridedIndex<'_> {
1918        self.layout.strided_index()
1919    }
1920
1921    /// Similar to `strided_index` but returns the position of the start of each contiguous block
1922    /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
1923    /// will only return the start offset and the size would be the number of elements in the
1924    /// tensor.
1925    pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
1926        self.layout.strided_blocks()
1927    }
1928
1929    /// Returns the data contained in a 1D tensor as a vector of scalar values.
1930    pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1931        if self.rank() != 1 {
1932            Err(Error::UnexpectedNumberOfDims {
1933                expected: 1,
1934                got: self.rank(),
1935                shape: self.shape().clone(),
1936            }
1937            .bt())?
1938        }
1939        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1940            let data = S::cpu_storage_as_slice(cpu_storage)?;
1941            let data = match self.layout.contiguous_offsets() {
1942                Some((o1, o2)) => data[o1..o2].to_vec(),
1943                None => self.strided_index().map(|i| data[i]).collect(),
1944            };
1945            Ok::<Vec<_>, Error>(data)
1946        };
1947        match &*self.storage() {
1948            Storage::Cpu(storage) => from_cpu_storage(storage),
1949            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1950            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1951        }
1952    }
1953
1954    /// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
1955    pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1956        let (dim1, dim2) = self.dims2()?;
1957        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1958            let data = S::cpu_storage_as_slice(cpu_storage)?;
1959            let mut rows = vec![];
1960            match self.layout.contiguous_offsets() {
1961                Some((o1, o2)) => {
1962                    let data = &data[o1..o2];
1963                    for idx_row in 0..dim1 {
1964                        rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1965                    }
1966                }
1967                None => {
1968                    let mut src_index = self.strided_index();
1969                    for _idx_row in 0..dim1 {
1970                        let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1971                        rows.push(row)
1972                    }
1973                    assert!(src_index.next().is_none());
1974                }
1975            }
1976            Ok(rows)
1977        };
1978        match &*self.storage() {
1979            Storage::Cpu(storage) => from_cpu_storage(storage),
1980            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1981            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1982        }
1983    }
1984
1985    /// Returns the data contained in a 3D tensor.
1986    pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1987        let (dim1, dim2, dim3) = self.dims3()?;
1988        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1989            let data = S::cpu_storage_as_slice(cpu_storage)?;
1990            let mut top_rows = vec![];
1991            match self.layout.contiguous_offsets() {
1992                Some((o1, o2)) => {
1993                    let data = &data[o1..o2];
1994                    let dim23 = dim2 * dim3;
1995                    for idx1 in 0..dim1 {
1996                        let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
1997                        let mut rows = vec![];
1998                        for idx2 in 0..dim2 {
1999                            rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
2000                        }
2001                        top_rows.push(rows);
2002                    }
2003                }
2004                None => {
2005                    let mut src_index = self.strided_index();
2006                    for _idx in 0..dim1 {
2007                        let mut rows = vec![];
2008                        for _jdx in 0..dim2 {
2009                            let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
2010                            rows.push(row)
2011                        }
2012                        top_rows.push(rows);
2013                    }
2014                    assert!(src_index.next().is_none());
2015                }
2016            }
2017            Ok(top_rows)
2018        };
2019        match &*self.storage() {
2020            Storage::Cpu(storage) => from_cpu_storage(storage),
2021            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2022            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2023        }
2024    }
2025
2026    /// The dtype for the elements stored in the input tensor.
2027    pub fn dtype(&self) -> DType {
2028        self.dtype
2029    }
2030
2031    /// The device on which the input tensor is located.
2032    pub fn device(&self) -> &Device {
2033        &self.device
2034    }
2035
2036    /// The tensor shape, i.e. dimension sizes on each axis.
2037    pub fn shape(&self) -> &Shape {
2038        self.layout().shape()
2039    }
2040
2041    /// The dimension size for this tensor on each axis.
2042    pub fn dims(&self) -> &[usize] {
2043        self.shape().dims()
2044    }
2045
2046    /// The dimension size for a specified dimension index.
2047    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
2048        let dim = dim.to_index(self.shape(), "dim")?;
2049        Ok(self.dims()[dim])
2050    }
2051
2052    /// The layout of the input tensor, this stores both the shape of the tensor as well as the
2053    /// strides and the start offset to apply to the underlying storage.
2054    pub fn layout(&self) -> &Layout {
2055        &self.layout
2056    }
2057
2058    pub fn stride(&self) -> &[usize] {
2059        self.layout.stride()
2060    }
2061
2062    /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
2063    pub fn rank(&self) -> usize {
2064        self.shape().rank()
2065    }
2066
2067    /// The number of elements stored in this tensor.
2068    pub fn elem_count(&self) -> usize {
2069        self.shape().elem_count()
2070    }
2071
2072    /// The unique identifier for this tensor.
2073    pub fn id(&self) -> TensorId {
2074        self.id
2075    }
2076
2077    /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is
2078    /// tracked and on which backpropagation can be performed.
2079    pub fn is_variable(&self) -> bool {
2080        self.is_variable
2081    }
2082
2083    pub(crate) fn op(&self) -> &Option<Op> {
2084        &self.op
2085    }
2086
2087    /// Computes the max of all the elements in this tensor and returns a tensor holding this
2088    /// scalar with zero dimensions.
2089    ///
2090    /// ```rust
2091    /// use candle_core::{Tensor, Device};
2092    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2093    /// let tensor = tensor.max_all()?;
2094    /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
2095    /// # Ok::<(), candle_core::Error>(())
2096    /// ```
2097    pub fn max_all(&self) -> Result<Tensor> {
2098        if self.rank() == 0 {
2099            Ok(self.clone())
2100        } else {
2101            self.flatten_all()?.max(0)
2102        }
2103    }
2104
2105    /// Computes the min of all the elements in this tensor and returns a tensor holding this
2106    /// scalar with zero dimensions.
2107    ///
2108    /// ```rust
2109    /// use candle_core::{Tensor, Device};
2110    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2111    /// let tensor = tensor.min_all()?;
2112    /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
2113    /// # Ok::<(), candle_core::Error>(())
2114    /// ```
2115    pub fn min_all(&self) -> Result<Tensor> {
2116        if self.rank() == 0 {
2117            Ok(self.clone())
2118        } else {
2119            self.flatten_all()?.min(0)
2120        }
2121    }
2122
2123    /// Computes the sum of all the elements in this tensor and returns a tensor holding this
2124    /// scalar with zero dimensions.
2125    ///
2126    /// ```rust
2127    /// use candle_core::{Tensor, Device};
2128    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2129    /// let tensor = tensor.sum_all()?;
2130    /// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
2131    /// # Ok::<(), candle_core::Error>(())
2132    /// ```
2133    pub fn sum_all(&self) -> Result<Tensor> {
2134        let dims: Vec<_> = (0..self.rank()).collect();
2135        self.sum(dims)
2136    }
2137
2138    pub fn mean_all(&self) -> Result<Tensor> {
2139        self.sum_all()? / self.elem_count() as f64
2140    }
2141
2142    fn flatten_<D1: Dim, D2: Dim>(
2143        &self,
2144        start_dim: Option<D1>,
2145        end_dim: Option<D2>,
2146    ) -> Result<Tensor> {
2147        if self.rank() == 0 {
2148            self.reshape(1)
2149        } else {
2150            let start_dim = match start_dim {
2151                None => 0,
2152                Some(dim) => dim.to_index(self.shape(), "flatten")?,
2153            };
2154            let end_dim = match end_dim {
2155                None => self.rank() - 1,
2156                Some(dim) => dim.to_index(self.shape(), "flatten")?,
2157            };
2158            if start_dim < end_dim {
2159                let dims = self.dims();
2160                let mut dst_dims = dims[..start_dim].to_vec();
2161                dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
2162                if end_dim + 1 < dims.len() {
2163                    dst_dims.extend(&dims[end_dim + 1..]);
2164                }
2165                self.reshape(dst_dims)
2166            } else {
2167                Ok(self.clone())
2168            }
2169        }
2170    }
2171
2172    /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
2173    /// inclusive).
2174    pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
2175        self.flatten_(Some(start_dim), Some(end_dim))
2176    }
2177
2178    /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
2179    pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
2180        self.flatten_(None::<usize>, Some(end_dim))
2181    }
2182
2183    /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
2184    /// dimension.
2185    pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
2186        self.flatten_(Some(start_dim), None::<usize>)
2187    }
2188
2189    /// Flattens the input tensor by reshaping it into a one dimension tensor.
2190    ///
2191    /// ```rust
2192    /// use candle_core::{Tensor, Device};
2193    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2194    /// let tensor = tensor.flatten_all()?;
2195    /// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
2196    /// # Ok::<(), candle_core::Error>(())
2197    /// ```
2198    pub fn flatten_all(&self) -> Result<Tensor> {
2199        self.flatten_(None::<usize>, None::<usize>)
2200    }
2201
2202    /// Returns the sub-tensor fixing the index at `i` on the first dimension.
2203    ///
2204    /// ```rust
2205    /// use candle_core::{Tensor, Device};
2206    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2207    /// let t = tensor.get(0)?;
2208    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
2209    /// let t = tensor.get(1)?;
2210    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
2211    /// # Ok::<(), candle_core::Error>(())
2212    /// ```
2213    pub fn get(&self, i: usize) -> Result<Tensor> {
2214        let dims = self.dims();
2215        if dims.is_empty() {
2216            Ok(self.clone())
2217        } else {
2218            self.narrow(0, i, 1)?.reshape(&dims[1..])
2219        }
2220    }
2221
2222    /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
2223    ///
2224    /// ```rust
2225    /// use candle_core::{Tensor, Device};
2226    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2227    /// let t = tensor.get_on_dim(1, 0)?;
2228    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
2229    /// let t = tensor.get_on_dim(1, 1)?;
2230    /// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
2231    /// let t = tensor.get_on_dim(0, 1)?;
2232    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
2233    /// # Ok::<(), candle_core::Error>(())
2234    /// ```
2235    pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
2236        let dim = dim.to_index(self.shape(), "get_on_dim")?;
2237        self.narrow(dim, index, 1)?.squeeze(dim)
2238    }
2239
2240    /// Returns a tensor that is a transposed version of the input, the two last dimensions of the
2241    /// input are swapped.
2242    ///
2243    /// ```rust
2244    /// use candle_core::{Tensor, Device};
2245    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2246    /// let tensor = tensor.t()?;
2247    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
2248    /// # Ok::<(), candle_core::Error>(())
2249    /// ```
2250    pub fn t(&self) -> Result<Tensor> {
2251        let rank = self.rank();
2252        if rank < 2 {
2253            Err(Error::UnexpectedNumberOfDims {
2254                expected: 2,
2255                got: rank,
2256                shape: self.shape().clone(),
2257            }
2258            .bt())?
2259        }
2260        self.transpose(rank - 2, rank - 1)
2261    }
2262
2263    /// Returns a tensor that is a transposed version of the input, the given dimensions are
2264    /// swapped.
2265    pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2266        let dim1 = dim1.to_index(self.shape(), "transpose")?;
2267        let dim2 = dim2.to_index(self.shape(), "transpose")?;
2268        if dim1 == dim2 {
2269            return Ok(self.clone());
2270        }
2271        let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2272        let tensor_ = Tensor_ {
2273            id: TensorId::new(),
2274            storage: self.storage.clone(),
2275            layout: self.layout.transpose(dim1, dim2)?,
2276            op,
2277            is_variable: false,
2278            dtype: self.dtype,
2279            device: self.device.clone(),
2280        };
2281        Ok(Tensor(Arc::new(tensor_)))
2282    }
2283
2284    /// Returns a tensor with the same data as the input where the dimensions have been permuted.
2285    /// dims must be a permutation, i.e. include each dimension index exactly once.
2286    ///
2287    /// ```rust
2288    /// use candle_core::{Tensor, Device};
2289    /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
2290    /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
2291    /// let tensor = tensor.permute((2, 3, 1, 0))?;
2292    /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
2293    /// # Ok::<(), candle_core::Error>(())
2294    /// ```
2295    pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2296        let dims = dims.to_indexes(self.shape(), "permute")?;
2297        // O(n^2) permutation check but these arrays are small.
2298        let is_permutation =
2299            dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2300        if !is_permutation {
2301            bail!(
2302                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2303                self.dims(),
2304                dims
2305            )
2306        }
2307        let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2308        let tensor_ = Tensor_ {
2309            id: TensorId::new(),
2310            storage: self.storage.clone(),
2311            layout: self.layout.permute(&dims)?,
2312            op,
2313            is_variable: false,
2314            dtype: self.dtype,
2315            device: self.device.clone(),
2316        };
2317        Ok(Tensor(Arc::new(tensor_)))
2318    }
2319
2320    /// Returns true if the data is stored in a C contiguous (aka row major) way.
2321    pub fn is_contiguous(&self) -> bool {
2322        self.layout.is_contiguous()
2323    }
2324
2325    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
2326    pub fn is_fortran_contiguous(&self) -> bool {
2327        self.layout.is_fortran_contiguous()
2328    }
2329
2330    /// Compared to clone, this copies the actual storage but may fail because of running out of
2331    /// memory.
2332    pub fn copy(&self) -> Result<Tensor> {
2333        let op = BackpropOp::new1(self, Op::Copy);
2334        let tensor_ = Tensor_ {
2335            id: TensorId::new(),
2336            storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2337            layout: self.layout.clone(),
2338            op,
2339            is_variable: false,
2340            dtype: self.dtype,
2341            device: self.device.clone(),
2342        };
2343        Ok(Tensor(Arc::new(tensor_)))
2344    }
2345
2346    /// Returns a new tensor detached from the current graph, gradient are not propagated through
2347    /// this new node. The storage of this tensor is shared with the initial tensor.
2348    ///
2349    /// If the tensor is already detached from the computation graph, the same tensor is returned.
2350    pub fn detach(&self) -> Tensor {
2351        if self.op.is_none() && !self.is_variable {
2352            self.clone()
2353        } else {
2354            let tensor_ = Tensor_ {
2355                id: TensorId::new(),
2356                storage: self.storage.clone(),
2357                layout: self.layout.clone(),
2358                op: BackpropOp::none(),
2359                is_variable: false,
2360                dtype: self.dtype,
2361                device: self.device.clone(),
2362            };
2363            Tensor(Arc::new(tensor_))
2364        }
2365    }
2366
2367    /// If the target device is the same as the tensor device, only a shallow copy is performed.
2368    pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2369        if self.device().same_device(device) {
2370            Ok(self.clone())
2371        } else {
2372            let storage = match (&*self.storage(), device) {
2373                (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2374                    Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2375                }
2376                (Storage::Cpu(storage), Device::Metal(metal)) => {
2377                    Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2378                }
2379                (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2380                (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2381                (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2382                    // can't clone storage if it's the same device because of the underlying device ptr
2383                    let dst_storage = storage.transfer_to_device(cuda)?;
2384                    Storage::Cuda(dst_storage)
2385                }
2386                (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2387                _ => {
2388                    bail!(
2389                        "not implemented yet, self.device: {:?}, device: {:?}",
2390                        self.device(),
2391                        device
2392                    )
2393                }
2394            };
2395            let op = BackpropOp::new1(self, Op::ToDevice);
2396            let tensor_ = Tensor_ {
2397                id: TensorId::new(),
2398                storage: Arc::new(RwLock::new(storage)),
2399                layout: self.layout.clone(),
2400                op,
2401                is_variable: false,
2402                dtype: self.dtype,
2403                device: device.clone(),
2404            };
2405            Ok(Tensor(Arc::new(tensor_)))
2406        }
2407    }
2408
2409    /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
2410    /// on the left.
2411    pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2412        let left_shape = left_shape.into();
2413        let mut dims = left_shape.into_dims();
2414        dims.extend(self.dims());
2415        self.broadcast_as(dims)
2416    }
2417
2418    /// Broadcast the input tensor to the target shape. This returns an error if the input shape is
2419    /// not compatible with the target shape.
2420    ///
2421    /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
2422    /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
2423    /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
2424    /// `i_a` is equal to 1, any value can be used.
2425    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2426        let tensor_ = Tensor_ {
2427            id: TensorId::new(),
2428            storage: self.storage.clone(),
2429            layout: self.layout.broadcast_as(shape)?,
2430            op: BackpropOp::new1(self, Op::Broadcast),
2431            is_variable: false,
2432            dtype: self.dtype,
2433            device: self.device.clone(),
2434        };
2435        Ok(Tensor(Arc::new(tensor_)))
2436    }
2437
2438    /// An alias for broadcast_as.
2439    pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2440        self.broadcast_as(shape)
2441    }
2442
2443    /// Casts the input tensor to the target `dtype`.
2444    ///
2445    /// ```rust
2446    /// use candle_core::{Tensor, Device};
2447    /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
2448    /// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
2449    /// let tensor = tensor.to_dtype(candle_core::DType::F32)?;
2450    /// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
2451    /// # Ok::<(), candle_core::Error>(())
2452    /// ```
2453    pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2454        if self.dtype() == dtype {
2455            Ok(self.clone())
2456        } else {
2457            let shape = self.shape();
2458            let storage = self.storage().to_dtype(self.layout(), dtype)?;
2459            let op = BackpropOp::new1(self, Op::ToDType);
2460            Ok(from_storage(storage, shape.clone(), op, false))
2461        }
2462    }
2463
2464    /// Returns a tensor that is in row major order. This is the same as the original tensor if it
2465    /// was already contiguous, otherwise a copy is triggered.
2466    pub fn contiguous(&self) -> Result<Tensor> {
2467        if self.is_contiguous() {
2468            Ok(self.clone())
2469        } else {
2470            let shape = self.shape();
2471            let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2472            self.storage()
2473                .copy_strided_src(&mut storage, 0, self.layout())?;
2474            let op = BackpropOp::new1(self, Op::Copy);
2475            Ok(from_storage(storage, shape.clone(), op, false))
2476        }
2477    }
2478
2479    /// Returns a tensor that is in row major order. This always makes a copy.
2480    pub fn force_contiguous(&self) -> Result<Tensor> {
2481        let shape = self.shape();
2482        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2483        self.storage()
2484            .copy_strided_src(&mut storage, 0, self.layout())?;
2485        let op = BackpropOp::new1(self, Op::Copy);
2486        Ok(from_storage(storage, shape.clone(), op, false))
2487    }
2488
2489    /// Create a variable based on the values currently stored in a tensor. The storage is always
2490    /// copied.
2491    pub(crate) fn make_var(&self) -> Result<Tensor> {
2492        let shape = self.shape().clone();
2493        let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2494        self.storage()
2495            .copy_strided_src(&mut storage, 0, self.layout())?;
2496        Ok(from_storage(storage, shape, BackpropOp::none(), true))
2497    }
2498
2499    /// Reshape returns a tensor with the target shape provided that the number of elements of the
2500    /// original tensor is the same.
2501    /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
2502    /// a new storage and copies the data over, the returned tensor is always contiguous.
2503    ///
2504    /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
2505    /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
2506    /// as to match the number of elements in the tensor.
2507    ///
2508    /// ```rust
2509    /// # use candle_core::{Tensor, DType, Device, D};
2510    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2511    ///
2512    /// let c = a.reshape((1, 6))?;
2513    /// assert_eq!(c.shape().dims(), &[1, 6]);
2514    ///
2515    /// let c = a.reshape((3, 2))?;
2516    /// assert_eq!(c.shape().dims(), &[3, 2]);
2517    ///
2518    /// let c = a.reshape((2, (), 1))?;
2519    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2520    ///
2521    /// # Ok::<(), candle_core::Error>(())
2522    /// ```
2523    pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2524        let shape = s.into_shape(self.elem_count())?;
2525        if shape.elem_count() != self.elem_count() {
2526            return Err(Error::ShapeMismatchBinaryOp {
2527                lhs: self.shape().clone(),
2528                rhs: shape,
2529                op: "reshape",
2530            }
2531            .bt());
2532        }
2533        let op = BackpropOp::new1(self, Op::Reshape);
2534        if self.is_contiguous() {
2535            let tensor_ = Tensor_ {
2536                id: TensorId::new(),
2537                storage: self.storage.clone(),
2538                layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2539                op,
2540                is_variable: false,
2541                dtype: self.dtype,
2542                device: self.device.clone(),
2543            };
2544            Ok(Tensor(Arc::new(tensor_)))
2545        } else {
2546            let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2547            self.storage()
2548                .copy_strided_src(&mut storage, 0, self.layout())?;
2549            Ok(from_storage(storage, shape, op, false))
2550        }
2551    }
2552
2553    /// Creates a new tensor with the specified dimension removed if its size was one.
2554    ///
2555    /// ```rust
2556    /// # use candle_core::{Tensor, DType, Device, D};
2557    /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
2558    ///
2559    /// let c = a.squeeze(2)?;
2560    /// assert_eq!(c.shape().dims(), &[2, 3]);
2561    ///
2562    /// let c = a.squeeze(D::Minus1)?;
2563    /// assert_eq!(c.shape().dims(), &[2, 3]);
2564    /// # Ok::<(), candle_core::Error>(())
2565    /// ```
2566    pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2567        // The PyTorch semantics are to return the same tensor if the target dimension
2568        // does not have a size of 1.
2569        let dims = self.dims();
2570        let dim = dim.to_index(self.shape(), "squeeze")?;
2571        if dims[dim] == 1 {
2572            let mut dims = dims.to_vec();
2573            let mut strides = self.stride().to_vec();
2574            dims.remove(dim);
2575            strides.remove(dim);
2576            let tensor_ = Tensor_ {
2577                id: TensorId::new(),
2578                storage: self.storage.clone(),
2579                layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2580                op: BackpropOp::new1(self, Op::Reshape),
2581                is_variable: false,
2582                dtype: self.dtype,
2583                device: self.device.clone(),
2584            };
2585            Ok(Tensor(Arc::new(tensor_)))
2586        } else {
2587            Ok(self.clone())
2588        }
2589    }
2590
2591    /// Creates a new tensor with a dimension of size one inserted at the specified position.
2592    ///
2593    /// ```rust
2594    /// # use candle_core::{Tensor, DType, Device, D};
2595    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2596    ///
2597    /// let c = a.unsqueeze(0)?;
2598    /// assert_eq!(c.shape().dims(), &[1, 2, 3]);
2599    ///
2600    /// let c = a.unsqueeze(D::Minus1)?;
2601    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2602    /// # Ok::<(), candle_core::Error>(())
2603    /// ```
2604    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2605        let mut dims = self.dims().to_vec();
2606        let mut strides = self.stride().to_vec();
2607        let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2608        // Cannot panic because to_index_plus_one already checks dimensions
2609        dims.insert(dim, 1);
2610        // Any stride would work here, but we pick one so as to maximize the probability to remain
2611        // C contiguous.
2612        let stride = if dim < strides.len() { strides[dim] } else { 1 };
2613        strides.insert(dim, stride);
2614        let tensor_ = Tensor_ {
2615            id: TensorId::new(),
2616            storage: self.storage.clone(),
2617            layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2618            op: BackpropOp::new1(self, Op::Reshape),
2619            is_variable: false,
2620            dtype: self.dtype,
2621            device: self.device.clone(),
2622        };
2623        Ok(Tensor(Arc::new(tensor_)))
2624    }
2625
2626    /// Stacks two or more tensors along a particular dimension.
2627    ///
2628    /// All tensors must have the same rank, and the output has one additional rank
2629    ///
2630    /// ```rust
2631    /// # use candle_core::{Tensor, DType, Device};
2632    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2633    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2634    ///
2635    /// let c = Tensor::stack(&[&a, &b], 0)?;
2636    /// assert_eq!(c.shape().dims(), &[2, 2, 3]);
2637    ///
2638    /// let c = Tensor::stack(&[&a, &b], 2)?;
2639    /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
2640    /// # Ok::<(), candle_core::Error>(())
2641    /// ```
2642    pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2643        if args.is_empty() {
2644            Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2645        }
2646        let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2647        let args = args
2648            .iter()
2649            .map(|t| t.as_ref().unsqueeze(dim))
2650            .collect::<Result<Vec<_>>>()?;
2651        Self::cat(&args, dim)
2652    }
2653
2654    /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
2655    /// input tensor values and `right` elements after.
2656    pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2657        if left == 0 && right == 0 {
2658            Ok(self.clone())
2659        } else if left == 0 {
2660            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2661            let mut dims = self.dims().to_vec();
2662            dims[dim] = right;
2663            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2664            Tensor::cat(&[self, &right], dim)
2665        } else if right == 0 {
2666            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2667            let mut dims = self.dims().to_vec();
2668            dims[dim] = left;
2669            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2670            Tensor::cat(&[&left, self], dim)
2671        } else {
2672            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2673            let mut dims = self.dims().to_vec();
2674            dims[dim] = left;
2675            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2676            dims[dim] = right;
2677            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2678            Tensor::cat(&[&left, self, &right], dim)
2679        }
2680    }
2681
2682    /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
2683    /// input tensor values and `right` elements after.
2684    pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2685        if left == 0 && right == 0 {
2686            Ok(self.clone())
2687        } else if self.elem_count() == 0 {
2688            bail!("cannot use pad_with_same on an empty tensor")
2689        } else if left == 0 {
2690            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2691            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2692            let mut v = vec![self];
2693            for _ in 0..right {
2694                v.push(&r)
2695            }
2696            Tensor::cat(&v, dim)
2697        } else if right == 0 {
2698            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2699            let l = self.narrow(dim, 0, 1)?;
2700            let mut v = vec![];
2701            for _ in 0..left {
2702                v.push(&l)
2703            }
2704            v.push(self);
2705            Tensor::cat(&v, dim)
2706        } else {
2707            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2708            let l = self.narrow(dim, 0, 1)?;
2709            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2710            let mut v = vec![];
2711            for _ in 0..left {
2712                v.push(&l)
2713            }
2714            v.push(self);
2715            for _ in 0..right {
2716                v.push(&r)
2717            }
2718            Tensor::cat(&v, dim)
2719        }
2720    }
2721
2722    /// Run the `forward` method of `m` on `self`.
2723    pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2724        m.forward(self)
2725    }
2726
2727    /// Run the `forward` method of `m` on `self`.
2728    pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2729        m.forward_t(self, train)
2730    }
2731
2732    pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2733        self.storage.read().unwrap()
2734    }
2735
2736    pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2737        self.storage.write().unwrap()
2738    }
2739
2740    // If we extend the visibility of this function to be usable outside of this crate, we should
2741    // make it unsafe.
2742    pub(crate) fn storage_mut_and_layout(
2743        &self,
2744    ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2745        let storage = self.storage.write().unwrap();
2746        (storage, &self.layout)
2747    }
2748
2749    /// The storage used by this tensor, together with the layout to use to access it safely.
2750    pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2751        let storage = self.storage.read().unwrap();
2752        (storage, &self.layout)
2753    }
2754
2755    pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2756        let lhs: &RwLock<Storage> = self.storage.as_ref();
2757        let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2758        std::ptr::eq(lhs, rhs)
2759    }
2760
2761    /// Normalize a 'relative' axis value: positive values are kept, negative
2762    /// values means counting the dimensions from the back.
2763    pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2764        let rank = self.rank() as i64;
2765        if rank <= axis {
2766            bail!("axis {axis} is too large, tensor rank {rank}")
2767        } else if 0 <= axis {
2768            Ok(axis as usize)
2769        } else {
2770            let naxis = rank + axis;
2771            if naxis < 0 {
2772                bail!("axis {axis} is too small, tensor rank {rank}")
2773            }
2774            Ok(naxis as usize)
2775        }
2776    }
2777
2778    /// Returns a lower triangular matrix of ones of size n by n.
2779    pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2780        let t = Tensor::arange(0u32, n as u32, device)?;
2781        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2782        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2783        t1.le(&t2)?.to_dtype(dtype)
2784    }
2785
2786    /// Returns an upper triangular matrix of ones of size n by n.
2787    pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2788        let t = Tensor::arange(0u32, n as u32, device)?;
2789        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2790        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2791        t1.ge(&t2)?.to_dtype(dtype)
2792    }
2793
2794    /// Returns a matrix with a diagonal of ones of size n by n.
2795    pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2796        let t = Tensor::arange(0u32, n as u32, device)?;
2797        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2798        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2799        t1.eq(&t2)?.to_dtype(dtype)
2800    }
2801
2802    /// Returns the cumulative sum of elements of the input tensor summed over the specified
2803    /// dimension.
2804    ///
2805    /// This operation is most efficient when dim is the last dimension of the tensor.
2806    pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2807        let dim = dim.to_index(self.shape(), "cumsum")?;
2808        let rank = self.rank();
2809        if rank == 0 {
2810            return Ok(self.clone());
2811        }
2812        let n_axis = self.dim(dim)?;
2813        let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2814        if rank == 1 {
2815            self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2816        } else {
2817            let last = rank - 1;
2818            let t = self.transpose(dim, last)?;
2819            let t = t.broadcast_matmul(&triu)?;
2820            t.transpose(dim, last)
2821        }
2822    }
2823
2824    /// Returns a copy of `self` where the values within `ranges` have been replaced with the
2825    /// content of `src`.
2826    pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2827        &self,
2828        ranges: &[D],
2829        src: &Tensor,
2830    ) -> Result<Self> {
2831        let src_dims = src.dims();
2832        let self_dims = self.dims();
2833        if self_dims.len() != src_dims.len() {
2834            bail!(
2835                "slice-assign requires input with the same rank {} <> {}",
2836                self_dims.len(),
2837                src_dims.len()
2838            )
2839        }
2840        if self_dims.len() != ranges.len() {
2841            bail!(
2842                "slice-assign requires input with the same rank as there are ranges {} <> {}",
2843                self_dims.len(),
2844                ranges.len()
2845            )
2846        }
2847        let mut src = src.clone();
2848        let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2849        for (i, range) in ranges.iter().enumerate() {
2850            let start_included = match range.start_bound() {
2851                std::ops::Bound::Unbounded => 0,
2852                std::ops::Bound::Included(v) => *v,
2853                std::ops::Bound::Excluded(v) => *v + 1,
2854            };
2855            let end_excluded = match range.end_bound() {
2856                std::ops::Bound::Unbounded => self_dims[i],
2857                std::ops::Bound::Included(v) => *v + 1,
2858                std::ops::Bound::Excluded(v) => *v,
2859            };
2860            if end_excluded <= start_included {
2861                bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2862            }
2863            if self_dims[i] < end_excluded {
2864                bail!(
2865                    "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2866                    self_dims[i]
2867                )
2868            }
2869            if end_excluded - start_included != src_dims[i] {
2870                bail!(
2871                    "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2872                )
2873            }
2874            src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2875            mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2876        }
2877        mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
2878    }
2879
2880    /// Returns log(sum(exp(tensor), dim)).
2881    pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2882        let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2883        if sum_dims.is_empty() {
2884            return Ok(self.clone());
2885        }
2886        let max = sum_dims[1..]
2887            .iter()
2888            .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2889                max.max_keepdim(dim)
2890            })?;
2891        let exp = self.broadcast_sub(&max)?.exp()?;
2892        let sum = exp.sum(sum_dims.clone())?;
2893
2894        sum.log()? + max.squeeze_dims(&sum_dims)
2895    }
2896
2897    /// Pointwise pow operation.
2898    pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2899        rhs.mul(&self.log()?)?.exp()
2900    }
2901
2902    /// Broadcasting version of `pow`.
2903    pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2904        rhs.broadcast_mul(&self.log()?)?.exp()
2905    }
2906
2907    /// Returns a new tensor with the order of elements reversed along the specified dimensions.
2908    /// This function makes a copy of the tensor’s data.
2909    ///
2910    /// ```rust
2911    /// # use candle_core::{Tensor, Device};
2912    /// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
2913    /// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
2914    /// let t_flipped = t.flip(&[0])?;
2915    /// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
2916    /// # Ok::<(), candle_core::Error>(())
2917    /// ```
2918    pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2919        let mut result = self.clone();
2920        for &dim in dims.iter() {
2921            let size = result.dim(dim)?;
2922            let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2923            let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2924            result = result.index_select(&indices_tensor, dim)?;
2925        }
2926        Ok(result)
2927    }
2928
2929    /// Returns a view of which contains all slices of size `size` from self tensor in the dimension
2930    /// `dim` and stepped by `step`.
2931    pub fn unfold<D: Dim>(&self, dim: D, size: usize, step: usize) -> Result<Self> {
2932        // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804
2933        let mut sizes = self.dims().to_vec();
2934        let mut strides = self.stride().to_vec();
2935
2936        let dim = dim.to_index(self.shape(), "unfold")?;
2937
2938        let max_len = if self.dims().is_empty() {
2939            1
2940        } else {
2941            sizes[dim]
2942        };
2943        if size > max_len {
2944            bail!(
2945                "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}"
2946            )
2947        }
2948        sizes.push(size);
2949        strides.push(if self.dims().is_empty() {
2950            1
2951        } else {
2952            strides[dim]
2953        });
2954
2955        if !self.dims().is_empty() {
2956            sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize;
2957            strides[dim] *= step;
2958        }
2959
2960        let tensor_ = Tensor_ {
2961            id: TensorId::new(),
2962            storage: self.storage.clone(),
2963            layout: Layout::new(sizes.into(), strides, self.layout.start_offset()),
2964            op: BackpropOp::new1(self, Op::Reshape),
2965            is_variable: false,
2966            dtype: self.dtype,
2967            device: self.device.clone(),
2968        };
2969        Ok(Tensor(Arc::new(tensor_)))
2970    }
2971}
2972
2973macro_rules! bin_trait {
2974    ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
2975        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
2976            type Output = Result<Tensor>;
2977
2978            fn $fn1(self, rhs: B) -> Self::Output {
2979                Tensor::$fn1(&self, rhs.borrow())
2980            }
2981        }
2982
2983        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
2984            type Output = Result<Tensor>;
2985
2986            fn $fn1(self, rhs: B) -> Self::Output {
2987                Tensor::$fn1(&self, rhs.borrow())
2988            }
2989        }
2990
2991        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
2992            type Output = Result<Tensor>;
2993
2994            fn $fn1(self, rhs: Tensor) -> Self::Output {
2995                Tensor::$fn1(self?.borrow(), &rhs)
2996            }
2997        }
2998
2999        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
3000            type Output = Result<Tensor>;
3001
3002            fn $fn1(self, rhs: &Tensor) -> Self::Output {
3003                Tensor::$fn1(self?.borrow(), rhs)
3004            }
3005        }
3006
3007        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
3008            type Output = Result<Tensor>;
3009
3010            fn $fn1(self, rhs: Result<B>) -> Self::Output {
3011                Tensor::$fn1(&self, rhs?.borrow())
3012            }
3013        }
3014
3015        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
3016            type Output = Result<Tensor>;
3017
3018            fn $fn1(self, rhs: Result<B>) -> Self::Output {
3019                Tensor::$fn1(&self, rhs?.borrow())
3020            }
3021        }
3022
3023        impl std::ops::$trait<f64> for Tensor {
3024            type Output = Result<Tensor>;
3025
3026            fn $fn1(self, rhs: f64) -> Self::Output {
3027                self.affine($mul(rhs), $add(rhs))
3028            }
3029        }
3030
3031        impl std::ops::$trait<f64> for &Tensor {
3032            type Output = Result<Tensor>;
3033
3034            fn $fn1(self, rhs: f64) -> Self::Output {
3035                self.affine($mul(rhs), $add(rhs))
3036            }
3037        }
3038    };
3039}
3040
3041bin_trait!(Add, add, |_| 1., |v| v);
3042bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
3043bin_trait!(Mul, mul, |v| v, |_| 0.);
3044bin_trait!(Div, div, |v| 1. / v, |_| 0.);
3045
3046impl std::ops::Add<Tensor> for f64 {
3047    type Output = Result<Tensor>;
3048
3049    fn add(self, rhs: Tensor) -> Self::Output {
3050        rhs + self
3051    }
3052}
3053
3054impl std::ops::Add<&Tensor> for f64 {
3055    type Output = Result<Tensor>;
3056
3057    fn add(self, rhs: &Tensor) -> Self::Output {
3058        rhs + self
3059    }
3060}
3061
3062impl std::ops::Mul<Tensor> for f64 {
3063    type Output = Result<Tensor>;
3064
3065    fn mul(self, rhs: Tensor) -> Self::Output {
3066        rhs * self
3067    }
3068}
3069
3070impl std::ops::Mul<&Tensor> for f64 {
3071    type Output = Result<Tensor>;
3072
3073    fn mul(self, rhs: &Tensor) -> Self::Output {
3074        rhs * self
3075    }
3076}
3077
3078impl std::ops::Sub<Tensor> for f64 {
3079    type Output = Result<Tensor>;
3080
3081    fn sub(self, rhs: Tensor) -> Self::Output {
3082        rhs.affine(-1., self)
3083    }
3084}
3085
3086impl std::ops::Sub<&Tensor> for f64 {
3087    type Output = Result<Tensor>;
3088
3089    fn sub(self, rhs: &Tensor) -> Self::Output {
3090        rhs.affine(-1., self)
3091    }
3092}
3093
3094impl std::ops::Div<Tensor> for f64 {
3095    type Output = Result<Tensor>;
3096
3097    #[allow(clippy::suspicious_arithmetic_impl)]
3098    fn div(self, rhs: Tensor) -> Self::Output {
3099        rhs.recip()? * self
3100    }
3101}
3102
3103impl std::ops::Div<&Tensor> for f64 {
3104    type Output = Result<Tensor>;
3105
3106    #[allow(clippy::suspicious_arithmetic_impl)]
3107    fn div(self, rhs: &Tensor) -> Self::Output {
3108        rhs.recip()? * self
3109    }
3110}
3111
3112impl<S: Into<Shape>> From<(Storage, S)> for Tensor {
3113    fn from((storage, shape): (Storage, S)) -> Self {
3114        from_storage(storage, shape, BackpropOp::none(), false)
3115    }
3116}