burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2use crate::backend::ExecutionError;
3use crate::check::unwrap_shape_reshape;
4
5pub use burn_backend::tensor::BasicOps;
6
7use alloc::vec::Vec;
8
9use alloc::format;
10use alloc::string::String;
11use alloc::vec;
12
13use burn_std::stub::RwLock;
14use core::iter::repeat;
15use core::{fmt::Debug, ops::Range};
16use serde::{Deserialize, Deserializer};
17
18use crate::{AsIndex, Slice, SliceArg, wrap_index};
19use crate::{
20    Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,
21    backend::Backend, check,
22};
23use crate::{DType, Element};
24use crate::{IndexingUpdateOp, TensorCreationOptions};
25use crate::{cast::ToElement, check::TensorCheck};
26use serde::{Serialize, Serializer};
27
28/// A tensor with a given backend, shape and data type.
29///
30/// # Indexing
31/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
32/// or [`select`](Tensor::select) for numeric types.
33///
34/// ## Example
35///
36/// ```rust
37/// use burn_tensor::backend::Backend;
38/// use burn_tensor::Tensor;
39/// use burn_tensor::Int;
40///
41/// fn example<B: Backend>() {
42///     let device = Default::default();
43///
44///     let tensor = Tensor::<B, 2>::from_data(
45///         [
46///             [3.0, 4.9, 2.0],
47///             [2.0, 1.9, 3.0],
48///             [6.0, 1.5, 7.0],
49///             [3.0, 4.9, 9.0],
50///         ],
51///         &device,
52///     );
53///
54///     // Slice the tensor to get the second and third rows:
55///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
56///     // The resulting tensor will have dimensions [2, 3].
57///     let slice = tensor.clone().slice([1..3]);
58///     println!("{slice}");
59///
60///     // Slice the tensor to get the first two rows and the first 2 columns:
61///     // [[3.0, 4.9], [2.0, 1.9]]
62///     // The resulting tensor will have dimensions [2, 2].
63///     let slice = tensor.clone().slice([0..2, 0..2]);
64///     println!("{slice}");
65///
66///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
67///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
68///     // The resulting tensor will have dimensions [4, 2]
69///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
70///     let indexed = tensor.select(1, indices);
71///     println!("{indexed}");
72/// }
73/// ```
74#[derive(new, Clone, Debug)]
75pub struct Tensor<B, const D: usize, K = Float>
76where
77    B: Backend,
78    K: TensorKind<B>,
79{
80    pub(crate) primitive: K::Primitive,
81}
82
83impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
84where
85    B: Backend,
86    K: BasicOps<B>,
87    T: Into<TensorData>,
88{
89    fn from(value: T) -> Self {
90        Tensor::from_data(value.into(), &Default::default())
91    }
92}
93
94impl<B, const D: usize, K> Tensor<B, D, K>
95where
96    B: Backend,
97    K: BasicOps<B>,
98    K::Elem: Element,
99{
100    /// Executes an operation on the tensor and modifies its value.
101    ///
102    /// # Notes
103    ///
104    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
105    /// no other reference pointing to the same tensor.
106    ///
107    /// Wrapping operations with inplace is not an optimization, it's mainly there if you
108    /// want to mutate a tensor by using owned operations. A plausible usage would be to
109    /// update the weights of a mutable model reference.
110    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
111        let mut tensor_owned = Tensor::empty([0; D], &self.device());
112        core::mem::swap(&mut tensor_owned, self);
113
114        let mut tensor_new = func(tensor_owned);
115        core::mem::swap(&mut tensor_new, self);
116    }
117
118    /// Converts the tensor into a primitive tensor.
119    pub fn into_primitive(self) -> K::Primitive {
120        self.primitive
121    }
122
123    /// Converts from a primitive tensor into a tensor.
124    pub fn from_primitive(tensor: K::Primitive) -> Self {
125        Self::new(tensor)
126    }
127
128    /// Returns the number of dimensions of the tensor.
129    pub fn rank(&self) -> usize {
130        self.primitive.rank()
131    }
132
133    /// Returns the tensor primitive data type.
134    ///
135    /// # Note
136    /// Some element types are encoded in different primitive types depending on the backend
137    /// (e.g., bool could be encoded as `u8` or `u32`).
138    pub fn dtype(&self) -> DType {
139        self.primitive.dtype()
140    }
141
142    /// Create an empty tensor of the given shape.
143    ///
144    /// # Arguments
145    ///
146    /// - `shape`: The shape of the tensor.
147    /// - `device`: The device where the tensor will be created.
148    ///
149    /// # Example
150    /// ```rust
151    /// use burn_tensor::backend::Backend;
152    /// use burn_tensor::Tensor;
153    ///
154    /// fn example<B: Backend>() {
155    ///    let device = Default::default();
156    ///    // Create an empty tensor with dimensions [2, 3, 4].
157    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
158    /// }
159    /// ```
160    pub fn empty<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
161        let opt = options.into();
162        let shape = shape.into();
163        check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
164        Self::new(K::empty(shape, &opt.device, opt.dtype_or(K::Elem::dtype())))
165    }
166
167    /// Create a tensor of the given shape where each element is zero.
168    ///
169    /// # Example
170    ///
171    /// ```rust
172    /// use burn_tensor::backend::Backend;
173    /// use burn_tensor::{Tensor, Shape};
174    ///
175    /// fn example<B: Backend>() {
176    ///    let device = B::Device::default();
177    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
178    ///    println!("{tensor}");
179    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
180    /// }
181    /// ```
182    pub fn zeros<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
183        let opt = options.into();
184        let shape = shape.into();
185        check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
186        Self::new(K::zeros(shape, &opt.device, opt.dtype_or(K::Elem::dtype())))
187    }
188
189    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
190    ///
191    /// # Example
192    ///
193    /// ```rust
194    /// use burn_tensor::backend::Backend;
195    /// use burn_tensor::{Tensor, Shape};
196    ///
197    /// fn example<B: Backend>() {
198    ///   let device = B::Device::default();
199    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
200    ///   let tensor = tensor.zeros_like();
201    ///   println!("{tensor}");
202    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
203    /// }
204    /// ```
205    pub fn zeros_like(&self) -> Self {
206        Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
207    }
208
209    /// Create a tensor of the given shape where each element is one.
210    ///
211    /// # Example
212    ///
213    /// ```rust
214    /// use burn_tensor::backend::Backend;
215    /// use burn_tensor::{Tensor, Shape};
216    ///
217    /// fn example<B: Backend>() {
218    ///   let device = B::Device::default();
219    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
220    ///   println!("{tensor}");
221    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
222    /// }
223    /// ```
224    pub fn ones<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
225        let opt = options.into();
226        let shape = shape.into();
227        check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
228        Self::new(K::ones(shape, &opt.device, opt.dtype_or(K::Elem::dtype())))
229    }
230
231    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
232    ///
233    /// # Example
234    ///
235    /// ```rust
236    /// use burn_tensor::backend::Backend;
237    /// use burn_tensor::{Tensor, Shape};
238    ///
239    /// fn example<B: Backend>() {
240    ///    let device = B::Device::default();
241    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
242    ///    let tensor = tensor.ones_like();
243    ///    println!("{tensor}");
244    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
245    /// }
246    /// ```
247    pub fn ones_like(&self) -> Self {
248        Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
249    }
250
251    /// Create a tensor of the given shape where each element is equal to the provided value.
252    ///
253    /// # Example
254    ///
255    /// ```rust
256    /// use burn_tensor::backend::Backend;
257    /// use burn_tensor::{Tensor, Shape};
258    ///
259    /// fn example<B: Backend>() {
260    ///   let device = B::Device::default();
261    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
262    ///   println!("{tensor}");
263    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
264    /// }
265    /// ```
266    pub fn full<S: Into<Shape>, E: ElementConversion>(
267        shape: S,
268        fill_value: E,
269        options: impl Into<TensorCreationOptions<B>>,
270    ) -> Self {
271        let opt = options.into();
272        let shape = shape.into();
273        check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
274        Self::new(K::full(
275            shape,
276            fill_value,
277            &opt.device,
278            opt.dtype_or(K::Elem::dtype()),
279        ))
280    }
281
282    /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
283    /// filled with the provided value.
284    ///
285    /// # Example
286    ///
287    /// ```rust
288    /// use burn_tensor::backend::Backend;
289    /// use burn_tensor::{Tensor, Shape};
290    ///
291    /// fn example<B: Backend>() {
292    ///    let device = B::Device::default();
293    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
294    ///    let tensor = tensor.full_like(5.0);
295    ///    println!("{tensor}");
296    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
297    /// }
298    /// ```
299    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
300        Self::new(K::full(
301            self.shape(),
302            fill_value,
303            &self.device(),
304            self.dtype(),
305        ))
306    }
307
308    /// Returns the dimensions of the current tensor.
309    ///
310    /// # Example
311    /// ```rust
312    /// use burn_tensor::backend::Backend;
313    /// use burn_tensor::Tensor;
314    ///
315    /// fn example<B: Backend>() {
316    ///   let device = Default::default();
317    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
318    ///   let dims = tensor.dims(); // [2, 3, 4]
319    ///   println!("{dims:?}");
320    /// }
321    /// ```
322    pub fn dims(&self) -> [usize; D] {
323        Self::shape(self).dims()
324    }
325
326    /// Returns the shape of the current tensor.
327    ///
328    /// # Example
329    /// ```rust
330    /// use burn_tensor::backend::Backend;
331    /// use burn_tensor::Tensor;
332    ///
333    /// fn example<B: Backend>() {
334    ///    let device = Default::default();
335    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
336    ///    // Shape { dims: [2, 3, 4] }
337    ///    let shape = tensor.shape();
338    /// }
339    /// ```
340    pub fn shape(&self) -> Shape {
341        self.primitive.shape()
342    }
343
344    /// Reshape the tensor to have the given shape.
345    ///
346    /// The tensor has the same data and number of elements as the input.
347    ///
348    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
349    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
350    ///
351    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
352    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
353    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
354    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
355    /// with [1, 3, 4] dimensions to [1, 12].
356    ///
357    /// # Arguments
358    /// - `shape`: The new shape of the tensor.
359    ///
360    /// # Panics
361    /// - If the tensor contains more than one `-1` in the shape.
362    /// - If the tensor contains values that are not positive (other than -1).
363    /// - If the shape does not match the number of elements of the original shape.
364    ///
365    /// # Example
366    ///
367    /// ```rust
368    /// use burn_tensor::backend::Backend;
369    /// use burn_tensor::Tensor;
370    ///
371    /// fn example<B: Backend>() {
372    ///    let device = Default::default();
373    ///    // Create a tensor with dimensions [2, 3, 4]
374    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
375    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
376    ///    let reshaped = tensor.reshape([2, -1]);
377    ///    println!("{reshaped}");
378    /// }
379    /// ```
380    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
381        // Convert reshape args to shape
382        let shape = shape.into_shape::<D2>(self.shape());
383        Tensor::new(K::reshape(self.primitive, shape))
384    }
385
386    /// Transpose the tensor.
387    ///
388    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
389    /// applied on the last two dimensions. For example, the transpose of a tensor with shape
390    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
391    ///
392    /// See also [`permute`](Tensor::permute).
393    ///
394    /// # Arguments
395    ///
396    /// * `tensor` - The tensor to transpose.
397    ///
398    /// # Returns
399    ///
400    /// The transposed tensor.
401    ///
402    /// # Example
403    ///
404    /// ```rust
405    /// use burn_tensor::backend::Backend;
406    /// use burn_tensor::Tensor;
407    ///
408    /// fn example<B: Backend>() {
409    ///     let device = Default::default();
410    ///     // Create a 2D tensor of shape [2, 3]
411    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
412    ///
413    ///     // Transpose the tensor:
414    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
415    ///     // The resulting tensor will have dimensions [3, 2].
416    ///     let transposed = tensor.transpose();
417    ///     println!("{transposed}");
418    /// }
419    /// ```
420    pub fn transpose(self) -> Tensor<B, D, K> {
421        Tensor::new(K::transpose(self.primitive))
422    }
423
424    /// Alias for `transpose`.
425    #[inline(always)]
426    pub fn t(self) -> Tensor<B, D, K> {
427        self.transpose()
428    }
429
430    /// Swaps two dimensions of a tensor.
431    ///
432    /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
433    ///
434    /// # Arguments
435    ///
436    /// * `tensor` - The tensor to swap the dimensions of.
437    /// * `dim1` - The first dimension to swap, supports negative indexing.
438    /// * `dim2` - The second dimension to swap, supports negative indexing.
439    ///
440    /// # Returns
441    ///
442    /// The tensor with the dimensions swapped.
443    ///
444    /// # Panics
445    ///
446    /// When dimensions are out of bounds.
447    ///
448    /// # Example
449    ///
450    /// ```rust
451    /// use burn_tensor::backend::Backend;
452    /// use burn_tensor::Tensor;
453    ///
454    /// fn example<B: Backend>() {
455    ///     let device = Default::default();
456    ///     // Create a 2D tensor of shape [2, 3]
457    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
458    ///
459    ///     // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
460    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
461    ///     // The resulting tensor will have dimensions [3, 2].
462    ///     let swapped = tensor.swap_dims(0, -1);
463    ///     println!("{swapped}");
464    /// }
465    /// ```
466    pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
467    where
468        Dim1: AsIndex,
469        Dim2: AsIndex,
470    {
471        let dim1 = dim1.expect_dim_index(D);
472        let dim2 = dim2.expect_dim_index(D);
473        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
474        if dim1 == dim2 {
475            self
476        } else {
477            Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
478        }
479    }
480
481    /// Permute the dimensions of the tensor.
482    ///
483    /// This is a no-op when the resolved `axes` match the current order.
484    ///
485    /// # Arguments
486    ///
487    /// * `axes` - The new order of the dimensions. The length of the axes
488    ///   must be equal to the number of dimensions of the tensor.
489    ///   The values must be unique and in the range of the number of dimensions.
490    ///   The values can be negative, in which case they are used as an offset from the end.
491    ///
492    /// # Returns
493    ///
494    /// The tensor with the dimensions permuted.
495    ///
496    /// # Example
497    ///
498    /// ```rust
499    /// use burn_tensor::backend::Backend;
500    /// use burn_tensor::Tensor;
501    ///
502    /// fn example<B: Backend>() {
503    ///     let device = Default::default();
504    ///     // Create a 2D tensor of shape [3, 2]
505    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
506    ///
507    ///     // Permute the dimensions 1 and 0:
508    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
509    ///     // The resulting tensor will have dimensions [3, 2].
510    ///     let permuted = tensor.permute([1, 0]);
511    ///     println!("{permuted}");
512    /// }
513    /// ```
514    pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
515    where
516        Dim: AsIndex,
517    {
518        let mut no_op = true;
519        let mut fixed_axes = [0; D];
520        for (i, axis) in axes.into_iter().enumerate() {
521            let dim = axis.expect_dim_index(D);
522            no_op &= dim == i;
523            fixed_axes[i] = dim;
524        }
525
526        if no_op {
527            self
528        } else {
529            check!(TensorCheck::permute(fixed_axes));
530            Tensor::new(K::permute(self.primitive, &fixed_axes))
531        }
532    }
533
534    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
535    ///
536    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
537    /// at the positions not specified in destination.
538    ///
539    /// # Arguments
540    ///
541    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
542    ///   The values can be negative, in which case they are used as an offset from the end.
543    ///
544    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
545    ///
546    /// # Panics
547    ///
548    /// - If the source and destination dimensions are not of the same length.
549    /// - If the source and destination vectors contain duplicate values.
550    /// - If the source and destination vectors contain values that are out of bounds.
551    ///
552    /// # Returns
553    ///
554    /// The tensor with the dimensions moved.
555    ///
556    /// # Example
557    ///
558    /// ```rust
559    /// use burn_tensor::backend::Backend;
560    /// use burn_tensor::Tensor;
561    ///
562    /// fn example<B: Backend>() {
563    ///     let device = Default::default();
564    ///     // Create a 3D tensor of shape [3, 2, 1]
565    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
566    ///
567    ///     // Move the dimensions 0 and 1:
568    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
569    ///     // The resulting tensor will have dimensions [2, 3, 1].
570    ///     let moved = tensor.movedim(1, 0);
571    ///     println!("{moved}");
572    /// }
573    /// ```
574    ///
575    /// # Note
576    ///
577    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
578    /// for it
579    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
580        let source_dims = src.into_dim_vec::<D>();
581        let destination_dims = dst.into_dim_vec::<D>();
582
583        check!(TensorCheck::movedim_args_length(
584            &source_dims,
585            &destination_dims
586        ));
587
588        let mut m = [-1; D];
589        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
590            m[d] = s as isize;
591        }
592        let mut axes: [isize; D] = [0; D];
593        let mut source_i = 0;
594        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
595            *item = if m[dest_i] != -1 {
596                m[dest_i]
597            } else {
598                while source_dims.contains(&source_i) {
599                    source_i += 1;
600                }
601                let result = source_i as isize;
602                source_i += 1;
603                result
604            };
605        }
606
607        self.permute(axes)
608    }
609
610    /// Reverse the order of elements in the tensor along the given dimensions.
611    ///
612    /// # Arguments
613    ///
614    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
615    ///   The values can be negative, in which case they are used as an offset from the end.
616    ///
617    /// # Returns
618    ///
619    /// The tensor with the axes flipped.
620    ///
621    /// # Example
622    ///
623    /// ```rust
624    /// use burn_tensor::backend::Backend;
625    /// use burn_tensor::Tensor;
626    ///
627    /// fn example<B: Backend>() {
628    ///     let device = Default::default();
629    ///     // Create a 2D tensor with dimensions [4, 3]
630    ///     let tensor = Tensor::<B, 2>::from_data(
631    ///         [
632    ///             [3.0, 4.9, 2.0],
633    ///             [2.0, 1.9, 3.0],
634    ///             [4.0, 5.9, 8.0],
635    ///             [1.4, 5.8, 6.0],
636    ///         ],
637    ///         &device,
638    ///     );
639    ///
640    ///     // Flip the elements in dimensions 0 and 1:
641    ///     // [[6.0, 5.8, 1.4],
642    ///     //  [8.0, 5.9, 4.0],
643    ///     //  [3.0, 1.9, 2.0],
644    ///     //  [2.0, 4.9, 3.0]]
645    ///     // The resulting tensor will have dimensions [4, 3].
646    ///     let flipped = tensor.flip([0, 1]);
647    ///     println!("{flipped}");
648    /// }
649    /// ```
650    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
651        // Convert the axes to usize and handle negative values without using vector
652        let mut transformed_axes: [usize; N] = [0; N];
653        for (i, &x) in axes.iter().enumerate() {
654            transformed_axes[i] = if x < 0 {
655                (D as isize + x) as usize
656            } else {
657                x as usize
658            };
659        }
660
661        // Check if the axes are valid
662        check!(TensorCheck::flip(D, &transformed_axes));
663
664        Tensor::new(K::flip(self.primitive, &transformed_axes))
665    }
666
667    /// Flatten the tensor along a given range of dimensions.
668    ///
669    /// This function collapses the specified range of dimensions into a single dimension,
670    /// effectively flattening the tensor in that range.
671    ///
672    /// # Arguments
673    ///
674    /// - `start_dim`: The starting dimension of the range to be flattened,
675    ///   supports negative indexing.
676    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
677    ///   supports negative indexing.
678    ///
679    /// # Type Parameters
680    ///
681    /// - `D2`: The resulting number of dimensions in the flattened tensor.
682    ///
683    /// # Returns
684    ///
685    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
686    ///
687    /// # Example
688    ///
689    /// ```rust
690    ///
691    /// use burn_tensor::backend::Backend;
692    /// use burn_tensor::{Tensor, Shape};
693    ///
694    /// fn example<B: Backend>() {
695    ///     let device = Default::default();
696    ///     // Create a 3D tensor with dimensions [2, 3, 4]
697    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
698    ///
699    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
700    ///     // The resulting tensor will have dimensions [2, 12]
701    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
702    ///     println!("{flattened}");
703    /// }
704    /// ```
705    pub fn flatten<const D2: usize>(
706        self,
707        start_dim: impl AsIndex,
708        end_dim: impl AsIndex,
709    ) -> Tensor<B, D2, K> {
710        let start_dim = start_dim.expect_dim_index(D);
711        let end_dim = end_dim.expect_dim_index(D);
712        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
713        let new_shape = self.shape().flatten_dims(start_dim, end_dim);
714
715        Tensor::new(K::reshape(self.primitive, new_shape))
716    }
717
718    /// Squeeze the tensor along all dimensions, removing dimensions
719    /// of size one, and effectively reducing the rank of the tensor.
720    ///
721    /// # Type Parameters
722    ///
723    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
724    ///
725    /// # Returns
726    ///
727    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
728    ///
729    /// # Example
730    ///
731    /// ```rust
732    ///
733    /// use burn_tensor::backend::Backend;
734    /// use burn_tensor::{Tensor, Shape};
735    ///
736    /// fn example<B: Backend>() {
737    ///     let device = Default::default();
738    ///     // Create a 4D tensor with dimensions [1, 3, 1, 3]
739    ///     let tensor = Tensor::<B, 4>::from_data(
740    ///         [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
741    ///         &device,
742    ///     );
743    ///
744    ///     // Squeeze the tensor dimensions.
745    ///     // The resulting tensor will have dimensions [3, 3].
746    ///     let squeezed = tensor.squeeze::<2>();
747    ///     println!("{squeezed}");
748    /// }
749    /// ```
750    pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
751        let new_dims = self
752            .shape()
753            .dims
754            .iter()
755            .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
756            .collect::<Vec<_>>();
757        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
758
759        Tensor::new(K::reshape(self.primitive, new_dims.into()))
760    }
761
762    /// Squeeze the tensor along the given dimension, removing the specified dimension
763    /// of size one, and effectively reducing the rank of the tensor by one.
764    ///
765    /// # Arguments
766    ///
767    /// - `dim`: The dimension to be squeezed.
768    ///
769    /// # Type Parameters
770    ///
771    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
772    ///
773    /// # Panics
774    ///
775    /// If the size in the squeezed dimension is not 1.
776    ///
777    /// # Returns
778    ///
779    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
780    ///
781    /// # Example
782    ///
783    /// ```rust
784    ///
785    /// use burn_tensor::backend::Backend;
786    /// use burn_tensor::{Tensor, Shape};
787    ///
788    /// fn example<B: Backend>() {
789    ///     let device = Default::default();
790    ///     // Create a 3D tensor with dimensions [3, 1, 3]
791    ///     let tensor = Tensor::<B, 3>::from_data(
792    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
793    ///         &device,
794    ///     );
795    ///
796    ///     // Squeeze the dimension 1.
797    ///     // The resulting tensor will have dimensions [3, 3].
798    ///     let squeezed = tensor.squeeze_dim::<2>(1);
799    ///     println!("{squeezed}");
800    /// }
801    /// ```
802    pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
803        check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
804
805        let current_dims = self.shape().dims;
806        let mut new_dims: [usize; D2] = [0; D2];
807
808        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
809        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
810
811        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
812        Tensor::new(K::reshape(self.primitive, new_dims.into()))
813    }
814
815    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
816    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
817    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
818    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
819    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
820    /// from the back.
821    ///
822    /// # Arguments
823    ///
824    /// - `dims`: The dimension(s) to be squeezed.
825    ///
826    /// # Type Parameters
827    ///
828    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
829    ///
830    /// # Returns
831    ///
832    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
833    ///
834    /// # Example
835    ///
836    /// ```rust
837    ///
838    /// use burn_tensor::backend::Backend;
839    /// use burn_tensor::{Tensor, Shape};
840    ///
841    /// fn example<B: Backend>() {
842    ///     let device = Default::default();
843    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
844    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
845    ///
846    ///     // Squeeze the dimensions 1 and 3.
847    ///     // The resulting tensor will have dimensions [2, 4].
848    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
849    ///     println!("{squeezed}");
850    /// }
851    /// ```
852    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
853        let current_dims = self.shape().dims;
854        let mut dim_indices: Vec<usize>;
855
856        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
857        if dims.is_empty() {
858            dim_indices = current_dims
859                .iter()
860                .enumerate()
861                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
862                .collect();
863        } else {
864            // If negative dims, count from the back
865            dim_indices = dims
866                .iter()
867                .map(|&d| {
868                    if d < 0 {
869                        (current_dims.len() as isize + d) as usize
870                    } else {
871                        d as usize
872                    }
873                })
874                .collect();
875        }
876
877        // Sort indices and remove duplicates
878        dim_indices.sort_unstable();
879        dim_indices.dedup();
880
881        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
882        check!(TensorCheck::squeeze_dims_input::<D2>(
883            &dim_indices,
884            &current_dims
885        ));
886
887        // Calculate new dimensions
888        let mut new_dims = Vec::new();
889        for (index, &dim_size) in current_dims.iter().enumerate() {
890            // Exclude the dimension if it's explicitly marked for squeezing
891            if dim_indices.contains(&index) {
892                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
893                continue;
894            }
895            new_dims.push(dim_size);
896        }
897
898        // Check that after squeezing, we still respect the D2 size
899        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
900
901        Tensor::new(K::reshape(self.primitive, new_dims.into()))
902    }
903
904    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
905    ///
906    /// # Type Parameters
907    ///
908    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.
909    ///
910    /// # Panics
911    ///
912    /// If the output size `D2` is smaller than the current number of dimensions.
913    ///
914    /// # Returns
915    ///
916    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
917    ///
918    /// # Example
919    ///
920    /// ```rust
921    /// use burn_tensor::backend::Backend;
922    /// use burn_tensor::{Tensor, Shape};
923    ///
924    /// fn example<B: Backend>() {
925    ///     let device = Default::default();
926    ///     // Create a 2D tensor with dimensions [3, 3]
927    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
928    ///     // Unsqueeze the tensor up to 4 dimensions.
929    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
930    ///     let unsqueezed = tensor.unsqueeze::<4>();
931    ///     println!("{unsqueezed}");
932    /// }
933    /// ```
934    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
935        check!(TensorCheck::unsqueeze::<D, D2>());
936
937        let mut dims = [1; D2];
938        let num_ones = D2 - D;
939        let shape = self.shape();
940
941        dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
942
943        let shape = Shape::new(dims);
944        self.reshape(shape)
945    }
946
947    /// Creates a new tensor with a dimension of size one inserted at the specified position.
948    ///
949    /// # Example
950    ///
951    /// ```rust
952    /// use burn_tensor::backend::Backend;
953    /// use burn_tensor::{Tensor, Shape};
954    ///
955    /// fn example<B: Backend>() {
956    ///     let device = Default::default();
957    ///     // Create a 2D tensor with dimensions [3, 3]
958    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
959    ///     // Unsqueeze the dimension 1.
960    ///     // The resulting tensor will have dimensions [3, 1, 3].
961    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
962    ///     println!("{unsqueezed}");
963    /// }
964    /// ```
965    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
966        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
967
968        let mut dims = [1; D2];
969        let shape = self.shape();
970
971        dims[0..dim].copy_from_slice(&shape[0..dim]);
972
973        if dim < D {
974            dims[dim] = 1;
975            dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
976        } else {
977            dims[dim] = 1;
978        }
979
980        let shape = Shape::new(dims);
981        self.reshape(shape)
982    }
983
984    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
985    /// The indices can be negative, in which case they are counted from the last to the first dimension.
986    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
987    /// is the number of duplicates.
988    /// # Example
989    ///
990    /// ```rust
991    /// use burn_tensor::backend::Backend;
992    /// use burn_tensor::{Tensor, Shape};
993    ///
994    /// fn example<B: Backend>() {
995    ///     let device = Default::default();
996    ///     // Create a 3D tensor with dimensions [3, 4, 5]
997    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
998    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
999    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
1000    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
1001    ///     println!("{unsqueezed}");
1002    /// }
1003    /// ```
1004    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[impl AsIndex]) -> Tensor<B, D2, K> {
1005        let mut new_dims = [1; D2];
1006        let old_dims = self.shape().dims;
1007        //for checking if the dimension is in the acceptable range
1008
1009        //part 1: convert the negative indices to positive
1010        let mut neg_offset = D2;
1011        let mut dim_indices = axes
1012            .iter()
1013            .map(|d| {
1014                let d = d.as_index();
1015                // check if the dimension is in the acceptable range
1016                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(d));
1017                (if d < 0 {
1018                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
1019                    d + neg_offset as isize + 1
1020                } else {
1021                    d
1022                }) as usize
1023            })
1024            .collect::<Vec<usize>>();
1025
1026        //sort the indices
1027        dim_indices.sort_unstable();
1028
1029        //Now use this to copy the chunks of the dims
1030        let mut prev_idx: usize = 0;
1031        let mut current_left_b: usize = 0;
1032        let mut current_right_b: usize = 0;
1033        let mut offset: usize = 0;
1034        dim_indices.iter().for_each(|d| {
1035            //check if there is space for at least one dimension
1036            if prev_idx < *d {
1037                current_right_b = *d - offset;
1038                //copy the chunks of the dims
1039                if current_right_b < D {
1040                    new_dims[prev_idx..*d]
1041                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);
1042                } else {
1043                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
1044                }
1045                prev_idx = *d + 1;
1046                //offset is equal to the number of extracted elements from the original shape
1047                offset += current_right_b - current_left_b;
1048                current_left_b = current_right_b;
1049            } else {
1050                //it's sorted so the only reason this would happen
1051                //is if multiple indices are the same
1052                prev_idx += 1;
1053            }
1054        });
1055        //copy over anything past the index of the last new dimension
1056        if current_left_b < D {
1057            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
1058        }
1059
1060        //lastly, create the shape and reshape
1061        let shape = Shape::new(new_dims);
1062        self.reshape(shape)
1063    }
1064
1065    /// Roll operation along a specific dimension; wrapping around the elements.
1066    ///
1067    /// ## Parameters
1068    ///
1069    /// - `shift`: The roll extent; supports negative values and wraps around.
1070    /// - `dim`: The dimension to roll; supports negative indexing.
1071    ///
1072    /// ## Returns
1073    ///
1074    /// A new tensor with the specified dimension rolled by the given shift amount.
1075    pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
1076    where
1077        Shift: AsIndex,
1078        Dim: AsIndex,
1079    {
1080        let dim = dim.expect_dim_index(D);
1081        let size = self.shape().dims[dim];
1082        if size == 0 {
1083            // If the dimension is empty, return the tensor as is.
1084            return self;
1085        }
1086
1087        let shift = wrap_index(shift, size);
1088        if shift == 0 {
1089            // If the shift is zero, return the tensor as is.
1090            return self;
1091        }
1092
1093        self.unchecked_roll_dim(shift, dim)
1094    }
1095
1096    /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1097    ///
1098    /// ## Parameters
1099    ///
1100    /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1101    /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1102    ///
1103    /// ## Returns
1104    ///
1105    /// A new tensor with the specified dimension rolled by the given shift amount.
1106    #[inline(always)]
1107    fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1108        #[cfg(debug_assertions)]
1109        {
1110            let size = self.shape().dims[dim];
1111            assert!(
1112                0 < shift && shift < size,
1113                "Expected: 0 < shift < size: found shift={shift}, size={size}",
1114            );
1115            assert!(
1116                dim < self.shape().num_dims(),
1117                "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1118            );
1119        }
1120
1121        Tensor::cat(
1122            vec![
1123                self.clone().slice_dim(dim, shift..),
1124                self.slice_dim(dim, ..shift),
1125            ],
1126            dim,
1127        )
1128    }
1129
1130    /// Roll operation.
1131    ///
1132    /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1133    ///
1134    /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1135    ///
1136    /// ## Parameters
1137    ///
1138    /// - `shifts`: A slice of shifts corresponding to each dimension;
1139    ///   supports negative values and wraps around.
1140    /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1141    ///
1142    /// ## Returns
1143    ///
1144    /// A new tensor with the specified dimensions rolled by the given shifts.
1145    pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1146    where
1147        Shift: AsIndex,
1148        Dim: AsIndex,
1149    {
1150        assert_eq!(
1151            dims.len(),
1152            shifts.len(),
1153            "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1154        );
1155
1156        // This is a fair amount of complexity, which could be replaced
1157        // by a simple canonicalization of `dims` and wrapping of `shifts`.
1158        // The work is done here to ensure that any roll operation
1159        // which could be a no-op is a no-op; simplifying the accounting
1160        // needed by backend-specific implementations of the inner roll op.
1161
1162        let item_count = dims.len();
1163
1164        let shape = self.shape().dims;
1165
1166        // Accumulate the effective shifts for each dimension.
1167        let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1168        for i in 0..item_count {
1169            let dim = dims[i].expect_dim_index(D);
1170            accumulated_shifts[dim] += shifts[i].as_index();
1171        }
1172
1173        // Do this after we've checked the validity of `dims` and `shifts`.
1174        if self.shape().num_elements() == 0 {
1175            // If the tensor is empty, return it as is.
1176            return self;
1177        }
1178
1179        // Wrap the accumulated shifts, and filter out empty dimensions.
1180        let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1181        let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1182        for dim in 0..shape.len() {
1183            // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1184            let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1185            if shift == 0 {
1186                continue;
1187            }
1188
1189            effective_dims.push(dim);
1190            effective_shifts.push(shift);
1191        }
1192
1193        // If no shifts are needed, return the original tensor.
1194        if effective_shifts.is_empty() {
1195            return self;
1196        }
1197
1198        // At this point:
1199        // - `dims` contains the effective dimensions to roll, in index order,
1200        // - `shifts` contains the effective usize shifts for each dimension.
1201        // - Every shift is non-zero, and less than the size of the corresponding dimension.
1202        self.unchecked_roll(&effective_shifts, &effective_dims)
1203    }
1204
1205    /// `roll` internal implementation.
1206    ///
1207    /// ## Parameters
1208    ///
1209    /// - `shifts`: A slice of shifts corresponding to each dimension;
1210    ///   must be non-empty, the same length as `dims`, and all ``1..<size>``.
1211    /// - `dims`: A slice of dimensions to roll; must be non-empty;
1212    ///   the same length as `shifts`, and must not contain repeats.
1213    ///
1214    /// ## Panics
1215    ///
1216    /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1217    ///
1218    /// ## Returns
1219    ///
1220    /// A new tensor with the specified dimensions rolled by the given shifts.
1221    #[inline(always)]
1222    fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1223        #[cfg(debug_assertions)]
1224        {
1225            assert!(!shifts.is_empty());
1226            assert_eq!(
1227                shifts.len(),
1228                dims.len(),
1229                "Shifts and dimensions must align; found {} shifts and {} dims",
1230                shifts.len(),
1231                dims.len()
1232            );
1233
1234            let mut unique_dims = dims.to_vec();
1235            unique_dims.dedup();
1236
1237            assert_eq!(
1238                unique_dims.len(),
1239                dims.len(),
1240                "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1241                unique_dims.len(),
1242                dims.len()
1243            )
1244        }
1245
1246        let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1247
1248        if dims.len() == 1 {
1249            x
1250        } else {
1251            x.unchecked_roll(&shifts[1..], &dims[1..])
1252        }
1253    }
1254
1255    /// Returns a tensor containing the elements selected from the given slices.
1256    ///
1257    /// This method provides flexible tensor slicing with support for various range types,
1258    /// negative indices, and stepped slicing. The method accepts both single slices and
1259    /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1260    ///
1261    /// # Arguments
1262    ///
1263    /// * `slices` - Can be:
1264    ///   - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1265    ///   - An array of ranges (e.g., `[0..2, 1..4]`)
1266    ///   - The [`s!`] macro output for advanced slicing with steps
1267    ///   - a `&Vec<Slice>` or `&[Slice]`
1268    ///
1269    /// # Behavior
1270    ///
1271    /// - Supports partial and full slicing in any number of dimensions
1272    /// - Handles negative indices by wrapping from the end (-1 is the last element)
1273    /// - Automatically clamps ranges that exceed tensor dimensions
1274    /// - Supports stepped slicing for selecting every nth element
1275    /// - Negative steps reverse the selection order
1276    ///
1277    /// # Panics
1278    ///
1279    /// - If the number of slices exceeds the tensor's dimensions
1280    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1281    /// - If a step is zero
1282    ///
1283    /// # Examples
1284    ///
1285    /// ```rust
1286    /// use burn_tensor::backend::Backend;
1287    /// use burn_tensor::{Tensor, Shape, s};
1288    ///
1289    /// fn example<B: Backend>() {
1290    ///     let device = B::Device::default();
1291    ///
1292    ///     // Single dimension slicing - no brackets needed!
1293    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1294    ///     let slice = tensor.clone().slice(2..8);  // Simple range
1295    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1296    ///
1297    ///     // Using s! macro for single dimension with step
1298    ///     let slice = tensor.clone().slice(s![0..10;2]);  // Every 2nd element
1299    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1300    ///
1301    ///     // Reverse a dimension with negative step
1302    ///     let slice = tensor.slice(s![..;-1]);  // Reverse entire tensor
1303    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1304    ///
1305    ///     // Multi-dimensional slicing
1306    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1307    ///
1308    ///     // Array syntax for simple ranges
1309    ///     let slice = tensor.clone().slice([1..3, 2..5]);
1310    ///     assert_eq!(slice.dims(), [2, 3]);
1311    ///
1312    ///     // Advanced multi-dimensional with s! macro
1313    ///     let slice = tensor.clone().slice(s![0..4;2, ..;-1]);  // Every 2nd row, reverse columns
1314    ///     assert_eq!(slice.dims(), [2, 6]);
1315    ///
1316    ///     // Complex 3D example with mixed slice types
1317    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1318    ///     let slice = tensor.slice(s![1..3, ..;2, -3..]);  // Rows 1-2, every 2nd col, last 3 depth
1319    ///     assert_eq!(slice.dims(), [2, 3, 3]);
1320    ///
1321    ///     // Using negative indices
1322    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1323    ///     let slice = tensor.slice(s![-2.., ..-1]);  // Last 2 rows, all but last column
1324    ///     assert_eq!(slice.dims(), [2, 5]);
1325    /// }
1326    /// ```
1327    ///
1328    /// # See Also
1329    ///
1330    /// - [`s!`] - The recommended macro for creating complex slice specifications
1331    /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1332    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1333    /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1334    ///
1335    /// [`s!`]: crate::s!
1336    pub fn slice<S>(self, slices: S) -> Self
1337    where
1338        S: SliceArg,
1339    {
1340        let shape = self.shape();
1341        let slices = slices.into_slices(&shape);
1342
1343        // Validate slices
1344        check!(TensorCheck::slice::<D>(&shape, &slices));
1345
1346        // Calculate output shape and check for empty slices
1347        let mut output_dims = shape.dims.clone();
1348        for (dim, slice) in slices.iter().enumerate() {
1349            output_dims[dim] = slice.output_size(shape.dims[dim]);
1350        }
1351
1352        // Return empty tensor if any dimension is 0 (empty slice)
1353        if output_dims.contains(&0) {
1354            return Self::empty(output_dims, &self.device());
1355        }
1356        Self::new(K::slice(self.primitive, &slices))
1357    }
1358
1359    /// Assigns values to a slice of the tensor and returns the updated tensor.
1360    ///
1361    /// This method supports advanced slicing with steps, including negative steps for reverse
1362    /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1363    /// providing powerful syntax for complex patterns.
1364    ///
1365    /// # Arguments
1366    ///
1367    /// * `slices` - Slice specification (same format as `slice` method)
1368    /// * `values` - Tensor with values to assign (must match slice dimensions)
1369    ///
1370    /// # Panics
1371    ///
1372    /// - If slices exceed tensor dimensions
1373    /// - If values dimensions don't match the selected slice shape
1374    /// - If a step is zero
1375    ///
1376    /// # Examples
1377    ///
1378    /// ```rust
1379    /// use burn_tensor::backend::Backend;
1380    /// use burn_tensor::{Tensor, s};
1381    ///
1382    /// fn example<B: Backend>() {
1383    ///     let device = B::Device::default();
1384    ///
1385    ///     // Simple assignment to a sub-region
1386    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1387    ///     let values = Tensor::<B, 2>::ones([2, 3], &device);
1388    ///     tensor = tensor.slice_assign([1..3, 2..5], values);
1389    ///     // Now tensor[1..3, 2..5] contains ones
1390    ///
1391    ///     // Single dimension assignment with step
1392    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1393    ///     let values = Tensor::<B, 1>::ones([5], &device);
1394    ///     tensor = tensor.slice_assign(s![0..10;2], values);
1395    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1396    ///
1397    ///     // Reverse assignment with negative step
1398    ///     let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1399    ///     let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1400    ///     tensor = tensor.slice_assign(s![..;-1], values);
1401    ///     // Assigns in reverse: [14, 13, 12, 11, 10]
1402    ///
1403    ///     // Complex multi-dimensional assignment
1404    ///     let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1405    ///     let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1406    ///     tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1407    ///     // Assigns to every 2nd row, every 2nd column, last 3 in depth
1408    ///
1409    ///     // Mixed syntax example
1410    ///     let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1411    ///     let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1412    ///     tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1413    ///     // Creates a checkerboard pattern with ones
1414    /// }
1415    /// ```
1416    ///
1417    /// # See Also
1418    ///
1419    /// - [`s!`] - The recommended macro for creating complex slice specifications
1420    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1421    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1422    ///
1423    /// [`s!`]: crate::s!
1424    pub fn slice_assign<S>(self, slices: S, values: Self) -> Self
1425    where
1426        S: SliceArg,
1427    {
1428        let shape = self.shape();
1429        let slices = slices.into_slices(&shape);
1430
1431        // Check if any slice produces 0 elements (empty assignment).
1432        // Empty assignments are no-ops and would cause issues in backend implementations.
1433        let is_empty_assignment = slices
1434            .iter()
1435            .enumerate()
1436            .any(|(i, slice)| slice.output_size(shape.dims[i]) == 0);
1437
1438        if is_empty_assignment {
1439            return self;
1440        }
1441
1442        check!(TensorCheck::slice_assign::<D>(
1443            &shape,
1444            &values.shape(),
1445            &slices
1446        ));
1447
1448        Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1449    }
1450
1451    /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1452    ///
1453    /// Like other slice methods, accepts both single slices and arrays. However, this method
1454    /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1455    /// with a constant tensor for stepped patterns.
1456    ///
1457    /// # Arguments
1458    ///
1459    /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1460    /// * `value` - The value to fill the slice with
1461    ///
1462    /// # Panics
1463    ///
1464    /// - If slices exceed tensor dimensions
1465    /// - If any slice has a step != 1 (not yet supported)
1466    ///
1467    /// # Examples
1468    ///
1469    /// ```rust
1470    /// use burn_tensor::backend::Backend;
1471    /// use burn_tensor::{Tensor, s};
1472    ///
1473    /// fn example<B: Backend>() {
1474    ///     let device = B::Device::default();
1475    ///
1476    ///     // Simple fill for a single dimension
1477    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1478    ///     tensor = tensor.slice_fill(2..5, 1.0);
1479    ///     // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1480    ///
1481    ///     // Multi-dimensional fill
1482    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1483    ///     tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1484    ///     // Fills the rectangle at rows 1-2, columns 2-4 with -1
1485    ///
1486    ///     // Using negative indices
1487    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1488    ///     tensor = tensor.slice_fill(-3.., 2.0);
1489    ///     // Fills the last 3 elements with 2.0
1490    ///
1491    ///     // Complex multi-dimensional example
1492    ///     let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1493    ///     tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1494    ///     // Sets rows 1-2, all columns, last 2 in depth to 0
1495    ///
1496    ///     // Stepped slicing is supported
1497    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1498    ///     tensor = tensor.slice_fill(s![0..10;2], 1.0);
1499    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1500    /// }
1501    /// ```
1502    ///
1503    /// # See Also
1504    ///
1505    /// - [`s!`] - The macro for creating slice specifications with steps
1506    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1507    /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1508    ///
1509    /// [`s!`]: crate::s!
1510    pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self
1511    where
1512        S: SliceArg,
1513    {
1514        let shape = self.shape();
1515        let slices = slices.into_slices(&shape);
1516
1517        check!(TensorCheck::slice::<D>(&shape, &slices));
1518
1519        let slice_shape = shape.slice(&slices).unwrap();
1520        let value = Tensor::<B, 1, K>::from_data_dtype(
1521            [value.elem::<K::Elem>()],
1522            &self.device(),
1523            self.dtype(),
1524        );
1525        let value = value.expand(slice_shape);
1526        self.slice_assign(&slices, value)
1527    }
1528
1529    /// Returns a new tensor with the specified dimension sliced.
1530    ///
1531    /// # Arguments
1532    ///
1533    /// * `dim`: The dimension to slice.
1534    /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1535    ///   slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1536    ///
1537    /// # Returns
1538    ///
1539    /// A new tensor with the specified dimension sliced.
1540    ///
1541    /// # Panics
1542    ///
1543    /// If the slice is out of bounds for the specified dimension.
1544    ///
1545    /// # Examples
1546    ///
1547    /// ```rust
1548    /// # use burn_tensor::{Tensor, s};
1549    /// # use burn_tensor::backend::Backend;
1550    /// #
1551    /// # fn example<B: Backend>() {
1552    /// #     let device = B::Device::default();
1553    ///     let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1554    ///
1555    ///     // Simple range slicing
1556    ///     let sliced = tensor.clone().slice_dim(1, 1..3);
1557    ///     assert_eq!(sliced.shape().dims, [3, 2, 5]);
1558    ///
1559    ///     // Slicing with step - take every 2nd element
1560    ///     let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1561    ///     assert_eq!(sliced.shape().dims, [3, 4, 3]); // Takes indices 0, 2, 4
1562    ///
1563    ///     // Reverse slicing with negative step
1564    ///     let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1565    ///     assert_eq!(sliced.shape().dims, [3, 4, 5]); // Reverses dimension 1
1566    ///
1567    ///     // Select from index 2 with step 3
1568    ///     let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1569    ///     assert_eq!(sliced.shape().dims, [1, 4, 5]); // Takes only index 2
1570    ///
1571    ///     // Select single index (reduces dimension to size 1)
1572    ///     let sliced = tensor.slice_dim(0, 1);
1573    ///     assert_eq!(sliced.shape().dims, [1, 4, 5]);
1574    /// # }
1575    /// ```
1576    ///
1577    /// # See Also
1578    ///
1579    /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1580    /// - [`s!`] - The macro for creating complex slice specifications
1581    ///
1582    /// [`s!`]: crate::s!
1583    pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1584    where
1585        S: Into<Slice>,
1586    {
1587        check!(TensorCheck::check_dim::<D>(dim));
1588        let slice: Slice = slice.into();
1589
1590        let mut slices = vec![Slice::full(); D];
1591        slices[dim] = slice;
1592
1593        self.slice(&slices)
1594    }
1595
1596    /// Returns the device of the current tensor.
1597    pub fn device(&self) -> B::Device {
1598        K::device(&self.primitive)
1599    }
1600
1601    /// Move the tensor to the given device.
1602    pub fn to_device(self, device: &B::Device) -> Self {
1603        Self::new(K::to_device(self.primitive, device))
1604    }
1605
1606    /// Select tensor elements along the given dimension corresponding to the given indices.
1607    ///
1608    /// # Arguments
1609    ///
1610    /// * `dim` - The dimension to select from. Supports negative indexing.
1611    /// * `indices` - The indices of the elements to select.
1612    ///
1613    /// # Example
1614    ///
1615    /// ```rust
1616    /// use burn_tensor::backend::Backend;
1617    /// use burn_tensor::{Tensor, Int};
1618    ///
1619    /// fn example<B: Backend>() {
1620    ///   let device = B::Device::default();
1621    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1622    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1623    ///   let tensor = tensor.select(0, indices);
1624    ///   println!("{tensor}");
1625    ///   //  [[1.0, -2.0, 3.0]]
1626    /// }
1627    /// ```
1628    pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1629        let dim = dim.expect_dim_index(D);
1630        check!(TensorCheck::select::<D>(dim));
1631        Self::new(K::select(self.primitive, dim, indices.primitive))
1632    }
1633
1634    /// Assign the selected elements along the given dimension corresponding to the given indices
1635    /// from the value tensor to the original tensor using sum reduction.
1636    ///
1637    /// # Note
1638    /// For booleans, the sum operator is logical or.
1639    ///
1640    /// # Arguments
1641    ///
1642    /// * `dim` - The dimension along which to select. Supports negative indexing.
1643    /// * `indices` - The indices to select from the tensor.
1644    /// * `values` - The values to assign to the selected indices.
1645    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1646    ///
1647    /// # Example
1648    ///
1649    /// Example using a 3D tensor:
1650    ///
1651    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1652    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1653    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1654    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1655    ///
1656    /// # Warning
1657    ///
1658    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1659    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1660    pub fn select_assign(
1661        self,
1662        dim: impl AsIndex,
1663        indices: Tensor<B, 1, Int>,
1664        values: Tensor<B, D, K>,
1665        update: IndexingUpdateOp,
1666    ) -> Self {
1667        let dim = dim.expect_dim_index(D);
1668        check!(TensorCheck::select_assign::<D>(
1669            dim,
1670            &indices.shape(),
1671            &values.shape()
1672        ));
1673
1674        Self::new(K::select_assign(
1675            self.primitive,
1676            dim,
1677            indices.primitive,
1678            values.primitive,
1679            update,
1680        ))
1681    }
1682
1683    /// Update the given tensor with the value tensor where the mask is true.
1684    ///
1685    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1686    /// a scalar.
1687    ///
1688    /// # Example
1689    ///
1690    /// ```rust
1691    /// use burn_tensor::backend::Backend;
1692    /// use burn_tensor::{Tensor, Shape, Bool};
1693    ///
1694    /// fn example<B: Backend>() {
1695    ///   let device = B::Device::default();
1696    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1697    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1698    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1699    ///   let tensor = tensor.mask_where(mask, value);
1700    ///   println!("{tensor}");
1701    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1702    /// }
1703    /// ```
1704    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1705        Self::new(K::mask_where(
1706            self.primitive,
1707            mask.primitive,
1708            value.primitive,
1709        ))
1710    }
1711
1712    /// Update the given tensor with the value where the mask is true.
1713    ///
1714    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1715    /// a tensor.
1716    ///
1717    /// # Example
1718    ///
1719    /// ```rust
1720    /// use burn_tensor::backend::Backend;
1721    /// use burn_tensor::{Tensor, Shape, Bool};
1722    ///
1723    /// fn example<B: Backend>() {
1724    ///   let device = B::Device::default();
1725    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1726    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1727    ///   let tensor = tensor.mask_fill(mask, 3.0);
1728    ///   println!("{tensor}");
1729    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1730    /// }
1731    /// ```
1732    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1733        Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
1734    }
1735
1736    /// Gather tensor elements corresponding to the given indices from the specified dim.
1737    ///
1738    /// Example using a 3D tensor:
1739    ///
1740    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1741    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1742    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1743    ///
1744    /// # Notes
1745    ///
1746    /// The index tensor should have the same shape as the original tensor except for the dim
1747    /// specified.
1748    ///
1749    /// # Warning
1750    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1751    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1752    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1753        check!(TensorCheck::gather::<D>(
1754            dim,
1755            &self.shape(),
1756            &indices.shape()
1757        ));
1758
1759        Self::new(K::gather(dim, self.primitive, indices.primitive))
1760    }
1761
1762    /// Assign the gathered elements corresponding to the given indices along the specified dimension
1763    /// from the value tensor to the original tensor using sum reduction.
1764    ///
1765    /// Example using a 3D tensor:
1766    ///
1767    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1768    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1769    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1770    ///
1771    /// # Arguments
1772    /// * `dim` - The axis along which to scatter elements.
1773    /// * `indices` - The indices of the elements to scatter.
1774    /// * `values` - The values to scatter into the tensor.
1775    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1776    ///
1777    /// # Notes
1778    ///
1779    /// The index tensor should have the same shape as the original tensor except for the specified
1780    /// dimension. The value and index tensors should have the same shape.
1781    ///
1782    /// Other references to the input tensor will not be modified by this operation.
1783    ///
1784    /// # Warning
1785    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1786    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1787    pub fn scatter(
1788        self,
1789        dim: usize,
1790        indices: Tensor<B, D, Int>,
1791        values: Self,
1792        update: IndexingUpdateOp,
1793    ) -> Self {
1794        check!(TensorCheck::scatter::<D>(
1795            dim,
1796            &self.shape(),
1797            &indices.shape(),
1798            &values.shape()
1799        ));
1800
1801        Self::new(K::scatter(
1802            dim,
1803            self.primitive,
1804            indices.primitive,
1805            values.primitive,
1806            update,
1807        ))
1808    }
1809
1810    /// Converts the data of the current tensor.
1811    ///
1812    /// # Note
1813    ///
1814    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1815    /// tensors at once. This may improve laziness, especially if executed on a different
1816    /// thread in native environments.
1817    pub fn into_data(self) -> TensorData {
1818        self.try_into_data().expect(
1819            "Error while reading data: use `try_into_data` instead to catch the error at runtime",
1820        )
1821    }
1822
1823    /// Converts the data of the current tensor and returns any error that might have occurred since the
1824    /// last time the device was synchronized.
1825    ///
1826    /// # Note
1827    ///
1828    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1829    /// tensors at once. This may improve laziness, especially if executed on a different
1830    /// thread in native environments.
1831    pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {
1832        crate::try_read_sync(self.into_data_async()).expect(
1833            "Failed to read tensor data synchronously.
1834        This can happen on platforms that don't support blocking futures like WASM.
1835        If possible, try using into_data_async instead.",
1836        )
1837    }
1838
1839    /// Converts the data of the current tensor.
1840    ///
1841    /// # Note
1842    ///
1843    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1844    /// tensors at once. This may improve laziness, especially if executed on a different
1845    /// thread in native environments.
1846    pub fn to_data(&self) -> TensorData {
1847        self.clone().into_data()
1848    }
1849
1850    /// Returns the data of the current tensor.
1851    pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {
1852        K::into_data_async(self.primitive).await
1853    }
1854
1855    /// Returns the data of the current tensor.
1856    pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {
1857        self.clone().into_data_async().await
1858    }
1859
1860    /// Create a tensor from the given data on the given device.
1861    pub fn from_data<T>(data: T, device: &B::Device) -> Self
1862    where
1863        T: Into<TensorData>,
1864    {
1865        let data = data.into();
1866        check!(TensorCheck::creation_ops::<D>(
1867            "From Data",
1868            data.shape.as_slice()
1869        ));
1870        Self::new(K::from_data(data, device))
1871    }
1872
1873    /// Create a tensor from the given data on the given device enforcing the given data type.
1874    pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
1875    where
1876        T: Into<TensorData>,
1877    {
1878        let data = data.into();
1879        check!(TensorCheck::creation_ops::<D>(
1880            "From Data",
1881            data.shape.as_slice()
1882        ));
1883        Self::new(K::from_data_dtype(data, device, dtype))
1884    }
1885
1886    /// Repeat the tensor along the given dimension.
1887    ///
1888    /// The output tensor has the same shape, except along the given dimension.
1889    ///
1890    /// # Arguments
1891    /// - `dim`: The dimension to repeat.
1892    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1893    ///
1894    /// # Returns
1895    ///
1896    /// A new tensor with the given dimension repeated `times` times.
1897    ///
1898    /// # Example
1899    ///
1900    /// ```rust
1901    /// use burn_tensor::backend::Backend;
1902    /// use burn_tensor::Tensor;
1903    ///
1904    /// fn example<B: Backend>() {
1905    ///     let device = Default::default();
1906    ///     // Create a 2D tensor with dimensions [3, 2]
1907    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1908    ///
1909    ///     // Repeat the tensor along the dimension 0 twice.
1910    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1911    ///     // The resulting tensor will have dimensions [6, 2].
1912    ///     let repeated = tensor.repeat_dim(0, 2);
1913    ///     println!("{repeated}");
1914    /// }
1915    /// ```
1916    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1917        if times > 0 {
1918            Self::new(K::repeat_dim(self.primitive, dim, times))
1919        } else {
1920            let shape = self.shape().repeat(dim, times).unwrap();
1921            Self::empty(shape, &self.device())
1922        }
1923    }
1924
1925    /// Repeat the tensor along the given dimensions.
1926    /// # Arguments
1927    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1928    ///
1929    /// # Returns
1930    ///
1931    /// A new tensor with the given dimensions repeated `times` times.
1932    ///
1933    /// # Panics
1934    ///
1935    /// If `sizes` contains more elements than the number of dimensions.
1936    ///
1937    /// # Example
1938    ///
1939    /// ```rust
1940    ///
1941    /// use burn_tensor::backend::Backend;
1942    /// use burn_tensor::Tensor;
1943    ///
1944    /// fn example<B: Backend>() {
1945    ///     let device = Default::default();
1946    ///     // Create a 2D tensor with dimensions [3, 2]
1947    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1948    ///
1949    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1950    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1951    ///     // The resulting tensor will have dimensions [6, 2].
1952    ///     let repeated = tensor.repeat(&[2, 1]);
1953    /// }
1954    /// ```
1955    pub fn repeat(self, sizes: &[usize]) -> Self {
1956        if sizes.contains(&0) {
1957            let mut shape = self.shape();
1958            for (dim, &times) in sizes.iter().enumerate() {
1959                shape = shape.repeat(dim, times).unwrap();
1960            }
1961
1962            return Self::empty(shape, &self.device());
1963        }
1964
1965        let mut tensor = self;
1966        for (dim, &times) in sizes.iter().enumerate() {
1967            if times > 1 {
1968                tensor = tensor.repeat_dim(dim, times);
1969            }
1970        }
1971        tensor
1972    }
1973
1974    /// Applies element-wise equal comparison.
1975    ///
1976    /// # Returns
1977    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1978    ///
1979    /// # Panics
1980    ///
1981    /// If the two tensors don't have the same shape.
1982    ///
1983    /// # Example
1984    ///
1985    /// ```rust
1986    /// use burn_tensor::backend::Backend;
1987    /// use burn_tensor::Tensor;
1988    ///
1989    /// fn example<B: Backend>() {
1990    ///     let device = Default::default();
1991    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1992    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1993    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
1994    ///     // [[false, true], [true, true], [true, true]]
1995    ///     let equal = t1.equal(t2);
1996    ///     println!("{equal}");
1997    /// }
1998    /// ```
1999    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
2000        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
2001        Tensor::new(K::equal(self.primitive, other.primitive))
2002    }
2003
2004    /// Applies element-wise non-equality comparison.
2005    ///
2006    /// # Returns
2007    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
2008    ///
2009    /// # Panics
2010    ///
2011    /// If the two tensors don't have the same shape.
2012    ///
2013    /// # Example
2014    ///
2015    /// ```rust
2016    /// use burn_tensor::backend::Backend;
2017    /// use burn_tensor::Tensor;
2018    ///
2019    /// fn example<B: Backend>() {
2020    ///     let device = Default::default();
2021    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2022    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2023    ///     // Compare the elements of the two 2D tensors for inequality.
2024    ///     // [[true, false], [false, false], [false, false]]
2025    ///     let not_equal = t1.not_equal(t2);
2026    ///     println!("{not_equal}");
2027    /// }
2028    /// ```
2029    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
2030        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
2031        Tensor::new(K::not_equal(self.primitive, other.primitive))
2032    }
2033
2034    /// Applies element wise equal comparison and returns a boolean tensor.
2035    ///
2036    /// # Arguments
2037    ///
2038    /// * `other` - The element to compare.
2039    ///
2040    /// # Example
2041    ///
2042    /// ```rust
2043    /// use burn_tensor::backend::Backend;
2044    /// use burn_tensor::{Tensor, Shape};
2045    ///
2046    /// fn example<B: Backend>() {
2047    ///    let device = B::Device::default();
2048    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2049    ///    let tensor = tensor.equal_elem(3.0);
2050    ///    println!("{tensor}");
2051    ///    // [[false, false, true], [false, false, false]]
2052    /// }
2053    /// ```
2054    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2055        Tensor::new(K::equal_elem(self.primitive, other.elem()))
2056    }
2057
2058    /// Applies element wise non-equality comparison and returns a boolean tensor.
2059    ///
2060    /// # Arguments
2061    ///
2062    /// * `other` - The element to compare.
2063    ///
2064    /// # Example
2065    ///
2066    /// ```rust
2067    /// use burn_tensor::backend::Backend;
2068    /// use burn_tensor::{Tensor, Shape};
2069    ///
2070    /// fn example<B: Backend>() {
2071    ///    let device = B::Device::default();
2072    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2073    ///    let tensor = tensor.not_equal_elem(3.0);
2074    ///    println!("{tensor}");
2075    ///    // [[true, true, false], [true, true, true]]
2076    /// }
2077    /// ```
2078    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2079        Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
2080    }
2081
2082    /// Concatenates all tensors into a new one along the given dimension.
2083    ///
2084    /// # Panics
2085    ///
2086    /// - If `dim` is higher than the rank.
2087    /// - If `tensors` is an empty vector.
2088    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
2089    ///
2090    /// # Example
2091    ///
2092    /// ```rust
2093    /// use burn_tensor::backend::Backend;
2094    /// use burn_tensor::Tensor;
2095    ///
2096    /// fn example<B: Backend>() {
2097    ///     let device = Default::default();
2098    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
2099    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2100    ///
2101    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
2102    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
2103    ///     // The resulting tensor will have shape [2, 7].
2104    ///     let concat = Tensor::cat(vec![t1, t2], 1);
2105    ///     println!("{concat}");
2106    /// }
2107    /// ```
2108    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
2109        check!(TensorCheck::cat(&tensors, dim));
2110
2111        // Filter out tensors with size 0 along the concatenation dimension.
2112        // Empty tensors don't contribute to the output and would cause issues
2113        // in backend implementations (e.g., division by zero in slice_assign).
2114        // Safety: TensorCheck::cat ensures tensors is non-empty
2115        let first_tensor = tensors.first().unwrap();
2116        let device = first_tensor.device();
2117        let mut shape = first_tensor.shape();
2118
2119        let non_empty_primitives: Vec<_> = tensors
2120            .into_iter()
2121            .filter(|t| t.shape().dims[dim] > 0)
2122            .map(|t| t.primitive)
2123            .collect();
2124
2125        // If all tensors were empty, return an empty tensor with size 0 on concat dim
2126        if non_empty_primitives.is_empty() {
2127            shape.dims[dim] = 0;
2128            return Self::empty(shape, &device);
2129        }
2130
2131        Self::new(K::cat(non_empty_primitives, dim))
2132    }
2133
2134    /// Concatenates all tensors into a new one along a new dimension.
2135    ///
2136    /// # Panics
2137    ///
2138    /// - If all tensors don't have the same shape.
2139    /// - If given dimension is not with range of 0..D2
2140    ///
2141    /// # Example
2142    ///
2143    /// ```rust
2144    /// use burn_tensor::backend::Backend;
2145    /// use burn_tensor::Tensor;
2146    ///
2147    /// fn example<B: Backend>() {
2148    ///     let device = Default::default();
2149    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2150    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2151    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2152    ///
2153    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
2154    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
2155    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
2156    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
2157    ///     // The resulting tensor will have shape [3, 2, 3].
2158    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
2159    ///     println!("{stacked}");
2160    /// }
2161    /// ```
2162    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
2163        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
2164        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
2165        Tensor::<B, D2, K>::cat(tensors, dim)
2166    }
2167
2168    /// Iterate over slices of tensors alongside a given dimension.
2169    ///
2170    /// # Panics
2171    ///
2172    /// If given dimension is greater than or equal to tensor rank.
2173    ///
2174    /// # Returns
2175    ///
2176    /// A tensor iterator.
2177    ///
2178    /// # Example
2179    ///
2180    /// ```rust
2181    /// use burn_tensor::backend::Backend;
2182    /// use burn_tensor::Tensor;
2183    /// fn example<B: Backend>() {
2184    ///   let device = Default::default();
2185    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2186    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
2187    ///   let iter = tensor.iter_dim(0);
2188    ///   for (i,tensor) in iter.enumerate() {
2189    ///     println!("Tensor {}: {}", i, tensor);
2190    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
2191    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
2192    ///  }
2193    /// }
2194    /// ```
2195    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
2196        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
2197        DimIter::new(self, dim)
2198    }
2199
2200    /// Returns a new tensor with the given dimension narrowed to the given range.
2201    ///
2202    /// # Panics
2203    ///
2204    /// - If the dimension is greater than the number of dimensions of the tensor.
2205    /// - If the given range exceeds the number of elements on the given dimension.
2206    ///
2207    /// # Returns
2208    ///
2209    /// A new tensor with the given dimension narrowed to the given range.
2210    ///
2211    /// # Example
2212    ///
2213    /// ```rust
2214    /// use burn_tensor::backend::Backend;
2215    /// use burn_tensor::Tensor;
2216    ///
2217    /// fn example<B: Backend>() {
2218    ///     let device = Default::default();
2219    ///     // Create a 2D tensor with dimensions [4, 3]
2220    ///     let tensor = Tensor::<B, 2>::from_data(
2221    ///         [
2222    ///             [3.0, 4.9, 2.0],
2223    ///             [2.0, 1.9, 3.0],
2224    ///             [6.0, 1.5, 7.0],
2225    ///             [3.0, 4.9, 9.0],
2226    ///         ],
2227    ///         &device,
2228    ///     );
2229    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
2230    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
2231    ///     // The resulting tensor will have dimensions [3, 3].
2232    ///     let narrowed = tensor.narrow(0, 1, 3);
2233    ///     println!("{narrowed}");
2234    /// }
2235    /// ```
2236    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
2237        check!(TensorCheck::dim_ops::<D>("narrow", dim));
2238        check!(TensorCheck::narrow(&self, dim, start, length));
2239        let dims = self.dims();
2240
2241        let ranges: [Range<usize>; D] = dims
2242            .iter()
2243            .enumerate()
2244            .map(|(i, d)| {
2245                if i == dim {
2246                    start..(start + length)
2247                } else {
2248                    0..*d
2249                }
2250            })
2251            .collect::<Vec<_>>()
2252            .try_into()
2253            .unwrap();
2254
2255        Self::slice(self, ranges)
2256    }
2257
2258    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
2259    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2260    ///
2261    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2262    /// Otherwise all chunks will be of equal size except for the last one.
2263    ///
2264    /// # Panics
2265    ///
2266    /// If the dimension is greater than the number of dimensions of the tensor.
2267    ///
2268    /// # Returns
2269    /// A vector of tensors.
2270    ///
2271    /// # Example
2272    ///
2273    /// ```rust
2274    /// use burn_tensor::backend::Backend;
2275    /// use burn_tensor::Tensor;
2276    ///
2277    /// fn example<B: Backend>() {
2278    ///     let device = Default::default();
2279    ///     // Create a 2D tensor with dimensions [4, 3]
2280    ///     let tensor = Tensor::<B, 2>::from_data(
2281    ///         [
2282    ///             [3.0, 4.9, 2.0],
2283    ///             [2.0, 1.9, 3.0],
2284    ///             [6.0, 1.5, 7.0],
2285    ///             [3.0, 4.9, 9.0],
2286    ///         ],
2287    ///         &device,
2288    ///     );
2289    ///     // Split the tensor along the dimension 1 into 2 chunks.
2290    ///     // The first chuck will have shape [4, 2]:
2291    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
2292    ///     // The second chunk will have shape [4, 1]:
2293    ///     // [[2.0], [3.0], [7.0], [9.0]]
2294    ///     let chunks = tensor.chunk(2, 1);
2295    ///     println!("{chunks:?}");
2296    /// }
2297    /// ```
2298    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
2299        check!(TensorCheck::dim_ops::<D>("chunk", dim));
2300        let size = self.shape().dims[dim];
2301        if size < chunks {
2302            return (0..size)
2303                .map(|i| Self::narrow(self.clone(), dim, i, 1))
2304                .collect();
2305        }
2306
2307        let mut tensors = Vec::with_capacity(chunks);
2308        let mut sum_chunk_size = 0;
2309        if size.is_multiple_of(chunks) {
2310            let chunk_size = size / chunks;
2311            for _ in 0..chunks {
2312                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2313                sum_chunk_size += chunk_size;
2314            }
2315        } else {
2316            let chunk_size = (size / chunks) + 1; // assumes not divisible
2317            for _ in 0..chunks - 1 {
2318                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2319                sum_chunk_size += chunk_size;
2320            }
2321            let remainder = size % chunk_size;
2322            tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
2323        }
2324
2325        tensors
2326    }
2327
2328    /// Splits the tensor into chunks of a specified size along a given dimension.
2329    /// Each chunk is a view of the original tensor.
2330    ///
2331    /// If the tensor size along the given dimension is not divisible by `split_size`,
2332    /// then the last chunk will be smaller.
2333    ///
2334    /// # Panics
2335    ///
2336    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2337    ///
2338    /// # Returns
2339    ///
2340    /// A vector of tensors.
2341    ///
2342    /// # Example
2343    /// ```rust
2344    /// use burn_tensor::backend::Backend;
2345    /// use burn_tensor::Tensor;
2346    ///
2347    /// fn example<B: Backend>() {
2348    ///     let device = Default::default();
2349    ///     // Create a 1D tensor with 5 elements
2350    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2351    ///     // Split the tensor into chunks of size 2 along dimension 0
2352    ///     let chunks = tensor.split(2, 0);
2353    ///     // The result is a vector of tensors:
2354    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2355    ///     println!("{:?}", chunks);
2356    /// }
2357    /// ```
2358    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2359        check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2360        let size = self.shape().dims[dim];
2361        let mut tensors = Vec::new();
2362
2363        let mut start = 0;
2364        while start < size {
2365            let length = usize::min(split_size, size - start);
2366            tensors.push(Self::narrow(self.clone(), dim, start, length));
2367            start += length;
2368        }
2369
2370        tensors
2371    }
2372
2373    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2374    /// Each chunk is a view of the original tensor.
2375    ///
2376    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2377    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2378    ///
2379    /// # Panics
2380    ///
2381    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2382    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2383    ///
2384    /// # Returns
2385    ///
2386    /// A vector of tensors.
2387    ///
2388    /// # Example
2389    /// ```rust
2390    /// use burn_tensor::backend::Backend;
2391    /// use burn_tensor::Tensor;
2392    ///
2393    /// fn example<B: Backend>() {
2394    ///     let device = Default::default();
2395    ///     // Create a 1D tensor with 5 elements
2396    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2397    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
2398    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2399    ///     // The result is a vector of tensors:
2400    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2401    ///     println!("{:?}", chunks);
2402    /// }
2403    /// ```
2404    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2405        check!(TensorCheck::split_with_sizes::<D>(
2406            &self.shape(),
2407            &split_sizes,
2408            dim
2409        ));
2410        let mut tensors = Vec::new();
2411
2412        let mut start = 0;
2413        for length in split_sizes {
2414            if length == 0 {
2415                continue;
2416            }
2417            tensors.push(Self::narrow(self.clone(), dim, start, length));
2418            start += length;
2419        }
2420
2421        tensors
2422    }
2423
2424    /// Tests if any element in the `tensor` evaluates to True.
2425    ///
2426    /// # Arguments
2427    ///
2428    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2429    ///
2430    /// # Returns
2431    ///
2432    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2433    /// evaluates to True, False otherwise.
2434    ///
2435    /// # Example
2436    ///
2437    /// ```rust
2438    /// use burn_tensor::backend::Backend;
2439    /// use burn_tensor::{Tensor, Bool};
2440    ///
2441    /// fn example<B: Backend>() {
2442    ///   let device = Default::default();
2443    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2444    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2445    ///
2446    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2447    ///   let any_tensor = tensor.any();
2448    ///   println!("{}", any_tensor);
2449    ///   // Tensor { data: [true], ... }
2450    ///
2451    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2452    ///   let any_tensor_two = tensor_two.any();
2453    ///   println!("{}", any_tensor_two);
2454    ///   // Tensor { data: [false], ... }
2455    /// }
2456    /// ```
2457    pub fn any(self) -> Tensor<B, 1, Bool> {
2458        Tensor::new(K::any(self.primitive))
2459    }
2460
2461    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2462    ///
2463    /// # Arguments
2464    ///
2465    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2466    /// * `dim` - The axis along which to test.
2467    ///
2468    /// # Returns
2469    ///
2470    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2471    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2472    /// evaluates to True, False otherwise.
2473    ///
2474    /// # Example
2475    ///
2476    /// ```rust
2477    /// use burn_tensor::backend::Backend;
2478    /// use burn_tensor::{Tensor, Bool};
2479    ///
2480    /// fn example<B: Backend>() {
2481    ///     let device = Default::default();
2482    ///     let tensor =
2483    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2484    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
2485    ///     // [[true], [true]],
2486    ///     let any_dim = tensor.clone().any_dim(1);
2487    ///     println!("{any_dim}");
2488    /// }
2489    /// ```
2490    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2491        Tensor::new(K::any_dim(self.primitive, dim))
2492    }
2493
2494    /// Tests if all elements in the `tensor` evaluate to True.
2495    ///
2496    /// # Arguments
2497    ///
2498    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2499    ///
2500    /// # Returns
2501    ///
2502    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2503    /// evaluate to True, False otherwise.
2504    ///
2505    /// # Example
2506    ///
2507    /// ```rust
2508    /// use burn_tensor::backend::Backend;
2509    /// use burn_tensor::{Tensor, Bool};
2510    ///
2511    /// fn example<B: Backend>() {
2512    ///     let device = Default::default();
2513    ///     let tensor =
2514    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2515    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
2516    ///     // [false]
2517    ///     let all = tensor.all();
2518    ///     println!("{all}");
2519    /// }
2520    /// ```
2521    pub fn all(self) -> Tensor<B, 1, Bool> {
2522        Tensor::new(K::all(self.primitive))
2523    }
2524
2525    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2526    ///
2527    /// # Arguments
2528    ///
2529    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2530    /// * `dim` - The axis along which to test.
2531    ///
2532    /// # Returns
2533    ///
2534    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2535    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2536    /// evaluates to True, False otherwise.
2537    ///
2538    /// # Example
2539    ///
2540    /// ```rust
2541    /// use burn_tensor::backend::Backend;
2542    /// use burn_tensor::{Tensor, Bool};
2543    ///
2544    /// fn example<B: Backend>() {
2545    ///     let device = Default::default();
2546    ///     let tensor =
2547    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2548    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
2549    ///     // [[true, true, false]]
2550    ///     let all_dim = tensor.clone().all_dim(0);
2551    ///     println!("{all_dim}");
2552    /// }
2553    /// ```
2554    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2555        Tensor::new(K::all_dim(self.primitive, dim))
2556    }
2557
2558    /// Convert the tensor into a scalar.
2559    ///
2560    /// # Panics
2561    ///
2562    /// - If the tensor doesn't have one element.
2563    /// - If the backend fails to read the tensor data synchronously.
2564    ///
2565    /// # Returns
2566    ///
2567    /// The scalar value of the tensor.
2568    ///
2569    /// # Example
2570    ///
2571    /// ```rust
2572    /// use burn_tensor::backend::Backend;
2573    /// use burn_tensor::Tensor;
2574    ///
2575    /// fn example<B: Backend>() {
2576    ///     let device = Default::default();
2577    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2578    ///     // Convert the tensor with a single element into a scalar.
2579    ///     let scalar = tensor.into_scalar();
2580    ///     println!("{scalar}");
2581    /// }
2582    /// ```
2583    pub fn into_scalar(self) -> K::Elem {
2584        crate::try_read_sync(self.into_scalar_async())
2585            .expect(
2586            "Failed to read tensor data synchronously. This can happen on platforms
2587            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2588            )
2589            .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime")
2590    }
2591
2592    /// Convert the tensor into a scalar and returns any error that might have occurred since the
2593    /// last time the device was synchronized.
2594    ///
2595    /// # Panics
2596    ///
2597    /// - If the tensor doesn't have one element.
2598    /// - If the backend fails to read the tensor data synchronously.
2599    ///
2600    /// # Returns
2601    ///
2602    /// The scalar value of the tensor.
2603    pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {
2604        crate::try_read_sync(self.into_scalar_async()).expect(
2605            "Failed to read tensor data synchronously. This can happen on platforms
2606            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2607        )
2608    }
2609
2610    /// Convert the tensor into a scalar.
2611    ///
2612    /// # Panics
2613    ///
2614    /// If the tensor doesn't have one element.
2615    pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {
2616        check!(TensorCheck::into_scalar::<D>(&self.shape()));
2617
2618        Ok(self.into_data_async().await?.iter().next().unwrap())
2619    }
2620
2621    /// Broadcast the tensor to the given shape.
2622    ///
2623    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2624    /// (which can be inferred with `-1`).
2625    ///
2626    /// # Arguments
2627    ///
2628    /// * `shape` - The shape to broadcast the tensor to.
2629    ///   Can contain -1 for dimensions that should be inferred.
2630    ///   The number of elements in the shape must be greater or equal as
2631    ///   the number of dimensions of the tensor.
2632    ///
2633    /// # Panics
2634    ///
2635    /// If the tensor cannot be broadcasted to the given shape.
2636    ///
2637    /// # Returns
2638    ///
2639    /// A new tensor with the given shape.
2640    ///
2641    /// # Example
2642    ///
2643    /// ```rust
2644    /// use burn_tensor::backend::Backend;
2645    /// use burn_tensor::Tensor;
2646    ///
2647    /// fn example<B: Backend>() {
2648    ///     let device = Default::default();
2649    ///     // Create a 2D tensor with dimensions [3, 1]
2650    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2651    ///     // Expand the tensor to a new shape [3, 4]
2652    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2653    ///     let expanded = tensor.expand([3, 4]);
2654    ///     println!("{}", expanded);
2655    /// }
2656    /// ```
2657    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2658        let shape = shape.into_shape(&self.shape());
2659        check!(TensorCheck::expand::<D, D2>(
2660            "expand",
2661            &self.shape(),
2662            &shape,
2663        ));
2664
2665        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2666    }
2667
2668    /// Unfold windows along a dimension.
2669    ///
2670    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2671    /// where windows are advanced by `step` at each index.
2672    ///
2673    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2674    ///
2675    /// The new view will have the unfolded dimension replaced by two dimensions;
2676    /// one in the position of the original dimension, with size equal to the number of windows,
2677    /// and one appended to the right-most position, with size equal to `size`.
2678    ///
2679    /// # Warning
2680    ///
2681    /// For the `ndarray` and `candle` backends; this is not a view but a copy
2682    /// with duplicated data.
2683    ///
2684    /// # Arguments
2685    ///
2686    /// * `dim` - the dimension to unfold.
2687    /// * `size` - the size of each unfolded window.
2688    /// * `step` - the step between each window.
2689    ///
2690    /// # Returns
2691    ///
2692    /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2693    pub fn unfold<const D2: usize, I: AsIndex>(
2694        self,
2695        dim: I,
2696        size: usize,
2697        step: usize,
2698    ) -> Tensor<B, D2, K> {
2699        let dim = dim.expect_dim_index(D);
2700        check!(TensorCheck::unfold::<D, D2>(
2701            "unfold",
2702            &self.shape(),
2703            dim,
2704            size,
2705            step,
2706        ));
2707        Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2708    }
2709}
2710
2711/// Iterator given by (Tensor::iter_dim).
2712pub struct DimIter<B, const D: usize, K>
2713where
2714    B: Backend,
2715    K: BasicOps<B>,
2716{
2717    start: usize,
2718    end: usize,
2719    dim: usize,
2720    ranges: [Range<usize>; D],
2721    tensor: Tensor<B, D, K>,
2722}
2723
2724impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2725    type Item = Tensor<B, D, K>;
2726
2727    fn next(&mut self) -> Option<Self::Item> {
2728        if self.start >= self.end {
2729            return None;
2730        }
2731
2732        let mut ranges = self.ranges.clone();
2733        ranges[self.dim] = self.start..(self.start + 1);
2734
2735        let slice = self.tensor.clone().slice(ranges);
2736        self.start += 1;
2737
2738        Some(slice)
2739    }
2740}
2741
2742impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2743    fn next_back(&mut self) -> Option<Self::Item> {
2744        if self.start >= self.end {
2745            return None;
2746        }
2747
2748        let mut ranges = self.ranges.clone();
2749        ranges[self.dim] = (self.end - 1)..self.end;
2750
2751        let slice = self.tensor.clone().slice(ranges);
2752        self.end = self.end.saturating_sub(1);
2753
2754        Some(slice)
2755    }
2756}
2757
2758impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2759    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2760        let dims = tensor.dims();
2761        let ranges = dims
2762            .iter()
2763            .map(|&dim| 0..dim)
2764            .collect::<Vec<Range<usize>>>();
2765        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2766        Self {
2767            end: dims[dim],
2768            ranges,
2769            start: 0,
2770            dim,
2771            tensor,
2772        }
2773    }
2774}
2775
2776impl<B, const D: usize, K> Tensor<B, D, K>
2777where
2778    B: Backend,
2779    K: BasicOps<B>,
2780    <K as BasicOps<B>>::Elem: Debug,
2781{
2782    #[inline]
2783    fn push_newline_indent(acc: &mut String, indent: usize) {
2784        acc.push('\n');
2785        for _ in 0..indent {
2786            acc.push(' ');
2787        }
2788    }
2789    fn fmt_inner_tensor(
2790        &self,
2791        acc: &mut String,
2792        depth: usize,
2793        multi_index: &mut [usize],
2794        range: (usize, usize),
2795        precision: Option<usize>,
2796    ) {
2797        let (start, end) = range;
2798        for i in start..end {
2799            if i > 0 {
2800                acc.push_str(", ");
2801            }
2802            multi_index[depth] = i;
2803            let range: [Range<usize>; D] =
2804                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2805
2806            let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());
2807
2808            if let Some(Ok(data)) = data {
2809                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2810                match (precision, K::name()) {
2811                    (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2812                    (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2813                    _ => acc.push_str(&format!("{elem:?}")),
2814                }
2815            } else {
2816                acc.push_str("<Tensor data not available>");
2817            }
2818        }
2819    }
2820
2821    fn fmt_outer_tensor(
2822        &self,
2823        acc: &mut String,
2824        depth: usize,
2825        multi_index: &mut [usize],
2826        print_options: &PrintOptions,
2827        summarize: bool,
2828        range: (usize, usize),
2829    ) {
2830        let (start, end) = range;
2831        for i in start..end {
2832            if i > start {
2833                acc.push(',');
2834                Self::push_newline_indent(acc, depth + 1);
2835            }
2836            acc.push('[');
2837            multi_index[depth] = i;
2838            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2839            acc.push(']');
2840        }
2841    }
2842
2843    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2844    ///
2845    /// This function is designed to work with tensors of any dimensionality.
2846    /// It traverses the tensor dimensions recursively, converting the elements
2847    /// to strings and appending them to the accumulator string with the
2848    /// appropriate formatting.
2849    ///
2850    /// # Arguments
2851    ///
2852    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2853    /// * `depth` - The current depth of the tensor dimensions being processed.
2854    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2855    fn display_recursive(
2856        &self,
2857        acc: &mut String,
2858        depth: usize,
2859        multi_index: &mut [usize],
2860        print_options: &PrintOptions,
2861        summarize: bool,
2862    ) {
2863        let edge_items = print_options.edge_items;
2864
2865        if depth == 0 {
2866            acc.push('[');
2867        }
2868
2869        if depth == self.dims().len() - 1 {
2870            // if we are at the innermost dimension, just push its elements into the accumulator
2871            if summarize && self.dims()[depth] > 2 * edge_items {
2872                // print the starting `edge_items` elements
2873                self.fmt_inner_tensor(
2874                    acc,
2875                    depth,
2876                    multi_index,
2877                    (0, edge_items),
2878                    print_options.precision,
2879                );
2880                acc.push_str(", ...");
2881                // print the last `edge_items` elements
2882                self.fmt_inner_tensor(
2883                    acc,
2884                    depth,
2885                    multi_index,
2886                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2887                    print_options.precision,
2888                );
2889            } else {
2890                // print all the elements
2891                self.fmt_inner_tensor(
2892                    acc,
2893                    depth,
2894                    multi_index,
2895                    (0, self.dims()[depth]),
2896                    print_options.precision,
2897                );
2898            }
2899        } else {
2900            // otherwise, iterate through the current dimension and recursively display the inner tensors
2901            if summarize && self.dims()[depth] > 2 * edge_items {
2902                self.fmt_outer_tensor(
2903                    acc,
2904                    depth,
2905                    multi_index,
2906                    print_options,
2907                    summarize,
2908                    (0, edge_items),
2909                );
2910
2911                acc.push(',');
2912                Self::push_newline_indent(acc, depth + 1);
2913                acc.push_str("...");
2914                Self::push_newline_indent(acc, depth + 1);
2915
2916                self.fmt_outer_tensor(
2917                    acc,
2918                    depth,
2919                    multi_index,
2920                    print_options,
2921                    summarize,
2922                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2923                );
2924            } else {
2925                self.fmt_outer_tensor(
2926                    acc,
2927                    depth,
2928                    multi_index,
2929                    print_options,
2930                    summarize,
2931                    (0, self.dims()[depth]),
2932                );
2933            }
2934        }
2935
2936        if depth == 0 {
2937            acc.push(']');
2938        }
2939    }
2940}
2941
2942#[derive(Clone, Debug)]
2943/// Options for Tensor pretty printing
2944pub struct PrintOptions {
2945    /// number of elements to start summarizing tensor
2946    pub threshold: usize,
2947
2948    /// number of starting elements and ending elements to display
2949    pub edge_items: usize,
2950
2951    /// Precision for floating point numbers
2952    pub precision: Option<usize>,
2953}
2954
2955static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
2956
2957impl PrintOptions {
2958    /// Print options with default values
2959    pub const fn const_default() -> Self {
2960        Self {
2961            threshold: 1000,
2962            edge_items: 3,
2963            precision: None,
2964        }
2965    }
2966}
2967
2968impl Default for PrintOptions {
2969    fn default() -> Self {
2970        Self::const_default()
2971    }
2972}
2973
2974/// Set print options
2975pub fn set_print_options(options: PrintOptions) {
2976    let mut print_opts = PRINT_OPTS.write().unwrap();
2977    *print_opts = options;
2978}
2979
2980/// Pretty print tensors
2981impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
2982where
2983    B: Backend,
2984    B::IntElem: core::fmt::Display,
2985    K: BasicOps<B>,
2986    <K as BasicOps<B>>::Elem: Debug,
2987{
2988    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2989        writeln!(f, "Tensor {{")?;
2990
2991        {
2992            // Do not lock the mutex for the whole function
2993            let mut po = { PRINT_OPTS.read().unwrap().clone() };
2994
2995            // Override the precision if it is set from the formatter
2996            // This will be possible when the tensor is printed using the `{:.*}` syntax
2997            if let Some(precision) = f.precision() {
2998                po.precision = Some(precision);
2999            }
3000
3001            let mut acc = String::new();
3002            let mut multi_index = vec![0; D];
3003            let summarize = self.shape().num_elements() > po.threshold;
3004
3005            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
3006
3007            writeln!(f, "  data:")?;
3008            write!(f, "{acc}")?;
3009            writeln!(f, ",")?;
3010        }
3011
3012        writeln!(f, "  shape:  {:?},", self.dims())?;
3013        writeln!(f, "  device:  {:?},", self.device())?;
3014        writeln!(f, "  backend:  {:?},", B::name(&self.device()))?;
3015        writeln!(f, "  kind:  {:?},", K::name())?;
3016
3017        let dtype = self.primitive.dtype();
3018
3019        writeln!(f, "  dtype:  {:?},", dtype.name())?;
3020        write!(f, "}}")
3021    }
3022}
3023
3024/// Trait used for movedim arguments
3025pub trait MovedimArgs {
3026    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3027    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3028}
3029
3030impl MovedimArgs for Vec<i32> {
3031    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3032        let set = self
3033            .iter()
3034            .map(|&dim| {
3035                if dim < 0 {
3036                    (D as i32 + dim) as usize
3037                } else {
3038                    dim as usize
3039                }
3040            })
3041            .collect::<Vec<usize>>();
3042        check!(TensorCheck::movedim_args_vec::<D>(&set));
3043
3044        set
3045    }
3046}
3047
3048impl MovedimArgs for Vec<usize> {
3049    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3050        check!(TensorCheck::movedim_args_vec::<D>(&self));
3051        self
3052    }
3053}
3054
3055impl MovedimArgs for usize {
3056    #[allow(clippy::vec_init_then_push)]
3057    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3058        check!(TensorCheck::movedim_args_usize::<D>(self));
3059
3060        let mut set = Vec::with_capacity(1);
3061        set.push(self);
3062
3063        set
3064    }
3065}
3066
3067impl MovedimArgs for i32 {
3068    #[allow(clippy::vec_init_then_push)]
3069    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3070        check!(TensorCheck::movedim_args_i32::<D>(self));
3071
3072        let dim = if self < 0 {
3073            (D as i32 + self) as usize
3074        } else {
3075            self as usize
3076        };
3077
3078        let mut set = Vec::with_capacity(1);
3079        set.push(dim);
3080
3081        set
3082    }
3083}
3084
3085/// Trait used for reshape arguments.
3086pub trait ReshapeArgs<const D2: usize>: Debug {
3087    /// Converts to a shape.
3088    fn into_shape<const D: usize>(self, source: Shape) -> Shape;
3089}
3090
3091impl<const D2: usize, I: AsIndex> ReshapeArgs<D2> for [I; D2] {
3092    fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3093        unwrap_shape_reshape(source.reshape(self))
3094    }
3095}
3096
3097impl<const D2: usize> ReshapeArgs<D2> for Shape {
3098    fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3099        unwrap_shape_reshape(source.reshape(self))
3100    }
3101}
3102
3103/// Trait used for broadcast arguments.
3104pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3105    /// Converts to a shape.
3106    fn into_shape(self, shape: &Shape) -> Shape;
3107}
3108
3109impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3110    fn into_shape(self, _shape: &Shape) -> Shape {
3111        self
3112    }
3113}
3114
3115impl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {
3116    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3117    fn into_shape(self, shape: &Shape) -> Shape {
3118        if self.len() < shape.num_dims() {
3119            panic!("Broadcast arguments must be greater than the number of dimensions");
3120        }
3121
3122        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3123        let new_shape: Vec<_> = self
3124            .iter()
3125            .rev()
3126            .map(|x| {
3127                let primitive = x.as_index();
3128                if primitive < -1 || primitive == 0 {
3129                    panic!("Broadcast arguments must be positive or -1");
3130                }
3131                primitive
3132            })
3133            .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3134            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3135            .collect::<Vec<_>>()
3136            .into_iter()
3137            .rev()
3138            .collect();
3139
3140        if new_shape.contains(&0) {
3141            panic!("Cannot substitute -1 for a non-existing dimension");
3142        }
3143
3144        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3145
3146        Shape::from(new_shape)
3147    }
3148}
3149
3150impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3151where
3152    B: Backend,
3153    K: BasicOps<B>,
3154    K::Elem: Debug + Copy + Serialize,
3155{
3156    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3157        let data = self.to_data();
3158        data.serialize(serializer)
3159    }
3160}
3161
3162impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3163where
3164    B: Backend,
3165    K: BasicOps<B>,
3166    K::Elem: Debug + Copy + Deserialize<'de>,
3167{
3168    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3169        let tensor = Tensor::from_data(
3170            TensorData::deserialize(deserializer)?,
3171            &<B::Device as Default>::default(),
3172        );
3173        Ok(tensor)
3174    }
3175}
3176
3177#[cfg(test)]
3178mod tests {
3179    use crate::{Shape, s};
3180
3181    #[test]
3182    fn slice_range_single_dim_leading() {
3183        let shape = Shape::new([8, 4]);
3184
3185        // Half-open range
3186        let slices = shape.clone().into_slices([0..5]);
3187        assert_eq!(slices[0].to_range(8), 0..5);
3188        let slices = shape.clone().into_slices([-3..-1]);
3189        assert_eq!(slices[0].to_range(8), 5..7);
3190
3191        // Inclusive range
3192        let slices = shape.clone().into_slices([0..=4]);
3193        assert_eq!(slices[0].to_range(8), 0..5);
3194        let slices = shape.clone().into_slices([-2..=-1]);
3195        assert_eq!(slices[0].to_range(8), 6..8);
3196
3197        // Unbounded start
3198        let slices = shape.clone().into_slices([..3]);
3199        assert_eq!(slices[0].to_range(8), 0..3);
3200        let slices = shape.clone().into_slices([..-5]);
3201        assert_eq!(slices[0].to_range(8), 0..3);
3202
3203        // Unbounded end
3204        let slices = shape.clone().into_slices([5..]);
3205        assert_eq!(slices[0].to_range(8), 5..8);
3206        let slices = shape.clone().into_slices([-3..]);
3207        assert_eq!(slices[0].to_range(8), 5..8);
3208
3209        // Full range
3210        let slices = shape.into_slices([..]);
3211        assert_eq!(slices[0].to_range(8), 0..8);
3212    }
3213
3214    #[test]
3215    fn test_negative_slice_indices() {
3216        use crate::Slice;
3217
3218        // Test negative indices conversion
3219        let slice: Slice = (-3..-1).into();
3220        assert_eq!(slice.start, -3);
3221        assert_eq!(slice.end, Some(-1));
3222
3223        // Test to_range conversion with size 8
3224        let range = slice.to_range(8);
3225        assert_eq!(range, 5..7);
3226
3227        // Test with shape slice
3228        let shape = Shape::new([8, 4]);
3229        let result = shape.clone().into_slices([-3..-1]);
3230        assert_eq!(result[0].to_range(8), 5..7);
3231
3232        // Test more negative index cases
3233        let slice2: Slice = (-5..).into();
3234        assert_eq!(slice2.to_range(10), 5..10);
3235
3236        let slice3: Slice = (..-2).into();
3237        assert_eq!(slice3.to_range(10), 0..8);
3238
3239        // Test with s! macro - single dimension returns Slice directly
3240        let slice4 = s![-3..-1];
3241        assert_eq!(slice4.start, -3);
3242        assert_eq!(slice4.end, Some(-1));
3243    }
3244
3245    #[test]
3246    fn slice_range_multi_dim() {
3247        let shape = Shape::new([8, 4]);
3248
3249        // Multiple ways to provide ranges
3250        let slices = shape.clone().into_slices([0..5, 0..4]);
3251        assert_eq!(slices[0].to_range(8), 0..5);
3252        assert_eq!(slices[1].to_range(4), 0..4);
3253
3254        let slices = shape.clone().into_slices([0.., 0..]);
3255        assert_eq!(slices[0].to_range(8), 0..8);
3256        assert_eq!(slices[1].to_range(4), 0..4);
3257
3258        let slices = shape.clone().into_slices([0..=7, 0..=3]);
3259        assert_eq!(slices[0].to_range(8), 0..8);
3260        assert_eq!(slices[1].to_range(4), 0..4);
3261
3262        let slices = shape.clone().into_slices([0..5, 0..3]);
3263        assert_eq!(slices[0].to_range(8), 0..5);
3264        assert_eq!(slices[1].to_range(4), 0..3);
3265
3266        let slices = shape.into_slices([0.., 0..]);
3267        assert_eq!(slices[0].to_range(8), 0..8);
3268        assert_eq!(slices[1].to_range(4), 0..4);
3269    }
3270
3271    #[test]
3272    fn slice_range_multi_dim_index() {
3273        let shape = Shape::new([8, 4]);
3274
3275        // Indices (single integer) should also convert to correct range
3276        let slices = shape.clone().into_slices([0, 2]);
3277        assert_eq!(slices[0].to_range(8), 0..1);
3278        assert_eq!(slices[1].to_range(4), 2..3);
3279
3280        let slices = shape.into_slices([-1, -1]);
3281        assert_eq!(slices[0].to_range(8), 7..8);
3282        assert_eq!(slices[1].to_range(4), 3..4);
3283    }
3284
3285    #[test]
3286    fn slice_range_multi_dim_heterogeneous() {
3287        // Slice macro `s![]` can be used to provide different range types
3288        let shape = Shape::new([8, 4, 2]);
3289        let slice = s![0..5, .., -1];
3290        let slices = shape.into_slices(slice);
3291        assert_eq!(slices[0].to_range(8), 0..5);
3292        assert_eq!(slices[1].to_range(4), 0..4);
3293        assert_eq!(slices[2].to_range(2), 1..2);
3294
3295        let shape = Shape::new([8, 4, 2, 3]);
3296        let slice = s![..=4, 0..=3, .., -2..];
3297        let slices = shape.into_slices(slice);
3298        assert_eq!(slices[0].to_range(8), 0..5);
3299        assert_eq!(slices[1].to_range(4), 0..4);
3300        assert_eq!(slices[2].to_range(2), 0..2);
3301        assert_eq!(slices[3].to_range(3), 1..3);
3302
3303        let shape = Shape::new([3, 4]);
3304        let slice = s![1..-1, ..];
3305        let slices = shape.into_slices(slice);
3306        assert_eq!(slices[0].to_range(3), 1..2);
3307        assert_eq!(slices[1].to_range(4), 0..4);
3308    }
3309}