Skip to main content

hanzo_ml/
tensor.rs

1//! Tensors are N-dimensional matrixes of elements using a single data type.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::{BackendDevice, BackendStorage};
4use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
5use crate::scalar::TensorOrScalar;
6use crate::shape::{Dim, Dims, ShapeWithOneHole};
7use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
8use std::sync::{Arc, RwLock};
9
10/// Unique identifier for tensors.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct TensorId(usize);
13
14impl TensorId {
15    fn new() -> Self {
16        // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
17        use std::sync::atomic;
18        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
19        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
20    }
21}
22
23pub struct Tensor_ {
24    id: TensorId,
25    // As we provide inner mutability on the tensor content, the alternatives are:
26    // - Using a mutex, this would have the highest cost when retrieving the storage but would
27    //   prevent errors when concurrent access takes place. Mutex would also be subject to
28    //   deadlocks for example using the current code if the same tensor is used twice by a single
29    //   binary op.
30    // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be
31    //   verified dynamically, but the resulting tensors would not be send or sync.
32    // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
33    //   accesses.
34    // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
35    // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
36    // that's tricky to encode in the current setup.
37    storage: Arc<RwLock<Storage>>,
38    layout: Layout,
39    op: BackpropOp,
40    is_variable: bool,
41    dtype: DType,
42    device: Device,
43}
44
45impl AsRef<Tensor> for Tensor {
46    fn as_ref(&self) -> &Tensor {
47        self
48    }
49}
50
51// Tensors are refcounted so that cloning is cheap when building the op graph.
52// Storages are also refcounted independently so that its possible to avoid
53// copying the storage for operations that only modify the shape or stride.
54#[derive(Clone)]
55/// The core struct for manipulating tensors.
56///
57/// ```rust
58/// use hanzo_ml::{Tensor, DType, Device};
59///
60/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
61/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
62///
63/// let c = a.matmul(&b)?;
64/// # Ok::<(), hanzo_ml::Error>(())
65/// ```
66///
67/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
68pub struct Tensor(Arc<Tensor_>);
69
70impl std::ops::Deref for Tensor {
71    type Target = Tensor_;
72
73    fn deref(&self) -> &Self::Target {
74        self.0.as_ref()
75    }
76}
77
78macro_rules! unary_op {
79    ($fn_name:ident, $op_name:ident) => {
80        pub fn $fn_name(&self) -> Result<Self> {
81            let shape = self.shape();
82            if shape.elem_count() == 0 {
83                return Ok(self.clone());
84            }
85            let storage = self
86                .storage()
87                .unary_impl::<crate::op::$op_name>(self.layout())?;
88            let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
89            Ok(from_storage(storage, shape.clone(), op, false))
90        }
91    };
92}
93
94macro_rules! binary_op {
95    ($fn_name:ident, $op_name:ident) => {
96        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
97            let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
98            if shape.elem_count() == 0 {
99                return Ok(self.clone());
100            }
101            let storage = self.storage().binary_impl::<crate::op::$op_name>(
102                &*rhs.storage(),
103                self.layout(),
104                rhs.layout(),
105            )?;
106            let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
107            Ok(from_storage(storage, shape.clone(), op, false))
108        }
109    };
110}
111
112macro_rules! binary_op_scalar {
113    ($fn_name:ident, $op_name:ident) => {
114        pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
115            let rhs = match rhs.to_tensor_scalar()? {
116                crate::scalar::TensorScalar::Tensor(rhs) => rhs,
117                crate::scalar::TensorScalar::Scalar(rhs) => rhs
118                    .to_dtype(self.dtype())?
119                    .to_device(self.device())?
120                    .broadcast_as(self.shape())?,
121            };
122            let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
123            if self.elem_count() == 0 {
124                return Ok(self.clone());
125            }
126            let storage = self.storage().binary_impl::<crate::op::$op_name>(
127                &*rhs.storage(),
128                self.layout(),
129                rhs.layout(),
130            )?;
131            let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
132            Ok(from_storage(storage, shape.clone(), op, false))
133        }
134    };
135}
136
137macro_rules! broadcast_binary_op {
138    ($fn_name:ident, $inner_fn_name:ident) => {
139        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
140            let lhs = self;
141            let shape = lhs
142                .shape()
143                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
144            let l_broadcast = shape != *lhs.shape();
145            let r_broadcast = shape != *rhs.shape();
146            match (l_broadcast, r_broadcast) {
147                (true, true) => lhs
148                    .broadcast_as(&shape)?
149                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
150                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
151                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
152                (false, false) => lhs.$inner_fn_name(rhs),
153            }
154        }
155    };
156}
157
158/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
159pub(crate) fn from_storage<S: Into<Shape>>(
160    storage: Storage,
161    shape: S,
162    op: BackpropOp,
163    is_variable: bool,
164) -> Tensor {
165    let dtype = storage.dtype();
166    let device = storage.device();
167    let tensor_ = Tensor_ {
168        id: TensorId::new(),
169        storage: Arc::new(RwLock::new(storage)),
170        layout: Layout::contiguous(shape),
171        op,
172        is_variable,
173        dtype,
174        device,
175    };
176    Tensor(Arc::new(tensor_))
177}
178
179impl Tensor {
180    pub(crate) fn ones_impl<S: Into<Shape>>(
181        shape: S,
182        dtype: DType,
183        device: &Device,
184        is_variable: bool,
185    ) -> Result<Self> {
186        let none = BackpropOp::none();
187        let shape = shape.into();
188        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
189        let layout = Layout::contiguous(shape.clone());
190        storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
191        Ok(from_storage(storage, shape, none, is_variable))
192    }
193
194    /// Creates a new tensor filled with ones.
195    ///
196    /// ```rust
197    /// use hanzo_ml::{Tensor, DType, Device};
198    /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
199    /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
200    /// // a == b
201    /// # Ok::<(), hanzo_ml::Error>(())
202    /// ```
203    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
204        Self::ones_impl(shape, dtype, device, false)
205    }
206
207    pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
208        self.storage_mut().const_set(value, self.layout())
209    }
210
211    pub fn zero_set(&self) -> Result<()> {
212        self.const_set(crate::scalar::Scalar::zero(self.dtype()))
213    }
214
215    pub fn one_set(&self) -> Result<()> {
216        self.const_set(crate::scalar::Scalar::one(self.dtype()))
217    }
218
219    /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
220    ///
221    /// ```rust
222    /// use hanzo_ml::{Tensor, DType, Device};
223    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
224    /// let b = a.ones_like()?;
225    /// // b == a + 1
226    /// # Ok::<(), hanzo_ml::Error>(())
227    /// ```
228    pub fn ones_like(&self) -> Result<Self> {
229        Tensor::ones(self.shape(), self.dtype(), self.device())
230    }
231
232    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
233    // the variable module.
234    pub(crate) fn zeros_impl<S: Into<Shape>>(
235        shape: S,
236        dtype: DType,
237        device: &Device,
238        is_variable: bool,
239    ) -> Result<Self> {
240        let none = BackpropOp::none();
241        let shape = shape.into();
242        let storage = device.zeros(&shape, dtype)?;
243        Ok(from_storage(storage, shape, none, is_variable))
244    }
245
246    /// Creates a new tensor filled with zeros.
247    ///
248    /// ```rust
249    /// use hanzo_ml::{Tensor, DType, Device};
250    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
251    /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
252    /// // a == b
253    /// # Ok::<(), hanzo_ml::Error>(())
254    /// ```
255    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
256        Self::zeros_impl(shape, dtype, device, false)
257    }
258
259    /// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
260    /// tensor.
261    ///
262    /// ```rust
263    /// use hanzo_ml::{Tensor, DType, Device};
264    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
265    /// let b = a.zeros_like()?;
266    /// // b is on CPU f32.
267    /// # Ok::<(), hanzo_ml::Error>(())
268    /// ```
269    pub fn zeros_like(&self) -> Result<Self> {
270        Tensor::zeros(self.shape(), self.dtype(), self.device())
271    }
272
273    // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
274    // the variable module.
275    pub(crate) unsafe fn empty_impl<S: Into<Shape>>(
276        shape: S,
277        dtype: DType,
278        device: &Device,
279        is_variable: bool,
280    ) -> Result<Self> {
281        let none = BackpropOp::none();
282        let shape = shape.into();
283        let storage = device.alloc_uninit(&shape, dtype)?;
284        Ok(from_storage(storage, shape, none, is_variable))
285    }
286
287    /// Creates a new tensor filled with uninitialized memory.
288    ///
289    /// # Safety
290    /// This returns uninitialized memory.
291    ///
292    /// ```rust
293    /// use hanzo_ml::{Tensor, DType, Device};
294    /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? };
295    /// // a == b
296    /// # Ok::<(), hanzo_ml::Error>(())
297    /// ```
298    pub unsafe fn empty<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
299        Self::empty_impl(shape, dtype, device, false)
300    }
301
302    /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other
303    /// tensor.
304    ///
305    /// # Safety
306    /// This returns uninitialized memory.
307    ///
308    /// ```rust
309    /// use hanzo_ml::{Tensor, DType, Device};
310    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
311    /// let b = unsafe { a.empty_like()? };
312    /// # Ok::<(), hanzo_ml::Error>(())
313    /// ```
314    pub unsafe fn empty_like(&self) -> Result<Self> {
315        Tensor::empty(self.shape(), self.dtype(), self.device())
316    }
317
318    pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
319        lo: T,
320        up: T,
321        s: S,
322        device: &Device,
323        is_variable: bool,
324    ) -> Result<Self> {
325        let s = s.into();
326        let storage = device.rand_uniform(lo, up, &s)?;
327        let none = BackpropOp::none();
328        Ok(from_storage(storage, s, none, is_variable))
329    }
330
331    pub(crate) fn rand_f64_impl<S: Into<Shape>>(
332        lo: f64,
333        up: f64,
334        s: S,
335        dtype: DType,
336        device: &Device,
337        is_variable: bool,
338    ) -> Result<Self> {
339        let s = s.into();
340        let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
341        let none = BackpropOp::none();
342        Ok(from_storage(storage, s, none, is_variable))
343    }
344
345    /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
346    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
347        lo: T,
348        up: T,
349        s: S,
350        device: &Device,
351    ) -> Result<Self> {
352        Self::rand_impl(lo, up, s, device, false)
353    }
354
355    pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self> {
356        Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false)
357    }
358
359    pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
360        mean: T,
361        std: T,
362        s: S,
363        device: &Device,
364        is_variable: bool,
365    ) -> Result<Self> {
366        let s = s.into();
367        let storage = device.rand_normal(mean, std, &s)?;
368        let none = BackpropOp::none();
369        Ok(from_storage(storage, s, none, is_variable))
370    }
371
372    pub(crate) fn randn_f64_impl<S: Into<Shape>>(
373        mean: f64,
374        std: f64,
375        s: S,
376        dtype: DType,
377        device: &Device,
378        is_variable: bool,
379    ) -> Result<Self> {
380        let s = s.into();
381        let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
382        let none = BackpropOp::none();
383        Ok(from_storage(storage, s, none, is_variable))
384    }
385
386    pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self> {
387        Tensor::randn_f64_impl(
388            mean,
389            stdev,
390            self.shape(),
391            self.dtype(),
392            self.device(),
393            false,
394        )
395    }
396
397    /// Creates a new tensor initialized with values sampled from a normal distribution with the
398    /// specified `mean` and standard deviation `std`.
399    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
400        mean: T,
401        std: T,
402        s: S,
403        device: &Device,
404    ) -> Result<Self> {
405        Self::randn_impl(mean, std, s, device, false)
406    }
407
408    pub(crate) fn new_impl<A: crate::device::NdArray>(
409        array: A,
410        shape: Shape,
411        device: &Device,
412        is_variable: bool,
413    ) -> Result<Self> {
414        let n: usize = shape.elem_count();
415        let buffer_size: usize = array.shape()?.elem_count();
416        if buffer_size != n {
417            return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
418        }
419        let storage = device.storage(array)?;
420        let none = BackpropOp::none();
421        Ok(from_storage(storage, shape, none, is_variable))
422    }
423
424    /// Creates a new tensor on the specified device using the content and shape of the input.
425    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
426        let shape = array.shape()?;
427        Self::new_impl(array, shape, device, false)
428    }
429
430    /// Returns a new tensor with all the elements having the same specified value.
431    ///```rust
432    /// use hanzo_ml::{Tensor, Device};
433    /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
434    ///
435    /// assert_eq!(a.to_vec2::<f64>()?, &[
436    ///     [3.5, 3.5, 3.5, 3.5],
437    ///     [3.5, 3.5, 3.5, 3.5],
438    /// ]);
439    /// # Ok::<(), hanzo_ml::Error>(())
440    pub fn full<D: crate::WithDType, S: Into<Shape>>(
441        value: D,
442        shape: S,
443        device: &Device,
444    ) -> Result<Self> {
445        let none = BackpropOp::none();
446        let shape = shape.into();
447        let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
448        let layout = Layout::contiguous(shape.clone());
449        storage.const_set(value.to_scalar(), &layout)?;
450        Ok(from_storage(storage, shape, none, false))
451    }
452
453    /// Creates a new 1D tensor from an iterator.
454    ///```rust
455    /// use hanzo_ml::{Tensor, Device};
456    /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
457    ///
458    /// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
459    /// # Ok::<(), hanzo_ml::Error>(())
460    /// ```
461    pub fn from_iter<D: crate::WithDType>(
462        iter: impl IntoIterator<Item = D>,
463        device: &Device,
464    ) -> Result<Self> {
465        let data = iter.into_iter().collect::<Vec<_>>();
466        let len = data.len();
467        Self::from_vec_impl(data, len, device, false)
468    }
469
470    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
471    /// difference `1` from `start`.
472    ///```rust
473    /// use hanzo_ml::{Tensor, Device};
474    /// let a = Tensor::arange(2., 5., &Device::Cpu)?;
475    ///
476    /// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
477    /// # Ok::<(), hanzo_ml::Error>(())
478    /// ```
479    pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
480        Self::arange_step(start, end, D::one(), device)
481    }
482
483    /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
484    /// difference `step` from `start`.
485    ///```rust
486    /// use hanzo_ml::{Tensor, Device};
487    /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
488    ///
489    /// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
490    /// # Ok::<(), hanzo_ml::Error>(())
491    /// ```
492    pub fn arange_step<D: crate::WithDType>(
493        start: D,
494        end: D,
495        step: D,
496        device: &Device,
497    ) -> Result<Self> {
498        if D::is_zero(&step) {
499            bail!("step cannot be zero")
500        }
501        let mut data = vec![];
502        let mut current = start;
503        if step >= D::zero() {
504            while current < end {
505                data.push(current);
506                current += step;
507            }
508        } else {
509            while current > end {
510                data.push(current);
511                current += step;
512            }
513        }
514        let len = data.len();
515        Self::from_vec_impl(data, len, device, false)
516    }
517
518    pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
519        data: Vec<D>,
520        shape: S,
521        device: &Device,
522        is_variable: bool,
523    ) -> Result<Self> {
524        let shape = shape.into_shape(data.len())?;
525        let storage = device.storage_owned(data)?;
526        let none = BackpropOp::none();
527        Ok(from_storage(storage, shape, none, is_variable))
528    }
529
530    /// Creates a new tensor initialized with values from the input vector. The number of elements
531    /// in this vector must be the same as the number of elements defined by the shape.
532    /// If the device is cpu, no data copy is made.
533    ///```rust
534    /// use hanzo_ml::{Tensor, Device};
535    /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
536    ///
537    /// assert_eq!(a.to_vec2::<f64>()?, &[
538    ///     [1., 2., 3.],
539    ///     [4., 5., 6.]
540    /// ]);
541    /// # Ok::<(), hanzo_ml::Error>(())
542    /// ```
543    pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
544        data: Vec<D>,
545        shape: S,
546        device: &Device,
547    ) -> Result<Self> {
548        Self::from_vec_impl(data, shape, device, false)
549    }
550
551    /// Creates a new tensor initialized with values from the input slice. The number of elements
552    /// in this vector must be the same as the number of elements defined by the shape.
553    ///```rust
554    /// use hanzo_ml::{Tensor, Device};
555    /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
556    /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
557    ///
558    /// assert_eq!(a.to_vec2::<f64>()?, &[
559    ///     [2., 3., 4.],
560    ///     [5., 6., 7.]
561    /// ]);
562    /// # Ok::<(), hanzo_ml::Error>(())
563    /// ```
564    pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
565        array: &[D],
566        shape: S,
567        device: &Device,
568    ) -> Result<Self> {
569        let shape = shape.into_shape(array.len())?;
570        let storage = device.storage_from_slice(array)?;
571        let none = BackpropOp::none();
572        Ok(from_storage(storage, shape, none, false))
573    }
574
575    pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
576        let lhs = self.shape();
577        let rhs = rhs.shape();
578        if lhs != rhs {
579            Err(Error::ShapeMismatchBinaryOp {
580                lhs: lhs.clone(),
581                rhs: rhs.clone(),
582                op,
583            }
584            .bt())
585        } else {
586            Ok(lhs)
587        }
588    }
589
590    /// Returns true if the computation graph should track this op, that is if it is
591    /// a variable or if it has some variable as dependencies.
592    pub fn track_op(&self) -> bool {
593        self.is_variable || self.op.is_some()
594    }
595
596    /// Creates a fresh tensor structure based on a storage and a shape.
597    ///
598    /// # Note
599    /// - This uses contiguous strides
600    /// - Ensure the shape is compatible with the shape of the storage.
601    pub fn from_storage<S: Into<Shape>>(
602        storage: Storage,
603        shape: S,
604        op: BackpropOp,
605        is_variable: bool,
606    ) -> Tensor {
607        from_storage(storage, shape, op, is_variable)
608    }
609
610    // TODO: Also make an inplace version or a pre-allocated? This could be tricky
611    // if this can create cycles in the compute graph.
612    binary_op!(add, Add);
613    binary_op!(mul, Mul);
614    binary_op!(sub, Sub);
615    binary_op!(div, Div);
616    binary_op_scalar!(maximum, Maximum);
617    binary_op_scalar!(minimum, Minimum);
618    broadcast_binary_op!(broadcast_add, add);
619    broadcast_binary_op!(broadcast_mul, mul);
620    broadcast_binary_op!(broadcast_sub, sub);
621    broadcast_binary_op!(broadcast_div, div);
622    broadcast_binary_op!(broadcast_maximum, maximum);
623    broadcast_binary_op!(broadcast_minimum, minimum);
624    broadcast_binary_op!(broadcast_eq, eq);
625    broadcast_binary_op!(broadcast_ne, ne);
626    broadcast_binary_op!(broadcast_lt, lt);
627    broadcast_binary_op!(broadcast_le, le);
628    broadcast_binary_op!(broadcast_gt, gt);
629    broadcast_binary_op!(broadcast_ge, ge);
630
631    unary_op!(recip, Recip);
632    unary_op!(neg, Neg);
633    unary_op!(exp, Exp);
634    unary_op!(log, Log);
635    unary_op!(sin, Sin);
636    unary_op!(cos, Cos);
637    unary_op!(tanh, Tanh);
638    unary_op!(abs, Abs);
639    unary_op!(sqr, Sqr);
640    unary_op!(sqrt, Sqrt);
641    unary_op!(gelu, Gelu);
642    unary_op!(gelu_erf, GeluErf);
643    unary_op!(erf, Erf);
644    unary_op!(relu, Relu);
645    unary_op!(silu, Silu);
646    unary_op!(ceil, Ceil);
647    unary_op!(floor, Floor);
648    unary_op!(round, Round);
649    unary_op!(sign, Sign);
650
651    /// Round element of the input tensor to the nearest integer.
652    ///
653    /// If the number of decimals is negative, it specifies the number of positions to the left of
654    /// the decimal point.
655    pub fn round_to(&self, decimals: i32) -> Result<Self> {
656        let mult = 10f64.powi(decimals);
657        (self * mult)?.round()? * (1f64 / mult)
658    }
659
660    /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
661    /// dimensions, an error is returned instead.
662    pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
663        if self.rank() != 0 {
664            Err(Error::UnexpectedNumberOfDims {
665                expected: 0,
666                got: self.rank(),
667                shape: self.shape().clone(),
668            }
669            .bt())?
670        }
671        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
672            let data = S::cpu_storage_as_slice(cpu_storage)?;
673            Ok::<_, Error>(data[self.layout().start_offset()])
674        };
675        match &*self.storage() {
676            Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
677            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
678            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
679            #[cfg(feature = "rocm")]
680            Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
681            #[cfg(feature = "vulkan")]
682            Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
683        }
684    }
685
686    /// An alias for `to_scalar`.
687    pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
688        self.to_scalar::<S>()
689    }
690
691    /// Repeat this tensor along the specified dimensions.
692    pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
693        // Similar to PyTorch, we extend the number of dimensions of self if needed.
694        let repeats = shape.into();
695        let repeats = repeats.dims();
696        let mut inp = if self.rank() < repeats.len() {
697            let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
698            self.reshape(shape)?
699        } else {
700            self.clone()
701        };
702        for (idx, &repeat) in repeats.iter().enumerate() {
703            if repeat > 1 {
704                inp = Tensor::cat(&vec![&inp; repeat], idx)?
705            }
706        }
707        Ok(inp)
708    }
709
710    /// Creates grids of coordinates specified by the 1D inputs.
711    ///
712    /// # Arguments
713    ///
714    /// * `args` - A slice of 1D tensors.
715    /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
716    ///   first dimension corresponds to the cardinality of the second input and the second
717    ///   dimension corresponds to the cardinality of the first input. If ij is selected, the
718    ///   dimensions are in the same order as the cardinality of the inputs.
719    ///
720    /// # Examples
721    ///
722    /// ```rust
723    /// use hanzo_ml::{Tensor, Device, Shape};
724    /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
725    /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
726    ///
727    /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
728    ///
729    /// assert_eq!(grids_xy.len(), 2);
730    /// assert_eq!(grids_xy[0].dims(), &[3, 3]);
731    ///
732    /// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
733    /// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
734    ///
735    /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
736    ///
737    /// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
738    /// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
739    /// # Ok::<(), hanzo_ml::Error>(())
740    /// ```
741    ///
742    /// # Errors
743    ///
744    /// * Will return `Err` if `args` contains less than 2 tensors.
745    ///
746    pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
747        if args.len() <= 1 {
748            Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
749        }
750        let args: Vec<_> = if xy_indexing {
751            args.iter().rev().collect()
752        } else {
753            args.iter().collect()
754        };
755
756        let mut shape = Vec::with_capacity(args.len());
757        for arg in args.iter() {
758            shape.push(arg.as_ref().dims1()?)
759        }
760
761        let mut grids = Vec::with_capacity(args.len());
762        for idx in 0..args.len() {
763            let mut ones = vec![1usize; args.len()];
764            ones[idx] = shape[idx];
765            let arg = args[idx].as_ref().reshape(ones)?;
766            let mut repeats = shape.clone();
767            repeats[idx] = 1;
768            let repeated_tensor = arg.repeat(repeats)?;
769            grids.push(repeated_tensor);
770        }
771        if xy_indexing {
772            grids.reverse();
773        }
774        Ok(grids)
775    }
776
777    /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
778    /// The input values `mul` and `add` are casted to the appropriate type so some rounding might
779    /// be performed.
780    ///
781    /// ```rust
782    /// use hanzo_ml::{Tensor, Device};
783    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
784    /// let a = a.affine(4., -2.)?;
785    /// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
786    /// # Ok::<(), hanzo_ml::Error>(())
787    /// ```
788    pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
789        if self.elem_count() == 0 {
790            return Ok(self.clone());
791        }
792        let storage = self.storage().affine(self.layout(), mul, add)?;
793        let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
794        Ok(from_storage(storage, self.shape(), op, false))
795    }
796
797    /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
798    pub fn elu(&self, alpha: f64) -> Result<Self> {
799        if self.elem_count() == 0 {
800            return Ok(self.clone());
801        }
802        let storage = self.storage().elu(self.layout(), alpha)?;
803        let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
804        Ok(from_storage(storage, self.shape(), op, false))
805    }
806
807    /// Raise the tensor to some float exponent `e`.
808    pub fn powf(&self, e: f64) -> Result<Self> {
809        if self.elem_count() == 0 {
810            return Ok(self.clone());
811        }
812        let storage = self.storage().powf(self.layout(), e)?;
813        let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
814        Ok(from_storage(storage, self.shape(), op, false))
815    }
816
817    pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
818        if dim >= self.dims().len() {
819            Err(Error::DimOutOfRange {
820                shape: self.shape().clone(),
821                dim: dim as i32,
822                op,
823            }
824            .bt())?
825        } else {
826            Ok(())
827        }
828    }
829
830    /// Split a tensor into the specified number of chunks, this may return less chunks than
831    /// specified.
832    pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
833        let dim = dim.to_index(self.shape(), "chunk")?;
834        let size = self.dim(dim)?;
835        if size < chunks {
836            (0..size).map(|i| self.narrow(dim, i, 1)).collect()
837        } else {
838            let chunk_size = size / chunks;
839            let cnt_additional = size % chunks;
840            let mut tensors = vec![];
841            let mut sum_chunk_size = 0;
842            for i in 0..chunks {
843                let chunk_size = if i < cnt_additional {
844                    chunk_size + 1
845                } else {
846                    chunk_size
847                };
848                let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
849                tensors.push(tensor);
850                sum_chunk_size += chunk_size
851            }
852            Ok(tensors)
853        }
854    }
855
856    /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
857    /// ranges from `start` to `start + len`.
858    /// ```
859    /// use hanzo_ml::{Tensor, Device};
860    /// let a = Tensor::new(&[
861    ///     [0f32, 1., 2.],
862    ///     [3.  , 4., 5.],
863    ///     [6.  , 7., 8.]
864    /// ], &Device::Cpu)?;
865    ///
866    /// let b = a.narrow(0, 1, 2)?;
867    /// assert_eq!(b.shape().dims(), &[2, 3]);
868    /// assert_eq!(b.to_vec2::<f32>()?, &[
869    ///     [3., 4., 5.],
870    ///     [6., 7., 8.]
871    /// ]);
872    ///
873    /// let c = a.narrow(1, 1, 1)?;
874    /// assert_eq!(c.shape().dims(), &[3, 1]);
875    /// assert_eq!(c.to_vec2::<f32>()?, &[
876    ///     [1.],
877    ///     [4.],
878    ///     [7.]
879    /// ]);
880    /// # Ok::<(), hanzo_ml::Error>(())
881    /// ```
882    pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
883        let dims = self.dims();
884        let dim = dim.to_index(self.shape(), "narrow")?;
885        let err = |msg| {
886            Err::<(), _>(
887                Error::NarrowInvalidArgs {
888                    shape: self.shape().clone(),
889                    dim,
890                    start,
891                    len,
892                    msg,
893                }
894                .bt(),
895            )
896        };
897        if start > dims[dim] {
898            err("start > dim_len")?
899        }
900        if start.saturating_add(len) > dims[dim] {
901            err("start + len > dim_len")?
902        }
903        if start == 0 && dims[dim] == len {
904            Ok(self.clone())
905        } else {
906            let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
907            let layout = self.layout().narrow(dim, start, len)?;
908            let tensor_ = Tensor_ {
909                id: TensorId::new(),
910                storage: self.storage.clone(),
911                layout,
912                op,
913                is_variable: false,
914                dtype: self.dtype,
915                device: self.device.clone(),
916            };
917            Ok(Tensor(Arc::new(tensor_)))
918        }
919    }
920
921    fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
922        match dims {
923            [] => Ok(self),
924            [i] => self.squeeze(*i),
925            dims => {
926                let dims = self
927                    .dims()
928                    .iter()
929                    .enumerate()
930                    .filter_map(|(dim_idx, &v)| {
931                        if dims.contains(&dim_idx) {
932                            None
933                        } else {
934                            Some(v)
935                        }
936                    })
937                    .collect::<Vec<_>>();
938                self.reshape(dims)
939            }
940        }
941    }
942
943    fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
944        let dim = dim.to_index(self.shape(), op.name())?;
945        let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
946        let mut dims = self.dims().to_vec();
947        dims[dim] = 1;
948        let op = match op {
949            ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
950                BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
951            }
952            ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
953        };
954        let res = from_storage(storage, dims, op, false);
955        if keepdim {
956            Ok(res)
957        } else {
958            res.squeeze_dims(&[dim])
959        }
960    }
961
962    fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
963        let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
964        let storage = self
965            .storage()
966            .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
967        let mut dims = self.dims().to_vec();
968        for &sum_dim in sum_dims.iter() {
969            dims[sum_dim] = 1
970        }
971        let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
972        let sum = from_storage(storage, dims, op, false);
973        if keepdim {
974            Ok(sum)
975        } else {
976            sum.squeeze_dims(&sum_dims)
977        }
978    }
979
980    /// Roll the tensor input along the given dimension.
981    /// Elements that are shifted beyond the last position are re-introduced at the first position.
982    ///
983    /// ```rust
984    /// # use hanzo_ml::{Tensor, Device};
985    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
986    /// let tensor = tensor.roll(1, 0)?;
987    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
988    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
989    /// let tensor = tensor.roll(-1, 0)?;
990    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
991    /// # Ok::<(), hanzo_ml::Error>(())
992    /// ```
993    pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
994    where
995        D: Dim + Clone,
996    {
997        let dim = dim.to_index(self.shape(), "roll")?;
998        let dim_size = self.dim(dim)?;
999        let shift = shift.rem_euclid(dim_size as i32) as usize;
1000        if shift == 0 {
1001            Ok(self.clone())
1002        } else {
1003            let a = self.narrow(dim, 0, dim_size - shift)?;
1004            let b = self.narrow(dim, dim_size - shift, shift)?;
1005            Tensor::cat(&[&b, &a], dim)
1006        }
1007    }
1008
1009    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
1010    /// input dimensions.
1011    ///
1012    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
1013    /// that the number of elements for each dimension index in `sum_dims` is 1.
1014    ///
1015    /// ```rust
1016    /// use hanzo_ml::{Tensor, Device};
1017    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
1018    /// let s = a.sum_keepdim(0)?;
1019    /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
1020    /// let s = a.sum_keepdim(1)?;
1021    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
1022    /// let s = a.sum_keepdim((0, 1))?;
1023    /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
1024    /// # Ok::<(), hanzo_ml::Error>(())
1025    /// ```
1026    pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1027        self.sum_impl(sum_dims, true)
1028    }
1029
1030    /// Returns the sum of all elements in the input tensor. The sum is performed over all the
1031    /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
1032    /// kept.
1033    pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
1034        self.sum_impl(sum_dims, false)
1035    }
1036
1037    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
1038    /// input dimensions.
1039    ///
1040    /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
1041    /// that the number of elements for each dimension index in `mean_dims` is 1.
1042    ///
1043    /// ```rust
1044    /// use hanzo_ml::{Tensor, Device};
1045    /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
1046    /// let s = a.mean_keepdim(0)?;
1047    /// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
1048    /// let s = a.mean_keepdim(1)?;
1049    /// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
1050    /// let s = a.mean_keepdim((0, 1))?;
1051    /// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
1052    /// # Ok::<(), hanzo_ml::Error>(())
1053    /// ```
1054    pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1055        let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
1056        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1057        let scale = 1f64 / (reduced_dim as f64);
1058        self.sum_impl(mean_dims, true)? * scale
1059    }
1060
1061    /// Returns the mean of all elements in the input tensor. The mean is performed over all the
1062    /// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
1063    /// kept.
1064    pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
1065        let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
1066        let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
1067        let scale = 1f64 / (reduced_dim as f64);
1068        self.sum_impl(mean_dims, false)? * scale
1069    }
1070
1071    /// Returns the unbiased variance over the selected dimension.
1072    pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1073        let dim = dim.to_index(self.shape(), "var")?;
1074        let mean = self.mean_keepdim(dim)?;
1075        let squares = self.broadcast_sub(&mean)?.sqr()?;
1076        squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
1077    }
1078
1079    /// Returns the unbiased variance over the selected dimension.
1080    pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
1081        let dim = dim.to_index(self.shape(), "var")?;
1082        self.var_keepdim(dim)?.squeeze(dim)
1083    }
1084
1085    /// Gathers the maximum value across the selected dimension. The resulting shape has the same
1086    /// number of dimensions as the original tensor and the select dimension has a single element.
1087    pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1088        self.reduce_impl(dim, true, ReduceOp::Max)
1089    }
1090
1091    /// Similar to `max_keepdim` but the target dimension is squeezed.
1092    pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
1093        self.reduce_impl(dim, false, ReduceOp::Max)
1094    }
1095
1096    /// Gathers the minimum value across the selected dimension. The resulting shape has the same
1097    /// number of dimensions as the original tensor and the select dimension has a single element.
1098    pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1099        self.reduce_impl(dim, true, ReduceOp::Min)
1100    }
1101
1102    /// Similar to `min_keepdim` but the target dimension is squeezed.
1103    pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
1104        self.reduce_impl(dim, false, ReduceOp::Min)
1105    }
1106
1107    pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1108        self.reduce_impl(dim, true, ReduceOp::ArgMax)
1109    }
1110
1111    /// Similar to `argmax_keepdim` but the target dimension is squeezed.
1112    pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
1113        self.reduce_impl(dim, false, ReduceOp::ArgMax)
1114    }
1115
1116    pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
1117        self.reduce_impl(dim, true, ReduceOp::ArgMin)
1118    }
1119
1120    /// Similar to `argmin_keepdim` but the target dimension is squeezed.
1121    pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
1122        self.reduce_impl(dim, false, ReduceOp::ArgMin)
1123    }
1124
1125    /// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
1126    /// comparison operation is specified by the `op` argument.
1127    ///
1128    /// The returned tensor has the same shape as the original tensors and uses `u8` elements.
1129    pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
1130        let rhs = match rhs.to_tensor_scalar()? {
1131            crate::scalar::TensorScalar::Tensor(rhs) => rhs,
1132            crate::scalar::TensorScalar::Scalar(rhs) => rhs
1133                .to_dtype(self.dtype())?
1134                .to_device(self.device())?
1135                .broadcast_as(self.shape())?,
1136        };
1137        let shape = self.same_shape_binary_op(&rhs, "cmp")?;
1138        let storage = self
1139            .storage()
1140            .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
1141        let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
1142        Ok(from_storage(storage, shape.dims(), op, false))
1143    }
1144
1145    /// Element-wise equality.
1146    pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1147        self.cmp(rhs, CmpOp::Eq)
1148    }
1149
1150    /// Element-wise non-equality.
1151    pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1152        self.cmp(rhs, CmpOp::Ne)
1153    }
1154
1155    /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
1156    /// rhs` and 0 otherwise.
1157    pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1158        self.cmp(rhs, CmpOp::Lt)
1159    }
1160
1161    /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
1162    /// rhs` and 0 otherwise.
1163    pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1164        self.cmp(rhs, CmpOp::Gt)
1165    }
1166
1167    /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
1168    /// rhs` and 0 otherwise.
1169    pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1170        self.cmp(rhs, CmpOp::Ge)
1171    }
1172
1173    /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
1174    /// rhs` and 0 otherwise.
1175    pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
1176        self.cmp(rhs, CmpOp::Le)
1177    }
1178
1179    /// Clamp the tensor values to be between `min` and `max`.
1180    pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
1181        self.maximum(min)?.minimum(max)
1182    }
1183
1184    /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
1185    ///
1186    /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
1187    /// tensor also has three dimensions, `(batch, channels, target_size)`.
1188    pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
1189        let (n, c, _l) = self.dims3()?;
1190        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
1191        let storage = self
1192            .storage()
1193            .upsample_nearest1d(self.layout(), target_size)?;
1194        Ok(from_storage(storage, (n, c, target_size), op, false))
1195    }
1196
1197    /// Alias for `interpolate1d`.
1198    pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
1199        self.interpolate1d(target_size)
1200    }
1201
1202    /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
1203    /// nearest element.
1204    ///
1205    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1206    /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
1207    pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1208        let (n, c, _h, _w) = self.dims4()?;
1209        let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
1210            arg,
1211            target_h,
1212            target_w,
1213        });
1214        let storage = self
1215            .storage()
1216            .upsample_nearest2d(self.layout(), target_h, target_w)?;
1217        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1218    }
1219
1220    /// Alias for `interpolate2d`.
1221    pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
1222        self.interpolate2d(target_h, target_w)
1223    }
1224
1225    /// Bilinear interpolation to resize the input tensor to the specified size.
1226    ///
1227    /// The input tensor should have four dimensions: `(batch, channels, h, w)`.
1228    /// The returned tensor also has four dimensions: `(batch, channels, target_h, target_w)`.
1229    ///
1230    /// # Arguments
1231    ///
1232    /// * `target_h` - Target height
1233    /// * `target_w` - Target width  
1234    /// * `align_corners` - If true, corner pixels are aligned. If false (default),
1235    ///   pixels are treated as areas (matches PyTorch default behavior).
1236    ///
1237    /// # Example
1238    ///
1239    /// ```rust
1240    /// use hanzo_ml::{Tensor, Device};
1241    /// # fn main() -> hanzo_ml::Result<()> {
1242    /// let t = Tensor::arange(0f32, 16f32, &Device::Cpu)?.reshape((1, 1, 4, 4))?;
1243    /// let upsampled = t.upsample_bilinear2d(8, 8, false)?;
1244    /// assert_eq!(upsampled.dims(), &[1, 1, 8, 8]);
1245    /// # Ok(())
1246    /// # }
1247    /// ```
1248    pub fn upsample_bilinear2d(
1249        &self,
1250        target_h: usize,
1251        target_w: usize,
1252        align_corners: bool,
1253    ) -> Result<Self> {
1254        let (n, c, _h, _w) = self.dims4()?;
1255        let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1256            arg,
1257            target_h,
1258            target_w,
1259            align_corners,
1260        });
1261        // Pass None for scale factors (size mode)
1262        let storage = self.storage().upsample_bilinear2d(
1263            self.layout(),
1264            target_h,
1265            target_w,
1266            align_corners,
1267            None,
1268            None,
1269        )?;
1270        Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
1271    }
1272
1273    /// Bilinear interpolation using scale factors.
1274    ///
1275    /// Similar to `upsample_bilinear2d` but uses scale factors instead of absolute sizes.
1276    /// This matches PyTorch's `interpolate(scale_factor=...)` behavior.
1277    ///
1278    /// # Arguments
1279    ///
1280    /// * `scale_h` - Height scaling factor
1281    /// * `scale_w` - Width scaling factor
1282    /// * `align_corners` - If true, corner pixels are aligned
1283    ///
1284    /// # Example
1285    ///
1286    /// ```rust
1287    /// use hanzo_ml::{Tensor, Device};
1288    /// # fn main() -> hanzo_ml::Result<()> {
1289    /// let t = Tensor::arange(0f32, 16f32, &Device::Cpu)?.reshape((1, 1, 4, 4))?;
1290    /// // Scale by 2x in both dimensions
1291    /// let upsampled = t.upsample_bilinear2d_with_scale(2.0, 2.0, false)?;
1292    /// assert_eq!(upsampled.dims(), &[1, 1, 8, 8]);
1293    /// # Ok(())
1294    /// # }
1295    /// ```
1296    pub fn upsample_bilinear2d_with_scale(
1297        &self,
1298        scale_h: f64,
1299        scale_w: f64,
1300        align_corners: bool,
1301    ) -> Result<Self> {
1302        let (n, c, height_in, width_in) = self.dims4()?;
1303
1304        // Calculate output size (floor, matching PyTorch)
1305        let height_out = (height_in as f64 * scale_h).floor() as usize;
1306        let width_out = (width_in as f64 * scale_w).floor() as usize;
1307
1308        // Early return if size unchanged
1309        if height_in == height_out && width_in == width_out {
1310            return Ok(self.clone());
1311        }
1312
1313        let op = BackpropOp::new1(self, |arg| Op::UpsampleBilinear2D {
1314            arg,
1315            target_h: height_out,
1316            target_w: width_out,
1317            align_corners,
1318        });
1319
1320        // Pass original scale factors (scale_factor mode)
1321        // This ensures PyTorch-compatible scale calculation
1322        let storage = self.storage().upsample_bilinear2d(
1323            self.layout(),
1324            height_out,
1325            width_out,
1326            align_corners,
1327            Some(scale_h),
1328            Some(scale_w),
1329        )?;
1330        Ok(from_storage(
1331            storage,
1332            (n, c, height_out, width_out),
1333            op,
1334            false,
1335        ))
1336    }
1337
1338    /// 2D average pooling over an input tensor with multiple channels.
1339    ///
1340    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1341    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1342    /// the two last dimensions using a kernel of size `sz`. The returned element is the average
1343    /// value over the kernel window.
1344    pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1345        let sz = sz.to_usize2();
1346        self.avg_pool2d_with_stride(sz, sz)
1347    }
1348
1349    /// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
1350    /// kernel size.
1351    pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
1352        &self,
1353        kernel_size: T,
1354        stride: T,
1355    ) -> Result<Self> {
1356        let kernel_size = kernel_size.to_usize2();
1357        let stride = stride.to_usize2();
1358        let (n, c, h, w) = self.dims4()?;
1359        if h < kernel_size.0 || w < kernel_size.1 {
1360            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1361        }
1362        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
1363        let h_out = (h - kernel_size.0) / stride.0 + 1;
1364        let w_out = (w - kernel_size.1) / stride.1 + 1;
1365        let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
1366            arg,
1367            kernel_size,
1368            stride,
1369        });
1370        let storage = self
1371            .storage()
1372            .avg_pool2d(self.layout(), kernel_size, stride)?;
1373        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1374    }
1375
1376    /// 2D max pooling over an input tensor with multiple channels.
1377    ///
1378    /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
1379    /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
1380    /// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
1381    /// value over the kernel window.
1382    pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
1383        let sz = sz.to_usize2();
1384        self.max_pool2d_with_stride(sz, sz)
1385    }
1386
1387    /// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
1388    /// kernel size.
1389    pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
1390        &self,
1391        kernel_size: T,
1392        stride: T,
1393    ) -> Result<Self> {
1394        let kernel_size = kernel_size.to_usize2();
1395        let stride = stride.to_usize2();
1396        let (n, c, h, w) = self.dims4()?;
1397        if h < kernel_size.0 || w < kernel_size.1 {
1398            bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
1399        }
1400        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
1401        let h_out = (h - kernel_size.0) / stride.0 + 1;
1402        let w_out = (w - kernel_size.1) / stride.1 + 1;
1403        let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
1404            arg,
1405            kernel_size,
1406            stride,
1407        });
1408        let storage = self
1409            .storage()
1410            .max_pool2d(self.layout(), kernel_size, stride)?;
1411        Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
1412    }
1413
1414    /// Computes the dot product of two 1D tensors.
1415    ///
1416    /// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
1417    /// - Panics if shapes are not compatible
1418    /// - Not supported for integer dtypes
1419    ///
1420    /// # Example (vectors)
1421    /// ```rust
1422    /// use hanzo_ml::{Tensor, Device};
1423    /// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
1424    /// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
1425    /// let res = t1.dot(&t2)?;
1426    /// assert_eq!(res.to_scalar::<f64>()?, 32.);
1427    /// # Ok::<(), hanzo_ml::Error>(())
1428    /// ```
1429    pub fn dot(&self, rhs: &Self) -> Result<Self> {
1430        if self.dims().len() != 1 || rhs.dims().len() != 1 {
1431            return Err(Error::ShapeMismatchBinaryOp {
1432                lhs: self.shape().clone(),
1433                rhs: rhs.shape().clone(),
1434                op: "dot",
1435            });
1436        }
1437
1438        (self * rhs).and_then(|ret| ret.sum_all())
1439    }
1440
1441    /// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
1442    /// - Output is `sqrt(sum(x^2))`.
1443    /// - Always returns a scalar (`[]` shape).
1444    ///
1445    /// # Example
1446    /// ```rust
1447    /// use hanzo_ml::{Tensor, Device};
1448    /// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
1449    /// let norm = t.norm()?;
1450    /// assert_eq!(norm.to_scalar::<f64>()?, 5.);
1451    /// # Ok::<(), hanzo_ml::Error>(())
1452    /// ```
1453    pub fn norm(&self) -> Result<Self> {
1454        if self.dtype().is_int() {
1455            bail!("norm not supported for integer dtypes");
1456        }
1457
1458        self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
1459    }
1460
1461    /// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
1462    ///
1463    /// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
1464    /// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
1465    ///
1466    /// # Example
1467    /// ```rust
1468    /// use hanzo_ml::{Tensor, Device};
1469    /// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
1470    /// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
1471    /// let res = mat.mv(&vec)?;
1472    /// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
1473    /// # Ok::<(), hanzo_ml::Error>(())
1474    /// ```
1475    pub fn mv(&self, rhs: &Self) -> Result<Self> {
1476        // Strict shape checks
1477        let lhs_dims = self.dims();
1478        let rhs_dims = rhs.dims();
1479        if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
1480            return Err(Error::ShapeMismatchBinaryOp {
1481                lhs: self.shape().clone(),
1482                rhs: rhs.shape().clone(),
1483                op: "mv",
1484            });
1485        }
1486
1487        // Direct matmul after ensuring rhs is column vector
1488        self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
1489    }
1490
1491    /// Returns the matrix-multiplication of the input tensor with the other provided tensor.
1492    ///
1493    /// # Arguments
1494    ///
1495    /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
1496    /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
1497    ///
1498    /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
1499    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1500        let a_dims = self.shape().dims();
1501        let b_dims = rhs.shape().dims();
1502
1503        let dim = a_dims.len();
1504
1505        if dim < 2 || b_dims.len() != dim {
1506            Err(Error::ShapeMismatchBinaryOp {
1507                lhs: self.shape().clone(),
1508                rhs: rhs.shape().clone(),
1509                op: "matmul",
1510            }
1511            .bt())?
1512        }
1513
1514        let m = a_dims[dim - 2];
1515        let k = a_dims[dim - 1];
1516        let k2 = b_dims[dim - 2];
1517        let n = b_dims[dim - 1];
1518
1519        let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
1520        if c_shape.elem_count() == 0 || k == 0 {
1521            return Tensor::zeros(c_shape, self.dtype(), self.device());
1522        }
1523        let batching: usize = a_dims[..dim - 2].iter().product();
1524        let batching_b: usize = b_dims[..dim - 2].iter().product();
1525        if k != k2 || batching != batching_b {
1526            Err(Error::ShapeMismatchBinaryOp {
1527                lhs: self.shape().clone(),
1528                rhs: rhs.shape().clone(),
1529                op: "matmul",
1530            }
1531            .bt())?
1532        }
1533
1534        let storage = self.storage().matmul(
1535            &rhs.storage(),
1536            (batching, m, n, k),
1537            self.layout(),
1538            rhs.layout(),
1539        )?;
1540        let op = BackpropOp::new2(self, rhs, Op::Matmul);
1541        Ok(from_storage(storage, c_shape, op, false))
1542    }
1543
1544    /// Matrix-multiplication with broadcasting support.
1545    ///
1546    /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
1547    /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
1548    /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
1549    pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
1550        let lhs = self;
1551        let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
1552        let l_broadcast = l_shape != *lhs.shape();
1553        let r_broadcast = r_shape != *rhs.shape();
1554        // TODO: Avoid concretising the broadcasted matrixes via contiguous.
1555        match (l_broadcast, r_broadcast) {
1556            (true, true) => lhs
1557                .broadcast_as(&l_shape)?
1558                .contiguous()?
1559                .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1560            (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
1561            (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
1562            (false, false) => lhs.matmul(rhs),
1563        }
1564    }
1565
1566    /// Returns a tensor with the same shape as the input tensor, the values are taken from
1567    /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
1568    /// input tensor is equal to zero.
1569    pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
1570        let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
1571        let shape = self.same_shape_binary_op(on_false, "where_cond")?;
1572        let storage = self.storage().where_cond(
1573            self.layout(),
1574            &on_true.storage(),
1575            on_true.layout(),
1576            &on_false.storage(),
1577            on_false.layout(),
1578        )?;
1579        let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
1580        Ok(from_storage(storage, shape, op, false))
1581    }
1582
1583    /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
1584    /// values hold in the `ids` tensor.
1585    ///
1586    /// # Arguments
1587    ///
1588    /// * `self` - A tensor with dimensions `v, h`.
1589    /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
1590    ///
1591    /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
1592    /// vocabulary size, and `h` the hidden size.
1593    ///
1594    /// ```rust
1595    /// use hanzo_ml::{Tensor, Device};
1596    /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
1597    /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
1598    /// let emb = values.embedding(&ids)?;
1599    /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
1600    /// # Ok::<(), hanzo_ml::Error>(())
1601    /// ```
1602    pub fn embedding(&self, ids: &Self) -> Result<Self> {
1603        if self.rank() != 2 || ids.rank() != 1 {
1604            Err(Error::ShapeMismatchBinaryOp {
1605                lhs: self.shape().clone(),
1606                rhs: ids.shape().clone(),
1607                op: "embedding",
1608            }
1609            .bt())?
1610        }
1611        self.index_select(ids, 0)
1612    }
1613
1614    fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
1615        let source_dims = source.dims();
1616        let self_dims = self.dims();
1617        let mismatch = if source_dims.len() != self_dims.len() {
1618            true
1619        } else {
1620            let mut mismatch = false;
1621            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1622                if i != dim && d1 != d2 {
1623                    mismatch = true;
1624                    break;
1625                }
1626            }
1627            mismatch
1628        };
1629        if mismatch {
1630            Err(Error::ShapeMismatchBinaryOp {
1631                op: "scatter (self, src)",
1632                lhs: self.shape().clone(),
1633                rhs: source.shape().clone(),
1634            }
1635            .bt())?
1636        }
1637        if indexes.dims() != source.dims() {
1638            Err(Error::ShapeMismatchBinaryOp {
1639                op: "scatter (indexes, src)",
1640                lhs: indexes.shape().clone(),
1641                rhs: source.shape().clone(),
1642            }
1643            .bt())?
1644        }
1645        Ok(())
1646    }
1647
1648    pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1649        let dim = dim.to_index(self.shape(), "scatter")?;
1650        self.scatter_checks(indexes, source, dim)?;
1651        let shape = self.shape();
1652        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1653        self.storage()
1654            .copy_strided_src(&mut storage, 0, self.layout())?;
1655        let layout = Layout::contiguous(shape);
1656        storage.scatter_set(
1657            &layout,
1658            &indexes.storage(),
1659            indexes.layout(),
1660            &source.storage(),
1661            source.layout(),
1662            dim,
1663        )?;
1664        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1665            Op::Scatter(t1, t2, t3, dim)
1666        });
1667        Ok(from_storage(storage, self.shape(), op, false))
1668    }
1669
1670    pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1671        if self.same_storage(source) {
1672            crate::bail!("cannot use slice_set when self and src share their storage")
1673        }
1674        let dim = dim.to_index(self.shape(), "scatter-set")?;
1675        self.scatter_checks(indexes, source, dim)?;
1676        self.storage_mut().scatter_set(
1677            self.layout(),
1678            &indexes.storage(),
1679            indexes.layout(),
1680            &source.storage(),
1681            source.layout(),
1682            dim,
1683        )?;
1684        Ok(())
1685    }
1686
1687    pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1688        let dim = dim.to_index(self.shape(), "scatter-add")?;
1689        self.scatter_checks(indexes, source, dim)?;
1690        let shape = self.shape();
1691        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
1692        self.storage()
1693            .copy_strided_src(&mut storage, 0, self.layout())?;
1694        let layout = Layout::contiguous(shape);
1695        storage.scatter_add(
1696            &layout,
1697            &indexes.storage(),
1698            indexes.layout(),
1699            &source.storage(),
1700            source.layout(),
1701            dim,
1702        )?;
1703        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1704            Op::ScatterAdd(t1, t2, t3, dim)
1705        });
1706        Ok(from_storage(storage, self.shape(), op, false))
1707    }
1708
1709    pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
1710        if self.same_storage(source) {
1711            crate::bail!("cannot use slice_set when self and src share their storage")
1712        }
1713        let dim = dim.to_index(self.shape(), "scatter-add-set")?;
1714        self.scatter_checks(indexes, source, dim)?;
1715        self.storage_mut().scatter_add(
1716            self.layout(),
1717            &indexes.storage(),
1718            indexes.layout(),
1719            &source.storage(),
1720            source.layout(),
1721            dim,
1722        )?;
1723        Ok(())
1724    }
1725
1726    /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
1727    pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
1728        let dim = dim.to_index(self.shape(), "slice-scatter")?;
1729        if dim == 0 {
1730            self.slice_scatter0(src, start)
1731        } else {
1732            // TODO: Maybe we want to add a more efficient implementation at some point.
1733            self.transpose(0, dim)?
1734                .slice_scatter0(&src.transpose(0, dim)?, start)?
1735                .transpose(0, dim)
1736        }
1737    }
1738
1739    /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
1740    pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
1741        if self.dtype() != src.dtype() {
1742            Err(Error::DTypeMismatchBinaryOp {
1743                lhs: self.dtype(),
1744                rhs: src.dtype(),
1745                op: "slice-scatter",
1746            }
1747            .bt())?
1748        }
1749        if self.device().location() != src.device.location() {
1750            Err(Error::DeviceMismatchBinaryOp {
1751                lhs: self.device().location(),
1752                rhs: src.device().location(),
1753                op: "slice-scatter",
1754            }
1755            .bt())?
1756        }
1757        if self.rank() != src.rank() {
1758            Err(Error::UnexpectedNumberOfDims {
1759                expected: self.rank(),
1760                got: src.rank(),
1761                shape: src.shape().clone(),
1762            }
1763            .bt())?
1764        }
1765        let shape_ok =
1766            self.dims()
1767                .iter()
1768                .zip(src.dims().iter())
1769                .enumerate()
1770                .all(|(dim_idx, (&d1, &d2))| {
1771                    if 0 == dim_idx {
1772                        d2 + start <= d1
1773                    } else {
1774                        d1 == d2
1775                    }
1776                });
1777        if !shape_ok {
1778            Err(Error::ShapeMismatchBinaryOp {
1779                op: "slice-scatter (self, src)",
1780                lhs: self.shape().clone(),
1781                rhs: src.shape().clone(),
1782            }
1783            .bt())?
1784        }
1785        let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
1786        self.storage()
1787            .copy_strided_src(&mut storage, 0, self.layout())?;
1788        let offset = start * src.dims()[1..].iter().product::<usize>();
1789        src.storage()
1790            .copy_strided_src(&mut storage, offset, src.layout())?;
1791        let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
1792        Ok(from_storage(storage, self.shape(), op, false))
1793    }
1794
1795    /// Accumulate element from `source` at indexes `indexes` and add them to `self`.
1796    pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
1797        let dim = dim.to_index(self.shape(), "index-add")?;
1798        let source_dims = source.dims();
1799        let self_dims = self.dims();
1800        let mismatch = if source_dims.len() != self_dims.len() {
1801            true
1802        } else {
1803            let mut mismatch = false;
1804            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
1805                if i != dim && d1 != d2 {
1806                    mismatch = true;
1807                    break;
1808                }
1809            }
1810            mismatch
1811        };
1812        if mismatch {
1813            Err(Error::ShapeMismatchBinaryOp {
1814                op: "index-add (self, source)",
1815                lhs: self.shape().clone(),
1816                rhs: source.shape().clone(),
1817            }
1818            .bt())?
1819        }
1820        // The number of element in indexes must match the dimension on which the add is
1821        // performed on the source tensor (and the index values from `indexes` are taken from
1822        // the target tensor self)
1823        let indexes_len = indexes.dims1()?;
1824        if source_dims[dim] != indexes_len {
1825            Err(Error::ShapeMismatchBinaryOp {
1826                op: "index-add (ids, source))",
1827                lhs: indexes.shape().clone(),
1828                rhs: source.shape().clone(),
1829            }
1830            .bt())?
1831        }
1832        let storage = self.storage().index_add(
1833            self.layout(),
1834            &indexes.storage(),
1835            indexes.layout(),
1836            &source.storage(),
1837            source.layout(),
1838            dim,
1839        )?;
1840        let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
1841            Op::IndexAdd(t1, t2, t3, dim)
1842        });
1843        Ok(from_storage(storage, self.shape(), op, false))
1844    }
1845
1846    /// Gather values across the target dimension.
1847    ///
1848    /// # Arguments
1849    ///
1850    /// * `self` - The input tensor.
1851    /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
1852    ///   and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
1853    /// * `dim` - the target dimension.
1854    ///
1855    /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
1856    /// dimension `dim` by the values in `indexes`.
1857    pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1858        let dim = dim.to_index(self.shape(), "gather")?;
1859
1860        let self_dims = self.dims();
1861        let indexes_dims = indexes.dims();
1862        let mismatch = if indexes_dims.len() != self_dims.len() {
1863            true
1864        } else {
1865            let mut mismatch = false;
1866            for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
1867                if i != dim && d1 < d2 {
1868                    mismatch = true;
1869                    break;
1870                }
1871            }
1872            mismatch
1873        };
1874        if mismatch {
1875            Err(Error::ShapeMismatchBinaryOp {
1876                op: "gather",
1877                lhs: self.shape().clone(),
1878                rhs: indexes.shape().clone(),
1879            }
1880            .bt())?
1881        }
1882        let storage =
1883            self.storage()
1884                .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
1885        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
1886        Ok(from_storage(storage, indexes.shape(), op, false))
1887    }
1888
1889    /// Select values for the input tensor at the target indexes across the specified dimension.
1890    ///
1891    /// The `indexes` is argument is an int tensor with a single dimension.
1892    /// The output has the same number of dimension as the `self` input. The target dimension of
1893    /// the output has length the length of `indexes` and the values are taken from `self` using
1894    /// the index from `indexes`. Other dimensions have the same number of elements as the input
1895    /// tensor.
1896    pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
1897        let dim = dim.to_index(self.shape(), "index-select")?;
1898        let indexes_len = match indexes.dims() {
1899            [l] => *l,
1900            _ => Err(Error::ShapeMismatchBinaryOp {
1901                lhs: self.shape().clone(),
1902                rhs: indexes.shape().clone(),
1903                op: "index-select",
1904            }
1905            .bt())?,
1906        };
1907        let storage = self.storage().index_select(
1908            &indexes.storage(),
1909            self.layout(),
1910            indexes.layout(),
1911            dim,
1912        )?;
1913        let mut dims = self.dims().to_vec();
1914        dims[dim] = indexes_len;
1915        let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
1916        Ok(from_storage(storage, dims, op, false))
1917    }
1918
1919    /// Returns an iterator over position of the elements in the storage when ranging over the
1920    /// index tuples in lexicographic order.
1921    pub fn strided_index(&self) -> crate::StridedIndex<'_> {
1922        self.layout.strided_index()
1923    }
1924
1925    /// Similar to `strided_index` but returns the position of the start of each contiguous block
1926    /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
1927    /// will only return the start offset and the size would be the number of elements in the
1928    /// tensor.
1929    pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
1930        self.layout.strided_blocks()
1931    }
1932
1933    /// Returns the data contained in a 1D tensor as a vector of scalar values.
1934    pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
1935        if self.rank() != 1 {
1936            Err(Error::UnexpectedNumberOfDims {
1937                expected: 1,
1938                got: self.rank(),
1939                shape: self.shape().clone(),
1940            }
1941            .bt())?
1942        }
1943        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1944            let data = S::cpu_storage_as_slice(cpu_storage)?;
1945            let data = match self.layout.contiguous_offsets() {
1946                Some((o1, o2)) => data[o1..o2].to_vec(),
1947                None => self.strided_index().map(|i| data[i]).collect(),
1948            };
1949            Ok::<Vec<_>, Error>(data)
1950        };
1951        match &*self.storage() {
1952            Storage::Cpu(storage) => from_cpu_storage(storage),
1953            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1954            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1955            #[cfg(feature = "rocm")]
1956            Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1957            #[cfg(feature = "vulkan")]
1958            Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1959        }
1960    }
1961
1962    /// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
1963    pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
1964        let (dim1, dim2) = self.dims2()?;
1965        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
1966            let data = S::cpu_storage_as_slice(cpu_storage)?;
1967            let mut rows = vec![];
1968            match self.layout.contiguous_offsets() {
1969                Some((o1, o2)) => {
1970                    let data = &data[o1..o2];
1971                    for idx_row in 0..dim1 {
1972                        rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
1973                    }
1974                }
1975                None => {
1976                    let mut src_index = self.strided_index();
1977                    for _idx_row in 0..dim1 {
1978                        let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
1979                        rows.push(row)
1980                    }
1981                    assert!(src_index.next().is_none());
1982                }
1983            }
1984            Ok(rows)
1985        };
1986        match &*self.storage() {
1987            Storage::Cpu(storage) => from_cpu_storage(storage),
1988            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1989            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1990            #[cfg(feature = "rocm")]
1991            Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1992            #[cfg(feature = "vulkan")]
1993            Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
1994        }
1995    }
1996
1997    /// Returns the data contained in a 3D tensor.
1998    pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
1999        let (dim1, dim2, dim3) = self.dims3()?;
2000        let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
2001            let data = S::cpu_storage_as_slice(cpu_storage)?;
2002            let mut top_rows = vec![];
2003            match self.layout.contiguous_offsets() {
2004                Some((o1, o2)) => {
2005                    let data = &data[o1..o2];
2006                    let dim23 = dim2 * dim3;
2007                    for idx1 in 0..dim1 {
2008                        let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
2009                        let mut rows = vec![];
2010                        for idx2 in 0..dim2 {
2011                            rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
2012                        }
2013                        top_rows.push(rows);
2014                    }
2015                }
2016                None => {
2017                    let mut src_index = self.strided_index();
2018                    for _idx in 0..dim1 {
2019                        let mut rows = vec![];
2020                        for _jdx in 0..dim2 {
2021                            let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
2022                            rows.push(row)
2023                        }
2024                        top_rows.push(rows);
2025                    }
2026                    assert!(src_index.next().is_none());
2027                }
2028            }
2029            Ok(top_rows)
2030        };
2031        match &*self.storage() {
2032            Storage::Cpu(storage) => from_cpu_storage(storage),
2033            Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2034            Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2035            #[cfg(feature = "rocm")]
2036            Storage::Rocm(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2037            #[cfg(feature = "vulkan")]
2038            Storage::Vulkan(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
2039        }
2040    }
2041
2042    /// The dtype for the elements stored in the input tensor.
2043    pub fn dtype(&self) -> DType {
2044        self.dtype
2045    }
2046
2047    /// The device on which the input tensor is located.
2048    pub fn device(&self) -> &Device {
2049        &self.device
2050    }
2051
2052    /// The tensor shape, i.e. dimension sizes on each axis.
2053    pub fn shape(&self) -> &Shape {
2054        self.layout().shape()
2055    }
2056
2057    /// The dimension size for this tensor on each axis.
2058    pub fn dims(&self) -> &[usize] {
2059        self.shape().dims()
2060    }
2061
2062    /// The dimension size for a specified dimension index.
2063    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
2064        let dim = dim.to_index(self.shape(), "dim")?;
2065        Ok(self.dims()[dim])
2066    }
2067
2068    /// The layout of the input tensor, this stores both the shape of the tensor as well as the
2069    /// strides and the start offset to apply to the underlying storage.
2070    pub fn layout(&self) -> &Layout {
2071        &self.layout
2072    }
2073
2074    pub fn stride(&self) -> &[usize] {
2075        self.layout.stride()
2076    }
2077
2078    /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
2079    pub fn rank(&self) -> usize {
2080        self.shape().rank()
2081    }
2082
2083    /// The number of elements stored in this tensor.
2084    pub fn elem_count(&self) -> usize {
2085        self.shape().elem_count()
2086    }
2087
2088    /// The unique identifier for this tensor.
2089    pub fn id(&self) -> TensorId {
2090        self.id
2091    }
2092
2093    /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is
2094    /// tracked and on which backpropagation can be performed.
2095    pub fn is_variable(&self) -> bool {
2096        self.is_variable
2097    }
2098
2099    pub(crate) fn op(&self) -> &Option<Op> {
2100        &self.op
2101    }
2102
2103    /// Computes the max of all the elements in this tensor and returns a tensor holding this
2104    /// scalar with zero dimensions.
2105    ///
2106    /// ```rust
2107    /// use hanzo_ml::{Tensor, Device};
2108    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2109    /// let tensor = tensor.max_all()?;
2110    /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
2111    /// # Ok::<(), hanzo_ml::Error>(())
2112    /// ```
2113    pub fn max_all(&self) -> Result<Tensor> {
2114        if self.rank() == 0 {
2115            Ok(self.clone())
2116        } else {
2117            self.flatten_all()?.max(0)
2118        }
2119    }
2120
2121    /// Computes the min of all the elements in this tensor and returns a tensor holding this
2122    /// scalar with zero dimensions.
2123    ///
2124    /// ```rust
2125    /// use hanzo_ml::{Tensor, Device};
2126    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2127    /// let tensor = tensor.min_all()?;
2128    /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
2129    /// # Ok::<(), hanzo_ml::Error>(())
2130    /// ```
2131    pub fn min_all(&self) -> Result<Tensor> {
2132        if self.rank() == 0 {
2133            Ok(self.clone())
2134        } else {
2135            self.flatten_all()?.min(0)
2136        }
2137    }
2138
2139    /// Computes the sum of all the elements in this tensor and returns a tensor holding this
2140    /// scalar with zero dimensions.
2141    ///
2142    /// ```rust
2143    /// use hanzo_ml::{Tensor, Device};
2144    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2145    /// let tensor = tensor.sum_all()?;
2146    /// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
2147    /// # Ok::<(), hanzo_ml::Error>(())
2148    /// ```
2149    pub fn sum_all(&self) -> Result<Tensor> {
2150        let dims: Vec<_> = (0..self.rank()).collect();
2151        self.sum(dims)
2152    }
2153
2154    pub fn mean_all(&self) -> Result<Tensor> {
2155        self.sum_all()? / self.elem_count() as f64
2156    }
2157
2158    fn flatten_<D1: Dim, D2: Dim>(
2159        &self,
2160        start_dim: Option<D1>,
2161        end_dim: Option<D2>,
2162    ) -> Result<Tensor> {
2163        if self.rank() == 0 {
2164            self.reshape(1)
2165        } else {
2166            let start_dim = match start_dim {
2167                None => 0,
2168                Some(dim) => dim.to_index(self.shape(), "flatten")?,
2169            };
2170            let end_dim = match end_dim {
2171                None => self.rank() - 1,
2172                Some(dim) => dim.to_index(self.shape(), "flatten")?,
2173            };
2174            if start_dim < end_dim {
2175                let dims = self.dims();
2176                let mut dst_dims = dims[..start_dim].to_vec();
2177                dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
2178                if end_dim + 1 < dims.len() {
2179                    dst_dims.extend(&dims[end_dim + 1..]);
2180                }
2181                self.reshape(dst_dims)
2182            } else {
2183                Ok(self.clone())
2184            }
2185        }
2186    }
2187
2188    /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
2189    /// inclusive).
2190    pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
2191        self.flatten_(Some(start_dim), Some(end_dim))
2192    }
2193
2194    /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
2195    pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
2196        self.flatten_(None::<usize>, Some(end_dim))
2197    }
2198
2199    /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
2200    /// dimension.
2201    pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
2202        self.flatten_(Some(start_dim), None::<usize>)
2203    }
2204
2205    /// Flattens the input tensor by reshaping it into a one dimension tensor.
2206    ///
2207    /// ```rust
2208    /// use hanzo_ml::{Tensor, Device};
2209    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2210    /// let tensor = tensor.flatten_all()?;
2211    /// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
2212    /// # Ok::<(), hanzo_ml::Error>(())
2213    /// ```
2214    pub fn flatten_all(&self) -> Result<Tensor> {
2215        self.flatten_(None::<usize>, None::<usize>)
2216    }
2217
2218    /// Returns the sub-tensor fixing the index at `i` on the first dimension.
2219    ///
2220    /// ```rust
2221    /// use hanzo_ml::{Tensor, Device};
2222    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2223    /// let t = tensor.get(0)?;
2224    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
2225    /// let t = tensor.get(1)?;
2226    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
2227    /// # Ok::<(), hanzo_ml::Error>(())
2228    /// ```
2229    pub fn get(&self, i: usize) -> Result<Tensor> {
2230        let dims = self.dims();
2231        if dims.is_empty() {
2232            Ok(self.clone())
2233        } else {
2234            self.narrow(0, i, 1)?.reshape(&dims[1..])
2235        }
2236    }
2237
2238    /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
2239    ///
2240    /// ```rust
2241    /// use hanzo_ml::{Tensor, Device};
2242    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2243    /// let t = tensor.get_on_dim(1, 0)?;
2244    /// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
2245    /// let t = tensor.get_on_dim(1, 1)?;
2246    /// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
2247    /// let t = tensor.get_on_dim(0, 1)?;
2248    /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
2249    /// # Ok::<(), hanzo_ml::Error>(())
2250    /// ```
2251    pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
2252        let dim = dim.to_index(self.shape(), "get_on_dim")?;
2253        self.narrow(dim, index, 1)?.squeeze(dim)
2254    }
2255
2256    /// Returns a tensor that is a transposed version of the input, the two last dimensions of the
2257    /// input are swapped.
2258    ///
2259    /// ```rust
2260    /// use hanzo_ml::{Tensor, Device};
2261    /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
2262    /// let tensor = tensor.t()?;
2263    /// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
2264    /// # Ok::<(), hanzo_ml::Error>(())
2265    /// ```
2266    pub fn t(&self) -> Result<Tensor> {
2267        let rank = self.rank();
2268        if rank < 2 {
2269            Err(Error::UnexpectedNumberOfDims {
2270                expected: 2,
2271                got: rank,
2272                shape: self.shape().clone(),
2273            }
2274            .bt())?
2275        }
2276        self.transpose(rank - 2, rank - 1)
2277    }
2278
2279    /// Returns a tensor that is a transposed version of the input, the given dimensions are
2280    /// swapped.
2281    pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
2282        let dim1 = dim1.to_index(self.shape(), "transpose")?;
2283        let dim2 = dim2.to_index(self.shape(), "transpose")?;
2284        if dim1 == dim2 {
2285            return Ok(self.clone());
2286        }
2287        let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
2288        let tensor_ = Tensor_ {
2289            id: TensorId::new(),
2290            storage: self.storage.clone(),
2291            layout: self.layout.transpose(dim1, dim2)?,
2292            op,
2293            is_variable: false,
2294            dtype: self.dtype,
2295            device: self.device.clone(),
2296        };
2297        Ok(Tensor(Arc::new(tensor_)))
2298    }
2299
2300    /// Returns a tensor with the same data as the input where the dimensions have been permuted.
2301    /// dims must be a permutation, i.e. include each dimension index exactly once.
2302    ///
2303    /// ```rust
2304    /// use hanzo_ml::{Tensor, Device};
2305    /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
2306    /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
2307    /// let tensor = tensor.permute((2, 3, 1, 0))?;
2308    /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
2309    /// # Ok::<(), hanzo_ml::Error>(())
2310    /// ```
2311    pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
2312        let dims = dims.to_indexes(self.shape(), "permute")?;
2313        // O(n^2) permutation check but these arrays are small.
2314        let is_permutation =
2315            dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
2316        if !is_permutation {
2317            bail!(
2318                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
2319                self.dims(),
2320                dims
2321            )
2322        }
2323        let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
2324        let tensor_ = Tensor_ {
2325            id: TensorId::new(),
2326            storage: self.storage.clone(),
2327            layout: self.layout.permute(&dims)?,
2328            op,
2329            is_variable: false,
2330            dtype: self.dtype,
2331            device: self.device.clone(),
2332        };
2333        Ok(Tensor(Arc::new(tensor_)))
2334    }
2335
2336    /// Returns true if the data is stored in a C contiguous (aka row major) way.
2337    pub fn is_contiguous(&self) -> bool {
2338        self.layout.is_contiguous()
2339    }
2340
2341    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
2342    pub fn is_fortran_contiguous(&self) -> bool {
2343        self.layout.is_fortran_contiguous()
2344    }
2345
2346    /// Compared to clone, this copies the actual storage but may fail because of running out of
2347    /// memory.
2348    pub fn copy(&self) -> Result<Tensor> {
2349        let op = BackpropOp::new1(self, Op::Copy);
2350        let tensor_ = Tensor_ {
2351            id: TensorId::new(),
2352            storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
2353            layout: self.layout.clone(),
2354            op,
2355            is_variable: false,
2356            dtype: self.dtype,
2357            device: self.device.clone(),
2358        };
2359        Ok(Tensor(Arc::new(tensor_)))
2360    }
2361
2362    /// Returns a new tensor detached from the current graph, gradient are not propagated through
2363    /// this new node. The storage of this tensor is shared with the initial tensor.
2364    ///
2365    /// If the tensor is already detached from the computation graph, the same tensor is returned.
2366    pub fn detach(&self) -> Tensor {
2367        if self.op.is_none() && !self.is_variable {
2368            self.clone()
2369        } else {
2370            let tensor_ = Tensor_ {
2371                id: TensorId::new(),
2372                storage: self.storage.clone(),
2373                layout: self.layout.clone(),
2374                op: BackpropOp::none(),
2375                is_variable: false,
2376                dtype: self.dtype,
2377                device: self.device.clone(),
2378            };
2379            Tensor(Arc::new(tensor_))
2380        }
2381    }
2382
2383    /// If the target device is the same as the tensor device, only a shallow copy is performed.
2384    pub fn to_device(&self, device: &Device) -> Result<Tensor> {
2385        if self.device().same_device(device) {
2386            Ok(self.clone())
2387        } else {
2388            let storage = match (&*self.storage(), device) {
2389                (Storage::Cpu(storage), Device::Cuda(cuda)) => {
2390                    Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
2391                }
2392                (Storage::Cpu(storage), Device::Metal(metal)) => {
2393                    Storage::Metal(metal.storage_from_cpu_storage(storage)?)
2394                }
2395                (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2396                (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2397                #[cfg(feature = "rocm")]
2398                (Storage::Rocm(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2399                #[cfg(feature = "vulkan")]
2400                (Storage::Vulkan(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
2401                #[cfg(feature = "rocm")]
2402                (Storage::Cpu(storage), Device::Rocm(rocm)) => {
2403                    Storage::Rocm(rocm.storage_from_cpu_storage(storage)?)
2404                }
2405                #[cfg(feature = "vulkan")]
2406                (Storage::Cpu(storage), Device::Vulkan(vulkan)) => {
2407                    Storage::Vulkan(vulkan.storage_from_cpu_storage(storage)?)
2408                }
2409                #[cfg(feature = "rocm")]
2410                (Storage::Rocm(storage), Device::Rocm(rocm)) => {
2411                    let cpu_storage = storage.to_cpu_storage()?;
2412                    Storage::Rocm(rocm.storage_from_cpu_storage(&cpu_storage)?)
2413                }
2414                #[cfg(feature = "vulkan")]
2415                (Storage::Vulkan(storage), Device::Vulkan(vulkan)) => {
2416                    let cpu_storage = storage.to_cpu_storage()?;
2417                    Storage::Vulkan(vulkan.storage_from_cpu_storage(&cpu_storage)?)
2418                }
2419                (Storage::Cuda(storage), Device::Cuda(cuda)) => {
2420                    // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
2421                    // are the same.
2422                    let cpu_storage = storage.to_cpu_storage()?;
2423                    Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
2424                }
2425                (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
2426                _ => {
2427                    bail!(
2428                        "not implemented yet, self.device: {:?}, device: {:?}",
2429                        self.device(),
2430                        device
2431                    )
2432                }
2433            };
2434            let op = BackpropOp::new1(self, Op::ToDevice);
2435            let tensor_ = Tensor_ {
2436                id: TensorId::new(),
2437                storage: Arc::new(RwLock::new(storage)),
2438                layout: self.layout.clone(),
2439                op,
2440                is_variable: false,
2441                dtype: self.dtype,
2442                device: device.clone(),
2443            };
2444            Ok(Tensor(Arc::new(tensor_)))
2445        }
2446    }
2447
2448    /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
2449    /// on the left.
2450    pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
2451        let left_shape = left_shape.into();
2452        let mut dims = left_shape.into_dims();
2453        dims.extend(self.dims());
2454        self.broadcast_as(dims)
2455    }
2456
2457    /// Broadcast the input tensor to the target shape. This returns an error if the input shape is
2458    /// not compatible with the target shape.
2459    ///
2460    /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
2461    /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
2462    /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
2463    /// `i_a` is equal to 1, any value can be used.
2464    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2465        let tensor_ = Tensor_ {
2466            id: TensorId::new(),
2467            storage: self.storage.clone(),
2468            layout: self.layout.broadcast_as(shape)?,
2469            op: BackpropOp::new1(self, Op::Broadcast),
2470            is_variable: false,
2471            dtype: self.dtype,
2472            device: self.device.clone(),
2473        };
2474        Ok(Tensor(Arc::new(tensor_)))
2475    }
2476
2477    /// An alias for broadcast_as.
2478    pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
2479        self.broadcast_as(shape)
2480    }
2481
2482    /// Casts the input tensor to the target `dtype`.
2483    ///
2484    /// ```rust
2485    /// use hanzo_ml::{Tensor, Device};
2486    /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
2487    /// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
2488    /// let tensor = tensor.to_dtype(hanzo_ml::DType::F32)?;
2489    /// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
2490    /// # Ok::<(), hanzo_ml::Error>(())
2491    /// ```
2492    pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2493        if self.dtype() == dtype {
2494            Ok(self.clone())
2495        } else {
2496            let shape = self.shape();
2497            let storage = self.storage().to_dtype(self.layout(), dtype)?;
2498            let op = BackpropOp::new1(self, Op::ToDType);
2499            Ok(from_storage(storage, shape.clone(), op, false))
2500        }
2501    }
2502
2503    /// Returns a tensor that is in row major order. This is the same as the original tensor if it
2504    /// was already contiguous, otherwise a copy is triggered.
2505    pub fn contiguous(&self) -> Result<Tensor> {
2506        if self.is_contiguous() {
2507            Ok(self.clone())
2508        } else {
2509            let shape = self.shape();
2510            let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2511            self.storage()
2512                .copy_strided_src(&mut storage, 0, self.layout())?;
2513            let op = BackpropOp::new1(self, Op::Copy);
2514            Ok(from_storage(storage, shape.clone(), op, false))
2515        }
2516    }
2517
2518    /// Returns a tensor that is in row major order. This always makes a copy.
2519    pub fn force_contiguous(&self) -> Result<Tensor> {
2520        let shape = self.shape();
2521        let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
2522        self.storage()
2523            .copy_strided_src(&mut storage, 0, self.layout())?;
2524        let op = BackpropOp::new1(self, Op::Copy);
2525        Ok(from_storage(storage, shape.clone(), op, false))
2526    }
2527
2528    /// Create a variable based on the values currently stored in a tensor. The storage is always
2529    /// copied.
2530    pub(crate) fn make_var(&self) -> Result<Tensor> {
2531        let shape = self.shape().clone();
2532        let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2533        self.storage()
2534            .copy_strided_src(&mut storage, 0, self.layout())?;
2535        Ok(from_storage(storage, shape, BackpropOp::none(), true))
2536    }
2537
2538    /// Reshape returns a tensor with the target shape provided that the number of elements of the
2539    /// original tensor is the same.
2540    /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
2541    /// a new storage and copies the data over, the returned tensor is always contiguous.
2542    ///
2543    /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
2544    /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
2545    /// as to match the number of elements in the tensor.
2546    ///
2547    /// ```rust
2548    /// # use hanzo_ml::{Tensor, DType, Device, D};
2549    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2550    ///
2551    /// let c = a.reshape((1, 6))?;
2552    /// assert_eq!(c.shape().dims(), &[1, 6]);
2553    ///
2554    /// let c = a.reshape((3, 2))?;
2555    /// assert_eq!(c.shape().dims(), &[3, 2]);
2556    ///
2557    /// let c = a.reshape((2, (), 1))?;
2558    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2559    ///
2560    /// # Ok::<(), hanzo_ml::Error>(())
2561    /// ```
2562    pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
2563        let shape = s.into_shape(self.elem_count())?;
2564        if shape.elem_count() != self.elem_count() {
2565            return Err(Error::ShapeMismatchBinaryOp {
2566                lhs: self.shape().clone(),
2567                rhs: shape,
2568                op: "reshape",
2569            }
2570            .bt());
2571        }
2572        let op = BackpropOp::new1(self, Op::Reshape);
2573        if self.is_contiguous() {
2574            let tensor_ = Tensor_ {
2575                id: TensorId::new(),
2576                storage: self.storage.clone(),
2577                layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
2578                op,
2579                is_variable: false,
2580                dtype: self.dtype,
2581                device: self.device.clone(),
2582            };
2583            Ok(Tensor(Arc::new(tensor_)))
2584        } else {
2585            let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
2586            self.storage()
2587                .copy_strided_src(&mut storage, 0, self.layout())?;
2588            Ok(from_storage(storage, shape, op, false))
2589        }
2590    }
2591
2592    /// Creates a new tensor with the specified dimension removed if its size was one.
2593    ///
2594    /// ```rust
2595    /// # use hanzo_ml::{Tensor, DType, Device, D};
2596    /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
2597    ///
2598    /// let c = a.squeeze(2)?;
2599    /// assert_eq!(c.shape().dims(), &[2, 3]);
2600    ///
2601    /// let c = a.squeeze(D::Minus1)?;
2602    /// assert_eq!(c.shape().dims(), &[2, 3]);
2603    /// # Ok::<(), hanzo_ml::Error>(())
2604    /// ```
2605    pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
2606        // The PyTorch semantics are to return the same tensor if the target dimension
2607        // does not have a size of 1.
2608        let dims = self.dims();
2609        let dim = dim.to_index(self.shape(), "squeeze")?;
2610        if dims[dim] == 1 {
2611            let mut dims = dims.to_vec();
2612            let mut strides = self.stride().to_vec();
2613            dims.remove(dim);
2614            strides.remove(dim);
2615            let tensor_ = Tensor_ {
2616                id: TensorId::new(),
2617                storage: self.storage.clone(),
2618                layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2619                op: BackpropOp::new1(self, Op::Reshape),
2620                is_variable: false,
2621                dtype: self.dtype,
2622                device: self.device.clone(),
2623            };
2624            Ok(Tensor(Arc::new(tensor_)))
2625        } else {
2626            Ok(self.clone())
2627        }
2628    }
2629
2630    /// Creates a new tensor with a dimension of size one inserted at the specified position.
2631    ///
2632    /// ```rust
2633    /// # use hanzo_ml::{Tensor, DType, Device, D};
2634    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2635    ///
2636    /// let c = a.unsqueeze(0)?;
2637    /// assert_eq!(c.shape().dims(), &[1, 2, 3]);
2638    ///
2639    /// let c = a.unsqueeze(D::Minus1)?;
2640    /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
2641    /// # Ok::<(), hanzo_ml::Error>(())
2642    /// ```
2643    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
2644        let mut dims = self.dims().to_vec();
2645        let mut strides = self.stride().to_vec();
2646        let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
2647        // Cannot panic because to_index_plus_one already checks dimensions
2648        dims.insert(dim, 1);
2649        // Any stride would work here, but we pick one so as to maximize the probability to remain
2650        // C contiguous.
2651        let stride = if dim < strides.len() { strides[dim] } else { 1 };
2652        strides.insert(dim, stride);
2653        let tensor_ = Tensor_ {
2654            id: TensorId::new(),
2655            storage: self.storage.clone(),
2656            layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
2657            op: BackpropOp::new1(self, Op::Reshape),
2658            is_variable: false,
2659            dtype: self.dtype,
2660            device: self.device.clone(),
2661        };
2662        Ok(Tensor(Arc::new(tensor_)))
2663    }
2664
2665    /// Stacks two or more tensors along a particular dimension.
2666    ///
2667    /// All tensors must have the same rank, and the output has one additional rank
2668    ///
2669    /// ```rust
2670    /// # use hanzo_ml::{Tensor, DType, Device};
2671    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2672    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
2673    ///
2674    /// let c = Tensor::stack(&[&a, &b], 0)?;
2675    /// assert_eq!(c.shape().dims(), &[2, 2, 3]);
2676    ///
2677    /// let c = Tensor::stack(&[&a, &b], 2)?;
2678    /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
2679    /// # Ok::<(), hanzo_ml::Error>(())
2680    /// ```
2681    pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
2682        if args.is_empty() {
2683            Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())?
2684        }
2685        let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
2686        let args = args
2687            .iter()
2688            .map(|t| t.as_ref().unsqueeze(dim))
2689            .collect::<Result<Vec<_>>>()?;
2690        Self::cat(&args, dim)
2691    }
2692
2693    /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
2694    /// input tensor values and `right` elements after.
2695    pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2696        if left == 0 && right == 0 {
2697            Ok(self.clone())
2698        } else if left == 0 {
2699            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2700            let mut dims = self.dims().to_vec();
2701            dims[dim] = right;
2702            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2703            Tensor::cat(&[self, &right], dim)
2704        } else if right == 0 {
2705            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2706            let mut dims = self.dims().to_vec();
2707            dims[dim] = left;
2708            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2709            Tensor::cat(&[&left, self], dim)
2710        } else {
2711            let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
2712            let mut dims = self.dims().to_vec();
2713            dims[dim] = left;
2714            let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2715            dims[dim] = right;
2716            let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
2717            Tensor::cat(&[&left, self, &right], dim)
2718        }
2719    }
2720
2721    /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
2722    /// input tensor values and `right` elements after.
2723    pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
2724        if left == 0 && right == 0 {
2725            Ok(self.clone())
2726        } else if self.elem_count() == 0 {
2727            bail!("cannot use pad_with_same on an empty tensor")
2728        } else if left == 0 {
2729            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2730            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2731            let mut v = vec![self];
2732            for _ in 0..right {
2733                v.push(&r)
2734            }
2735            Tensor::cat(&v, dim)
2736        } else if right == 0 {
2737            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2738            let l = self.narrow(dim, 0, 1)?;
2739            let mut v = vec![];
2740            for _ in 0..left {
2741                v.push(&l)
2742            }
2743            v.push(self);
2744            Tensor::cat(&v, dim)
2745        } else {
2746            let dim = dim.to_index(self.shape(), "pad_with_same")?;
2747            let l = self.narrow(dim, 0, 1)?;
2748            let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
2749            let mut v = vec![];
2750            for _ in 0..left {
2751                v.push(&l)
2752            }
2753            v.push(self);
2754            for _ in 0..right {
2755                v.push(&r)
2756            }
2757            Tensor::cat(&v, dim)
2758        }
2759    }
2760
2761    /// Run the `forward` method of `m` on `self`.
2762    pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
2763        m.forward(self)
2764    }
2765
2766    /// Run the `forward` method of `m` on `self`.
2767    pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
2768        m.forward_t(self, train)
2769    }
2770
2771    pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
2772        self.storage.read().unwrap()
2773    }
2774
2775    pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
2776        self.storage.write().unwrap()
2777    }
2778
2779    // If we extend the visibility of this function to be usable outside of this crate, we should
2780    // make it unsafe.
2781    pub(crate) fn storage_mut_and_layout(
2782        &self,
2783    ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
2784        let storage = self.storage.write().unwrap();
2785        (storage, &self.layout)
2786    }
2787
2788    /// The storage used by this tensor, together with the layout to use to access it safely.
2789    pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
2790        let storage = self.storage.read().unwrap();
2791        (storage, &self.layout)
2792    }
2793
2794    pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
2795        let lhs: &RwLock<Storage> = self.storage.as_ref();
2796        let rhs: &RwLock<Storage> = rhs.storage.as_ref();
2797        std::ptr::eq(lhs, rhs)
2798    }
2799
2800    /// Normalize a 'relative' axis value: positive values are kept, negative
2801    /// values means counting the dimensions from the back.
2802    pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
2803        let rank = self.rank() as i64;
2804        if rank <= axis {
2805            bail!("axis {axis} is too large, tensor rank {rank}")
2806        } else if 0 <= axis {
2807            Ok(axis as usize)
2808        } else {
2809            let naxis = rank + axis;
2810            if naxis < 0 {
2811                bail!("axis {axis} is too small, tensor rank {rank}")
2812            }
2813            Ok(naxis as usize)
2814        }
2815    }
2816
2817    /// Returns a lower triangular matrix of ones of size n by n.
2818    pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2819        let t = Tensor::arange(0u32, n as u32, device)?;
2820        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2821        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2822        t1.le(&t2)?.to_dtype(dtype)
2823    }
2824
2825    /// Returns an upper triangular matrix of ones of size n by n.
2826    pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2827        let t = Tensor::arange(0u32, n as u32, device)?;
2828        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2829        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2830        t1.ge(&t2)?.to_dtype(dtype)
2831    }
2832
2833    /// Returns a matrix with a diagonal of ones of size n by n.
2834    pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
2835        let t = Tensor::arange(0u32, n as u32, device)?;
2836        let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
2837        let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
2838        t1.eq(&t2)?.to_dtype(dtype)
2839    }
2840
2841    /// Returns the cumulative sum of elements of the input tensor summed over the specified
2842    /// dimension.
2843    ///
2844    /// This operation is most efficient when dim is the last dimension of the tensor.
2845    pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
2846        let dim = dim.to_index(self.shape(), "cumsum")?;
2847        let rank = self.rank();
2848        if rank == 0 {
2849            return Ok(self.clone());
2850        }
2851        let n_axis = self.dim(dim)?;
2852        let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
2853        if rank == 1 {
2854            self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
2855        } else {
2856            let last = rank - 1;
2857            let t = self.transpose(dim, last)?;
2858            let t = t.broadcast_matmul(&triu)?;
2859            t.transpose(dim, last)
2860        }
2861    }
2862
2863    /// Returns a copy of `self` where the values within `ranges` have been replaced with the
2864    /// content of `src`.
2865    pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2866        &self,
2867        ranges: &[D],
2868        src: &Tensor,
2869    ) -> Result<Self> {
2870        let src_dims = src.dims();
2871        let self_dims = self.dims();
2872        if self_dims.len() != src_dims.len() {
2873            bail!(
2874                "slice-assign requires input with the same rank {} <> {}",
2875                self_dims.len(),
2876                src_dims.len()
2877            )
2878        }
2879        if self_dims.len() != ranges.len() {
2880            bail!(
2881                "slice-assign requires input with the same rank as there are ranges {} <> {}",
2882                self_dims.len(),
2883                ranges.len()
2884            )
2885        }
2886        let mut src = src.clone();
2887        let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2888        for (i, range) in ranges.iter().enumerate() {
2889            let start_included = match range.start_bound() {
2890                std::ops::Bound::Unbounded => 0,
2891                std::ops::Bound::Included(v) => *v,
2892                std::ops::Bound::Excluded(v) => *v + 1,
2893            };
2894            let end_excluded = match range.end_bound() {
2895                std::ops::Bound::Unbounded => self_dims[i],
2896                std::ops::Bound::Included(v) => *v + 1,
2897                std::ops::Bound::Excluded(v) => *v,
2898            };
2899            if end_excluded <= start_included {
2900                bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
2901            }
2902            if self_dims[i] < end_excluded {
2903                bail!(
2904                    "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2905                    self_dims[i]
2906                )
2907            }
2908            if end_excluded - start_included != src_dims[i] {
2909                bail!(
2910                    "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2911                )
2912            }
2913            src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2914            mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2915        }
2916        mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
2917    }
2918
2919    /// Returns log(sum(exp(tensor), dim)).
2920    pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
2921        let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
2922        if sum_dims.is_empty() {
2923            return Ok(self.clone());
2924        }
2925        let max = sum_dims[1..]
2926            .iter()
2927            .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
2928                max.max_keepdim(dim)
2929            })?;
2930        let exp = self.broadcast_sub(&max)?.exp()?;
2931        let sum = exp.sum(sum_dims.clone())?;
2932
2933        sum.log()? + max.squeeze_dims(&sum_dims)
2934    }
2935
2936    /// Pointwise pow operation.
2937    pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
2938        rhs.mul(&self.log()?)?.exp()
2939    }
2940
2941    /// Broadcasting version of `pow`.
2942    pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
2943        rhs.broadcast_mul(&self.log()?)?.exp()
2944    }
2945
2946    /// Returns a new tensor with the order of elements reversed along the specified dimensions.
2947    /// This function makes a copy of the tensor’s data.
2948    ///
2949    /// ```rust
2950    /// # use hanzo_ml::{Tensor, Device};
2951    /// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
2952    /// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
2953    /// let t_flipped = t.flip(&[0])?;
2954    /// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
2955    /// # Ok::<(), hanzo_ml::Error>(())
2956    /// ```
2957    pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
2958        let mut result = self.clone();
2959        for &dim in dims.iter() {
2960            let size = result.dim(dim)?;
2961            let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
2962            let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
2963            result = result.index_select(&indices_tensor, dim)?;
2964        }
2965        Ok(result)
2966    }
2967
2968    /// Returns a view of which contains all slices of size `size` from self tensor in the dimension
2969    /// `dim` and stepped by `step`.
2970    pub fn unfold<D: Dim>(&self, dim: D, size: usize, step: usize) -> Result<Self> {
2971        // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804
2972        let mut sizes = self.dims().to_vec();
2973        let mut strides = self.stride().to_vec();
2974
2975        let dim = dim.to_index(self.shape(), "unfold")?;
2976
2977        let max_len = if self.dims().is_empty() {
2978            1
2979        } else {
2980            sizes[dim]
2981        };
2982        if size > max_len {
2983            bail!(
2984                "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}"
2985            )
2986        }
2987        sizes.push(size);
2988        strides.push(if self.dims().is_empty() {
2989            1
2990        } else {
2991            strides[dim]
2992        });
2993
2994        if !self.dims().is_empty() {
2995            sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize;
2996            strides[dim] *= step;
2997        }
2998
2999        let tensor_ = Tensor_ {
3000            id: TensorId::new(),
3001            storage: self.storage.clone(),
3002            layout: Layout::new(sizes.into(), strides, self.layout.start_offset()),
3003            op: BackpropOp::new1(self, Op::Reshape),
3004            is_variable: false,
3005            dtype: self.dtype,
3006            device: self.device.clone(),
3007        };
3008        Ok(Tensor(Arc::new(tensor_)))
3009    }
3010}
3011
3012macro_rules! bin_trait {
3013    ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => {
3014        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor {
3015            type Output = Result<Tensor>;
3016
3017            fn $fn1(self, rhs: B) -> Self::Output {
3018                Tensor::$fn1(&self, rhs.borrow())
3019            }
3020        }
3021
3022        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for &Tensor {
3023            type Output = Result<Tensor>;
3024
3025            fn $fn1(self, rhs: B) -> Self::Output {
3026                Tensor::$fn1(&self, rhs.borrow())
3027            }
3028        }
3029
3030        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
3031            type Output = Result<Tensor>;
3032
3033            fn $fn1(self, rhs: Tensor) -> Self::Output {
3034                Tensor::$fn1(self?.borrow(), &rhs)
3035            }
3036        }
3037
3038        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
3039            type Output = Result<Tensor>;
3040
3041            fn $fn1(self, rhs: &Tensor) -> Self::Output {
3042                Tensor::$fn1(self?.borrow(), rhs)
3043            }
3044        }
3045
3046        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
3047            type Output = Result<Tensor>;
3048
3049            fn $fn1(self, rhs: Result<B>) -> Self::Output {
3050                Tensor::$fn1(&self, rhs?.borrow())
3051            }
3052        }
3053
3054        impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for &Tensor {
3055            type Output = Result<Tensor>;
3056
3057            fn $fn1(self, rhs: Result<B>) -> Self::Output {
3058                Tensor::$fn1(&self, rhs?.borrow())
3059            }
3060        }
3061
3062        impl std::ops::$trait<f64> for Tensor {
3063            type Output = Result<Tensor>;
3064
3065            fn $fn1(self, rhs: f64) -> Self::Output {
3066                self.affine($mul(rhs), $add(rhs))
3067            }
3068        }
3069
3070        impl std::ops::$trait<f64> for &Tensor {
3071            type Output = Result<Tensor>;
3072
3073            fn $fn1(self, rhs: f64) -> Self::Output {
3074                self.affine($mul(rhs), $add(rhs))
3075            }
3076        }
3077    };
3078}
3079
3080bin_trait!(Add, add, |_| 1., |v| v);
3081bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
3082bin_trait!(Mul, mul, |v| v, |_| 0.);
3083bin_trait!(Div, div, |v| 1. / v, |_| 0.);
3084
3085impl std::ops::Add<Tensor> for f64 {
3086    type Output = Result<Tensor>;
3087
3088    fn add(self, rhs: Tensor) -> Self::Output {
3089        rhs + self
3090    }
3091}
3092
3093impl std::ops::Add<&Tensor> for f64 {
3094    type Output = Result<Tensor>;
3095
3096    fn add(self, rhs: &Tensor) -> Self::Output {
3097        rhs + self
3098    }
3099}
3100
3101impl std::ops::Mul<Tensor> for f64 {
3102    type Output = Result<Tensor>;
3103
3104    fn mul(self, rhs: Tensor) -> Self::Output {
3105        rhs * self
3106    }
3107}
3108
3109impl std::ops::Mul<&Tensor> for f64 {
3110    type Output = Result<Tensor>;
3111
3112    fn mul(self, rhs: &Tensor) -> Self::Output {
3113        rhs * self
3114    }
3115}
3116
3117impl std::ops::Sub<Tensor> for f64 {
3118    type Output = Result<Tensor>;
3119
3120    fn sub(self, rhs: Tensor) -> Self::Output {
3121        rhs.affine(-1., self)
3122    }
3123}
3124
3125impl std::ops::Sub<&Tensor> for f64 {
3126    type Output = Result<Tensor>;
3127
3128    fn sub(self, rhs: &Tensor) -> Self::Output {
3129        rhs.affine(-1., self)
3130    }
3131}
3132
3133impl std::ops::Div<Tensor> for f64 {
3134    type Output = Result<Tensor>;
3135
3136    #[allow(clippy::suspicious_arithmetic_impl)]
3137    fn div(self, rhs: Tensor) -> Self::Output {
3138        rhs.recip()? * self
3139    }
3140}
3141
3142impl std::ops::Div<&Tensor> for f64 {
3143    type Output = Result<Tensor>;
3144
3145    #[allow(clippy::suspicious_arithmetic_impl)]
3146    fn div(self, rhs: &Tensor) -> Self::Output {
3147        rhs.recip()? * self
3148    }
3149}
3150
3151impl<S: Into<Shape>> From<(Storage, S)> for Tensor {
3152    fn from((storage, shape): (Storage, S)) -> Self {
3153        from_storage(storage, shape, BackpropOp::none(), false)
3154    }
3155}