Skip to main content

burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2use crate::backend::ExecutionError;
3use crate::check::unwrap_shape_reshape;
4
5use burn_backend::Scalar;
6pub use burn_backend::tensor::BasicOps;
7
8use alloc::vec::Vec;
9
10use alloc::format;
11use alloc::string::String;
12use alloc::vec;
13
14use burn_std::{SliceOps, stub::RwLock};
15use core::iter::repeat;
16use core::{fmt::Debug, ops::Range};
17use serde::{Deserialize, Deserializer};
18
19use crate::{AsIndex, Slice, SliceArg, wrap_index};
20use crate::{
21    Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,
22    backend::Backend, check,
23};
24use crate::{DType, Element};
25use crate::{IndexingUpdateOp, TensorCreationOptions};
26use crate::{cast::ToElement, check::TensorCheck};
27use serde::{Serialize, Serializer};
28
29/// A tensor with a given backend, shape and data type.
30///
31/// # Indexing
32/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
33/// or [`select`](Tensor::select) for numeric types.
34///
35/// ## Example
36///
37/// ```rust
38/// use burn_tensor::backend::Backend;
39/// use burn_tensor::Tensor;
40/// use burn_tensor::Int;
41///
42/// fn example<B: Backend>() {
43///     let device = Default::default();
44///
45///     let tensor = Tensor::<B, 2>::from_data(
46///         [
47///             [3.0, 4.9, 2.0],
48///             [2.0, 1.9, 3.0],
49///             [6.0, 1.5, 7.0],
50///             [3.0, 4.9, 9.0],
51///         ],
52///         &device,
53///     );
54///
55///     // Slice the tensor to get the second and third rows:
56///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
57///     // The resulting tensor will have dimensions [2, 3].
58///     let slice = tensor.clone().slice([1..3]);
59///     println!("{slice}");
60///
61///     // Slice the tensor to get the first two rows and the first 2 columns:
62///     // [[3.0, 4.9], [2.0, 1.9]]
63///     // The resulting tensor will have dimensions [2, 2].
64///     let slice = tensor.clone().slice([0..2, 0..2]);
65///     println!("{slice}");
66///
67///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
68///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
69///     // The resulting tensor will have dimensions [4, 2]
70///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
71///     let indexed = tensor.select(1, indices);
72///     println!("{indexed}");
73/// }
74/// ```
75#[derive(new, Clone, Debug)]
76pub struct Tensor<B, const D: usize, K = Float>
77where
78    B: Backend,
79    K: TensorKind<B>,
80{
81    pub(crate) primitive: K::Primitive,
82}
83
84impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
85where
86    B: Backend,
87    K: BasicOps<B>,
88    T: Into<TensorData>,
89{
90    fn from(value: T) -> Self {
91        Tensor::from_data(value.into(), &Default::default())
92    }
93}
94
95impl<B, const D: usize, K> Tensor<B, D, K>
96where
97    B: Backend,
98    K: BasicOps<B>,
99    K::Elem: Element,
100{
101    /// Executes an operation on the tensor and modifies its value.
102    ///
103    /// # Notes
104    ///
105    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
106    /// no other reference pointing to the same tensor.
107    ///
108    /// Wrapping operations with inplace is not an optimization, it's mainly there if you
109    /// want to mutate a tensor by using owned operations. A plausible usage would be to
110    /// update the weights of a mutable model reference.
111    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
112        let mut tensor_owned = Tensor::empty([0; D], &self.device());
113        core::mem::swap(&mut tensor_owned, self);
114
115        let mut tensor_new = func(tensor_owned);
116        core::mem::swap(&mut tensor_new, self);
117    }
118
119    /// Converts the tensor into a primitive tensor.
120    pub fn into_primitive(self) -> K::Primitive {
121        self.primitive
122    }
123
124    /// Converts from a primitive tensor into a tensor.
125    pub fn from_primitive(tensor: K::Primitive) -> Self {
126        Self::new(tensor)
127    }
128
129    /// Returns the number of dimensions of the tensor.
130    pub fn rank(&self) -> usize {
131        self.primitive.rank()
132    }
133
134    /// Returns the tensor primitive data type.
135    ///
136    /// # Note
137    /// Some element types are encoded in different primitive types depending on the backend
138    /// (e.g., bool could be encoded as `u8` or `u32`).
139    pub fn dtype(&self) -> DType {
140        self.primitive.dtype()
141    }
142
143    /// Create an empty tensor of the given shape.
144    ///
145    /// # Arguments
146    ///
147    /// - `shape`: The shape of the tensor.
148    /// - `device`: The device where the tensor will be created.
149    ///
150    /// # Example
151    /// ```rust
152    /// use burn_tensor::backend::Backend;
153    /// use burn_tensor::Tensor;
154    ///
155    /// fn example<B: Backend>() {
156    ///    let device = Default::default();
157    ///    // Create an empty tensor with dimensions [2, 3, 4].
158    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
159    /// }
160    /// ```
161    pub fn empty<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
162        let opt = options.into();
163        let shape = shape.into();
164        let dtype = opt.resolve_dtype::<K>();
165        check!(TensorCheck::creation_ops::<D>("Empty", &shape));
166        Self::new(K::empty(shape, &opt.device, dtype))
167    }
168
169    /// Create a tensor of the given shape where each element is zero.
170    ///
171    /// # Example
172    ///
173    /// ```rust
174    /// use burn_tensor::backend::Backend;
175    /// use burn_tensor::{Tensor, Shape};
176    ///
177    /// fn example<B: Backend>() {
178    ///    let device = B::Device::default();
179    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
180    ///    println!("{tensor}");
181    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
182    /// }
183    /// ```
184    pub fn zeros<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
185        let opt = options.into();
186        let shape = shape.into();
187        let dtype = opt.resolve_dtype::<K>();
188        check!(TensorCheck::creation_ops::<D>("Zeros", &shape));
189        Self::new(K::zeros(shape, &opt.device, dtype))
190    }
191
192    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
193    ///
194    /// # Example
195    ///
196    /// ```rust
197    /// use burn_tensor::backend::Backend;
198    /// use burn_tensor::{Tensor, Shape};
199    ///
200    /// fn example<B: Backend>() {
201    ///   let device = B::Device::default();
202    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
203    ///   let tensor = tensor.zeros_like();
204    ///   println!("{tensor}");
205    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
206    /// }
207    /// ```
208    pub fn zeros_like(&self) -> Self {
209        Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
210    }
211
212    /// Create a tensor of the given shape where each element is one.
213    ///
214    /// # Example
215    ///
216    /// ```rust
217    /// use burn_tensor::backend::Backend;
218    /// use burn_tensor::{Tensor, Shape};
219    ///
220    /// fn example<B: Backend>() {
221    ///   let device = B::Device::default();
222    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
223    ///   println!("{tensor}");
224    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
225    /// }
226    /// ```
227    pub fn ones<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
228        let opt = options.into();
229        let shape = shape.into();
230        let dtype = opt.resolve_dtype::<K>();
231        check!(TensorCheck::creation_ops::<D>("Ones", &shape));
232        Self::new(K::ones(shape, &opt.device, dtype))
233    }
234
235    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use burn_tensor::backend::Backend;
241    /// use burn_tensor::{Tensor, Shape};
242    ///
243    /// fn example<B: Backend>() {
244    ///    let device = B::Device::default();
245    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
246    ///    let tensor = tensor.ones_like();
247    ///    println!("{tensor}");
248    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
249    /// }
250    /// ```
251    pub fn ones_like(&self) -> Self {
252        Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
253    }
254
255    /// Create a tensor of the given shape where each element is equal to the provided value.
256    ///
257    /// # Example
258    ///
259    /// ```rust
260    /// use burn_tensor::backend::Backend;
261    /// use burn_tensor::{Tensor, Shape};
262    ///
263    /// fn example<B: Backend>() {
264    ///   let device = B::Device::default();
265    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
266    ///   println!("{tensor}");
267    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
268    /// }
269    /// ```
270    pub fn full<S: Into<Shape>, E: ElementConversion>(
271        shape: S,
272        fill_value: E,
273        options: impl Into<TensorCreationOptions<B>>,
274    ) -> Self {
275        let opt = options.into();
276        let shape = shape.into();
277        let dtype = opt.resolve_dtype::<K>();
278        check!(TensorCheck::creation_ops::<D>("Full", &shape));
279        Self::new(K::full(
280            shape,
281            Scalar::new(fill_value, &dtype),
282            &opt.device,
283            dtype,
284        ))
285    }
286
287    /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
288    /// filled with the provided value.
289    ///
290    /// # Example
291    ///
292    /// ```rust
293    /// use burn_tensor::backend::Backend;
294    /// use burn_tensor::{Tensor, Shape};
295    ///
296    /// fn example<B: Backend>() {
297    ///    let device = B::Device::default();
298    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
299    ///    let tensor = tensor.full_like(5.0);
300    ///    println!("{tensor}");
301    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
302    /// }
303    /// ```
304    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
305        let dtype = self.dtype();
306        Self::new(K::full(
307            self.shape(),
308            Scalar::new(fill_value, &dtype),
309            &self.device(),
310            dtype,
311        ))
312    }
313
314    /// Returns the dimensions of the current tensor.
315    ///
316    /// # Example
317    /// ```rust
318    /// use burn_tensor::backend::Backend;
319    /// use burn_tensor::Tensor;
320    ///
321    /// fn example<B: Backend>() {
322    ///   let device = Default::default();
323    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
324    ///   let dims = tensor.dims(); // [2, 3, 4]
325    ///   println!("{dims:?}");
326    /// }
327    /// ```
328    pub fn dims(&self) -> [usize; D] {
329        Self::shape(self).dims()
330    }
331
332    /// Returns the shape of the current tensor.
333    ///
334    /// # Example
335    /// ```rust
336    /// use burn_tensor::backend::Backend;
337    /// use burn_tensor::Tensor;
338    ///
339    /// fn example<B: Backend>() {
340    ///    let device = Default::default();
341    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
342    ///    // Shape { dims: [2, 3, 4] }
343    ///    let shape = tensor.shape();
344    /// }
345    /// ```
346    pub fn shape(&self) -> Shape {
347        self.primitive.shape()
348    }
349
350    /// Reshape the tensor to have the given shape.
351    ///
352    /// The tensor has the same data and number of elements as the input.
353    ///
354    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
355    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
356    ///
357    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
358    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
359    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
360    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
361    /// with [1, 3, 4] dimensions to [1, 12].
362    ///
363    /// # Arguments
364    /// - `shape`: The new shape of the tensor.
365    ///
366    /// # Panics
367    /// - If the tensor contains more than one `-1` in the shape.
368    /// - If the tensor contains values that are not positive (other than -1).
369    /// - If the shape does not match the number of elements of the original shape.
370    ///
371    /// # Example
372    ///
373    /// ```rust
374    /// use burn_tensor::backend::Backend;
375    /// use burn_tensor::Tensor;
376    ///
377    /// fn example<B: Backend>() {
378    ///    let device = Default::default();
379    ///    // Create a tensor with dimensions [2, 3, 4]
380    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
381    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
382    ///    let reshaped = tensor.reshape([2, -1]);
383    ///    println!("{reshaped}");
384    /// }
385    /// ```
386    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
387        // Convert reshape args to shape
388        let shape = shape.into_shape::<D2>(self.shape());
389        Tensor::new(K::reshape(self.primitive, shape))
390    }
391
392    /// Transpose the tensor.
393    ///
394    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
395    /// applied on the last two dimensions. For example, the transpose of a tensor with shape
396    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
397    ///
398    /// See also [`permute`](Tensor::permute).
399    ///
400    /// # Arguments
401    ///
402    /// * `tensor` - The tensor to transpose.
403    ///
404    /// # Returns
405    ///
406    /// The transposed tensor.
407    ///
408    /// # Example
409    ///
410    /// ```rust
411    /// use burn_tensor::backend::Backend;
412    /// use burn_tensor::Tensor;
413    ///
414    /// fn example<B: Backend>() {
415    ///     let device = Default::default();
416    ///     // Create a 2D tensor of shape [2, 3]
417    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
418    ///
419    ///     // Transpose the tensor:
420    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
421    ///     // The resulting tensor will have dimensions [3, 2].
422    ///     let transposed = tensor.transpose();
423    ///     println!("{transposed}");
424    /// }
425    /// ```
426    pub fn transpose(self) -> Tensor<B, D, K> {
427        Tensor::new(K::transpose(self.primitive))
428    }
429
430    /// Alias for `transpose`.
431    #[inline(always)]
432    pub fn t(self) -> Tensor<B, D, K> {
433        self.transpose()
434    }
435
436    /// Swaps two dimensions of a tensor.
437    ///
438    /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
439    ///
440    /// # Arguments
441    ///
442    /// * `tensor` - The tensor to swap the dimensions of.
443    /// * `dim1` - The first dimension to swap, supports negative indexing.
444    /// * `dim2` - The second dimension to swap, supports negative indexing.
445    ///
446    /// # Returns
447    ///
448    /// The tensor with the dimensions swapped.
449    ///
450    /// # Panics
451    ///
452    /// When dimensions are out of bounds.
453    ///
454    /// # Example
455    ///
456    /// ```rust
457    /// use burn_tensor::backend::Backend;
458    /// use burn_tensor::Tensor;
459    ///
460    /// fn example<B: Backend>() {
461    ///     let device = Default::default();
462    ///     // Create a 2D tensor of shape [2, 3]
463    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
464    ///
465    ///     // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
466    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
467    ///     // The resulting tensor will have dimensions [3, 2].
468    ///     let swapped = tensor.swap_dims(0, -1);
469    ///     println!("{swapped}");
470    /// }
471    /// ```
472    pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
473    where
474        Dim1: AsIndex,
475        Dim2: AsIndex,
476    {
477        let dim1 = dim1.expect_dim_index(D);
478        let dim2 = dim2.expect_dim_index(D);
479        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
480        if dim1 == dim2 {
481            self
482        } else {
483            Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
484        }
485    }
486
487    /// Permute the dimensions of the tensor.
488    ///
489    /// This is a no-op when the resolved `axes` match the current order.
490    ///
491    /// # Arguments
492    ///
493    /// * `axes` - The new order of the dimensions. The length of the axes
494    ///   must be equal to the number of dimensions of the tensor.
495    ///   The values must be unique and in the range of the number of dimensions.
496    ///   The values can be negative, in which case they are used as an offset from the end.
497    ///
498    /// # Returns
499    ///
500    /// The tensor with the dimensions permuted.
501    ///
502    /// # Example
503    ///
504    /// ```rust
505    /// use burn_tensor::backend::Backend;
506    /// use burn_tensor::Tensor;
507    ///
508    /// fn example<B: Backend>() {
509    ///     let device = Default::default();
510    ///     // Create a 2D tensor of shape [3, 2]
511    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
512    ///
513    ///     // Permute the dimensions 1 and 0:
514    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
515    ///     // The resulting tensor will have dimensions [3, 2].
516    ///     let permuted = tensor.permute([1, 0]);
517    ///     println!("{permuted}");
518    /// }
519    /// ```
520    pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
521    where
522        Dim: AsIndex,
523    {
524        let mut no_op = true;
525        let mut fixed_axes = [0; D];
526        for (i, axis) in axes.into_iter().enumerate() {
527            let dim = axis.expect_dim_index(D);
528            no_op &= dim == i;
529            fixed_axes[i] = dim;
530        }
531
532        if no_op {
533            self
534        } else {
535            check!(TensorCheck::permute(fixed_axes));
536            Tensor::new(K::permute(self.primitive, &fixed_axes))
537        }
538    }
539
540    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
541    ///
542    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
543    /// at the positions not specified in destination.
544    ///
545    /// # Arguments
546    ///
547    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
548    ///   The values can be negative, in which case they are used as an offset from the end.
549    ///
550    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
551    ///
552    /// # Panics
553    ///
554    /// - If the source and destination dimensions are not of the same length.
555    /// - If the source and destination vectors contain duplicate values.
556    /// - If the source and destination vectors contain values that are out of bounds.
557    ///
558    /// # Returns
559    ///
560    /// The tensor with the dimensions moved.
561    ///
562    /// # Example
563    ///
564    /// ```rust
565    /// use burn_tensor::backend::Backend;
566    /// use burn_tensor::Tensor;
567    ///
568    /// fn example<B: Backend>() {
569    ///     let device = Default::default();
570    ///     // Create a 3D tensor of shape [3, 2, 1]
571    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
572    ///
573    ///     // Move the dimensions 0 and 1:
574    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
575    ///     // The resulting tensor will have dimensions [2, 3, 1].
576    ///     let moved = tensor.movedim(1, 0);
577    ///     println!("{moved}");
578    /// }
579    /// ```
580    ///
581    /// # Note
582    ///
583    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
584    /// for it
585    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
586        let source_dims = src.into_dim_vec::<D>();
587        let destination_dims = dst.into_dim_vec::<D>();
588
589        check!(TensorCheck::movedim_args_length(
590            &source_dims,
591            &destination_dims
592        ));
593
594        let mut m = [-1; D];
595        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
596            m[d] = s as isize;
597        }
598        let mut axes: [isize; D] = [0; D];
599        let mut source_i = 0;
600        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
601            *item = if m[dest_i] != -1 {
602                m[dest_i]
603            } else {
604                while source_dims.contains(&source_i) {
605                    source_i += 1;
606                }
607                let result = source_i as isize;
608                source_i += 1;
609                result
610            };
611        }
612
613        self.permute(axes)
614    }
615
616    /// Reverse the order of elements in the tensor along the given dimensions.
617    ///
618    /// # Arguments
619    ///
620    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
621    ///   The values can be negative, in which case they are used as an offset from the end.
622    ///
623    /// # Returns
624    ///
625    /// The tensor with the axes flipped.
626    ///
627    /// # Example
628    ///
629    /// ```rust
630    /// use burn_tensor::backend::Backend;
631    /// use burn_tensor::Tensor;
632    ///
633    /// fn example<B: Backend>() {
634    ///     let device = Default::default();
635    ///     // Create a 2D tensor with dimensions [4, 3]
636    ///     let tensor = Tensor::<B, 2>::from_data(
637    ///         [
638    ///             [3.0, 4.9, 2.0],
639    ///             [2.0, 1.9, 3.0],
640    ///             [4.0, 5.9, 8.0],
641    ///             [1.4, 5.8, 6.0],
642    ///         ],
643    ///         &device,
644    ///     );
645    ///
646    ///     // Flip the elements in dimensions 0 and 1:
647    ///     // [[6.0, 5.8, 1.4],
648    ///     //  [8.0, 5.9, 4.0],
649    ///     //  [3.0, 1.9, 2.0],
650    ///     //  [2.0, 4.9, 3.0]]
651    ///     // The resulting tensor will have dimensions [4, 3].
652    ///     let flipped = tensor.flip([0, 1]);
653    ///     println!("{flipped}");
654    /// }
655    /// ```
656    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
657        // Convert the axes to usize and handle negative values without using vector
658        let mut transformed_axes: [usize; N] = [0; N];
659        for (i, &x) in axes.iter().enumerate() {
660            transformed_axes[i] = if x < 0 {
661                (D as isize + x) as usize
662            } else {
663                x as usize
664            };
665        }
666
667        // Check if the axes are valid
668        check!(TensorCheck::flip(D, &transformed_axes));
669
670        Tensor::new(K::flip(self.primitive, &transformed_axes))
671    }
672
673    /// Flatten the tensor along a given range of dimensions.
674    ///
675    /// This function collapses the specified range of dimensions into a single dimension,
676    /// effectively flattening the tensor in that range.
677    ///
678    /// # Arguments
679    ///
680    /// - `start_dim`: The starting dimension of the range to be flattened,
681    ///   supports negative indexing.
682    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
683    ///   supports negative indexing.
684    ///
685    /// # Type Parameters
686    ///
687    /// - `D2`: The resulting number of dimensions in the flattened tensor.
688    ///
689    /// # Returns
690    ///
691    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
692    ///
693    /// # Example
694    ///
695    /// ```rust
696    ///
697    /// use burn_tensor::backend::Backend;
698    /// use burn_tensor::{Tensor, Shape};
699    ///
700    /// fn example<B: Backend>() {
701    ///     let device = Default::default();
702    ///     // Create a 3D tensor with dimensions [2, 3, 4]
703    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
704    ///
705    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
706    ///     // The resulting tensor will have dimensions [2, 12]
707    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
708    ///     println!("{flattened}");
709    /// }
710    /// ```
711    pub fn flatten<const D2: usize>(
712        self,
713        start_dim: impl AsIndex,
714        end_dim: impl AsIndex,
715    ) -> Tensor<B, D2, K> {
716        let start_dim = start_dim.expect_dim_index(D);
717        let end_dim = end_dim.expect_dim_index(D);
718        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
719        let new_shape = self.shape().flatten_dims(start_dim, end_dim);
720
721        Tensor::new(K::reshape(self.primitive, new_shape))
722    }
723
724    /// Squeeze the tensor along all dimensions, removing dimensions
725    /// of size one, and effectively reducing the rank of the tensor.
726    ///
727    /// # Type Parameters
728    ///
729    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
730    ///
731    /// # Returns
732    ///
733    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
734    ///
735    /// # Example
736    ///
737    /// ```rust
738    ///
739    /// use burn_tensor::backend::Backend;
740    /// use burn_tensor::{Tensor, Shape};
741    ///
742    /// fn example<B: Backend>() {
743    ///     let device = Default::default();
744    ///     // Create a 4D tensor with dimensions [1, 3, 1, 3]
745    ///     let tensor = Tensor::<B, 4>::from_data(
746    ///         [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
747    ///         &device,
748    ///     );
749    ///
750    ///     // Squeeze the tensor dimensions.
751    ///     // The resulting tensor will have dimensions [3, 3].
752    ///     let squeezed = tensor.squeeze::<2>();
753    ///     println!("{squeezed}");
754    /// }
755    /// ```
756    pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
757        let new_dims = self
758            .shape()
759            .iter()
760            .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
761            .collect::<Vec<_>>();
762        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
763
764        Tensor::new(K::reshape(self.primitive, new_dims.into()))
765    }
766
767    /// Squeeze the tensor along the given dimension, removing the specified dimension
768    /// of size one, and effectively reducing the rank of the tensor by one.
769    ///
770    /// # Arguments
771    ///
772    /// - `dim`: The dimension to be squeezed.
773    ///
774    /// # Type Parameters
775    ///
776    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
777    ///
778    /// # Panics
779    ///
780    /// If the size in the squeezed dimension is not 1.
781    ///
782    /// # Returns
783    ///
784    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
785    ///
786    /// # Example
787    ///
788    /// ```rust
789    ///
790    /// use burn_tensor::backend::Backend;
791    /// use burn_tensor::{Tensor, Shape};
792    ///
793    /// fn example<B: Backend>() {
794    ///     let device = Default::default();
795    ///     // Create a 3D tensor with dimensions [3, 1, 3]
796    ///     let tensor = Tensor::<B, 3>::from_data(
797    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
798    ///         &device,
799    ///     );
800    ///
801    ///     // Squeeze the dimension 1.
802    ///     // The resulting tensor will have dimensions [3, 3].
803    ///     let squeezed = tensor.squeeze_dim::<2>(1);
804    ///     println!("{squeezed}");
805    /// }
806    /// ```
807    pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
808        check!(TensorCheck::squeeze::<D2>(dim, &self.shape()));
809
810        let current_dims = self.shape();
811        let mut new_dims: [usize; D2] = [0; D2];
812
813        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
814        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
815
816        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
817        Tensor::new(K::reshape(self.primitive, new_dims.into()))
818    }
819
820    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
821    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
822    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
823    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
824    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
825    /// from the back.
826    ///
827    /// # Arguments
828    ///
829    /// - `dims`: The dimension(s) to be squeezed.
830    ///
831    /// # Type Parameters
832    ///
833    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
834    ///
835    /// # Returns
836    ///
837    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
838    ///
839    /// # Example
840    ///
841    /// ```rust
842    ///
843    /// use burn_tensor::backend::Backend;
844    /// use burn_tensor::{Tensor, Shape};
845    ///
846    /// fn example<B: Backend>() {
847    ///     let device = Default::default();
848    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
849    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
850    ///
851    ///     // Squeeze the dimensions 1 and 3.
852    ///     // The resulting tensor will have dimensions [2, 4].
853    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
854    ///     println!("{squeezed}");
855    /// }
856    /// ```
857    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
858        let current_dims = self.shape();
859        let mut dim_indices: Vec<usize>;
860
861        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
862        if dims.is_empty() {
863            dim_indices = current_dims
864                .iter()
865                .enumerate()
866                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
867                .collect();
868        } else {
869            // If negative dims, count from the back
870            dim_indices = dims
871                .iter()
872                .map(|&d| {
873                    if d < 0 {
874                        (current_dims.len() as isize + d) as usize
875                    } else {
876                        d as usize
877                    }
878                })
879                .collect();
880        }
881
882        // Sort indices and remove duplicates
883        dim_indices.sort_unstable();
884        dim_indices.dedup();
885
886        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
887        check!(TensorCheck::squeeze_dims_input::<D2>(
888            &dim_indices,
889            &current_dims
890        ));
891
892        // Calculate new dimensions
893        let mut new_dims = Vec::new();
894        for (index, &dim_size) in current_dims.iter().enumerate() {
895            // Exclude the dimension if it's explicitly marked for squeezing
896            if dim_indices.contains(&index) {
897                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
898                continue;
899            }
900            new_dims.push(dim_size);
901        }
902
903        // Check that after squeezing, we still respect the D2 size
904        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
905
906        Tensor::new(K::reshape(self.primitive, new_dims.into()))
907    }
908
909    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
910    ///
911    /// # Type Parameters
912    ///
913    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.
914    ///
915    /// # Panics
916    ///
917    /// If the output size `D2` is smaller than the current number of dimensions.
918    ///
919    /// # Returns
920    ///
921    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
922    ///
923    /// # Example
924    ///
925    /// ```rust
926    /// use burn_tensor::backend::Backend;
927    /// use burn_tensor::{Tensor, Shape};
928    ///
929    /// fn example<B: Backend>() {
930    ///     let device = Default::default();
931    ///     // Create a 2D tensor with dimensions [3, 3]
932    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
933    ///     // Unsqueeze the tensor up to 4 dimensions.
934    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
935    ///     let unsqueezed = tensor.unsqueeze::<4>();
936    ///     println!("{unsqueezed}");
937    /// }
938    /// ```
939    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
940        check!(TensorCheck::unsqueeze::<D, D2>());
941
942        let mut dims = [1; D2];
943        let num_ones = D2 - D;
944        let shape = self.shape();
945
946        dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
947
948        let shape = Shape::new(dims);
949        self.reshape(shape)
950    }
951
952    /// Creates a new tensor with a dimension of size one inserted at the specified position.
953    ///
954    /// # Example
955    ///
956    /// ```rust
957    /// use burn_tensor::backend::Backend;
958    /// use burn_tensor::{Tensor, Shape};
959    ///
960    /// fn example<B: Backend>() {
961    ///     let device = Default::default();
962    ///     // Create a 2D tensor with dimensions [3, 3]
963    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
964    ///     // Unsqueeze the dimension 1.
965    ///     // The resulting tensor will have dimensions [3, 1, 3].
966    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
967    ///     println!("{unsqueezed}");
968    /// }
969    /// ```
970    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
971        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
972
973        let mut dims = [1; D2];
974        let shape = self.shape();
975
976        dims[0..dim].copy_from_slice(&shape[0..dim]);
977
978        if dim < D {
979            dims[dim] = 1;
980            dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
981        } else {
982            dims[dim] = 1;
983        }
984
985        let shape = Shape::new(dims);
986        self.reshape(shape)
987    }
988
989    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
990    /// The indices can be negative, in which case they are counted from the last to the first dimension.
991    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
992    /// is the number of duplicates.
993    /// # Example
994    ///
995    /// ```rust
996    /// use burn_tensor::backend::Backend;
997    /// use burn_tensor::{Tensor, Shape};
998    ///
999    /// fn example<B: Backend>() {
1000    ///     let device = Default::default();
1001    ///     // Create a 3D tensor with dimensions [3, 4, 5]
1002    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
1003    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
1004    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
1005    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
1006    ///     println!("{unsqueezed}");
1007    /// }
1008    /// ```
1009    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[impl AsIndex]) -> Tensor<B, D2, K> {
1010        let mut new_dims = [1; D2];
1011        let old_dims = self.shape();
1012        //for checking if the dimension is in the acceptable range
1013
1014        //part 1: convert the negative indices to positive
1015        let mut neg_offset = D2;
1016        let mut dim_indices = axes
1017            .iter()
1018            .map(|d| {
1019                let d = d.as_index();
1020                // check if the dimension is in the acceptable range
1021                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(d));
1022                (if d < 0 {
1023                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
1024                    d + neg_offset as isize + 1
1025                } else {
1026                    d
1027                }) as usize
1028            })
1029            .collect::<Vec<usize>>();
1030
1031        //sort the indices
1032        dim_indices.sort_unstable();
1033
1034        // Per the documented semantics, duplicate axes mean "insert N dims at that index".
1035        // After sorting, N insertions at position `i` logically occupy positions
1036        // `i, i+1, ..., i+N-1` in the output, so bump each duplicate to the next slot.
1037        // Example: sorted `[0, 0, 3]` becomes `[0, 1, 3]`, matching the intent of
1038        // "two 1s starting at index 0, plus one 1 at index 3".
1039        for i in 1..dim_indices.len() {
1040            if dim_indices[i] <= dim_indices[i - 1] {
1041                dim_indices[i] = dim_indices[i - 1] + 1;
1042            }
1043        }
1044
1045        // Re-validate after normalization: bumping duplicates forward can push the
1046        // last index past `D2 - 1` (e.g. `[2, 2]` targeting rank 3 normalizes to
1047        // `[2, 3]`). The per-axis check above only runs on pre-normalization values,
1048        // so we re-check here to surface a clear `TensorCheck` error instead of
1049        // letting the copy loop panic on an out-of-bounds `old_dims` read.
1050        for &dim_index in &dim_indices {
1051            check!(TensorCheck::unsqueeze_dims::<{ D2 }>(dim_index as isize));
1052        }
1053
1054        // Loop over the entries/indices of the `new_dims` array.
1055        // When the current entry should be 1 from the unsqueeze operation, simply increment
1056        // the index for `dims_indices` to account for "adding" its entry to `new_dims`.
1057        // Otherwise, the dim from the current entry of `old_dims` should be copied to `new_dims`.
1058        let mut dim_indices_curr_idx = 0;
1059        let mut old_dims_curr_idx = 0;
1060        for new_dims_curr_idx in 0..D2 {
1061            // If all indices in `dim_indices` have been processed, then
1062            // simply copy all the remaining dims from `old_dims` to `new_dims`
1063            if dim_indices_curr_idx == dim_indices.len() {
1064                new_dims[new_dims_curr_idx..].copy_from_slice(&old_dims[old_dims_curr_idx..]);
1065                break;
1066            }
1067
1068            if new_dims_curr_idx == dim_indices[dim_indices_curr_idx] {
1069                dim_indices_curr_idx += 1;
1070            } else {
1071                new_dims[new_dims_curr_idx] = old_dims[old_dims_curr_idx];
1072                old_dims_curr_idx += 1;
1073            }
1074        }
1075
1076        //lastly, create the shape and reshape
1077        let shape = Shape::new(new_dims);
1078        self.reshape(shape)
1079    }
1080
1081    /// Roll operation along a specific dimension; wrapping around the elements.
1082    ///
1083    /// ## Parameters
1084    ///
1085    /// - `shift`: The roll extent; supports negative values and wraps around.
1086    /// - `dim`: The dimension to roll; supports negative indexing.
1087    ///
1088    /// ## Returns
1089    ///
1090    /// A new tensor with the specified dimension rolled by the given shift amount.
1091    pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
1092    where
1093        Shift: AsIndex,
1094        Dim: AsIndex,
1095    {
1096        let dim = dim.expect_dim_index(D);
1097        let size = self.shape()[dim];
1098        if size == 0 {
1099            // If the dimension is empty, return the tensor as is.
1100            return self;
1101        }
1102
1103        let shift = wrap_index(shift, size);
1104        if shift == 0 {
1105            // If the shift is zero, return the tensor as is.
1106            return self;
1107        }
1108
1109        self.unchecked_roll_dim(shift, dim)
1110    }
1111
1112    /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1113    ///
1114    /// ## Parameters
1115    ///
1116    /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1117    /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1118    ///
1119    /// ## Returns
1120    ///
1121    /// A new tensor with the specified dimension rolled by the given shift amount.
1122    #[inline(always)]
1123    fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1124        #[cfg(debug_assertions)]
1125        {
1126            let size = self.shape()[dim];
1127            assert!(
1128                0 < shift && shift < size,
1129                "Expected: 0 < shift < size: found shift={shift}, size={size}",
1130            );
1131            assert!(
1132                dim < self.shape().num_dims(),
1133                "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1134            );
1135        }
1136
1137        Tensor::cat(
1138            vec![
1139                self.clone().slice_dim(dim, shift..),
1140                self.slice_dim(dim, ..shift),
1141            ],
1142            dim,
1143        )
1144    }
1145
1146    /// Roll operation.
1147    ///
1148    /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1149    ///
1150    /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1151    ///
1152    /// ## Parameters
1153    ///
1154    /// - `shifts`: A slice of shifts corresponding to each dimension;
1155    ///   supports negative values and wraps around.
1156    /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1157    ///
1158    /// ## Returns
1159    ///
1160    /// A new tensor with the specified dimensions rolled by the given shifts.
1161    pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1162    where
1163        Shift: AsIndex,
1164        Dim: AsIndex,
1165    {
1166        assert_eq!(
1167            dims.len(),
1168            shifts.len(),
1169            "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1170        );
1171
1172        // This is a fair amount of complexity, which could be replaced
1173        // by a simple canonicalization of `dims` and wrapping of `shifts`.
1174        // The work is done here to ensure that any roll operation
1175        // which could be a no-op is a no-op; simplifying the accounting
1176        // needed by backend-specific implementations of the inner roll op.
1177
1178        let item_count = dims.len();
1179
1180        let shape = self.shape();
1181
1182        // Accumulate the effective shifts for each dimension.
1183        let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1184        for i in 0..item_count {
1185            let dim = dims[i].expect_dim_index(D);
1186            accumulated_shifts[dim] += shifts[i].as_index();
1187        }
1188
1189        // Do this after we've checked the validity of `dims` and `shifts`.
1190        if self.shape().num_elements() == 0 {
1191            // If the tensor is empty, return it as is.
1192            return self;
1193        }
1194
1195        // Wrap the accumulated shifts, and filter out empty dimensions.
1196        let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1197        let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1198        for dim in 0..shape.len() {
1199            // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1200            let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1201            if shift == 0 {
1202                continue;
1203            }
1204
1205            effective_dims.push(dim);
1206            effective_shifts.push(shift);
1207        }
1208
1209        // If no shifts are needed, return the original tensor.
1210        if effective_shifts.is_empty() {
1211            return self;
1212        }
1213
1214        // At this point:
1215        // - `dims` contains the effective dimensions to roll, in index order,
1216        // - `shifts` contains the effective usize shifts for each dimension.
1217        // - Every shift is non-zero, and less than the size of the corresponding dimension.
1218        self.unchecked_roll(&effective_shifts, &effective_dims)
1219    }
1220
1221    /// `roll` internal implementation.
1222    ///
1223    /// ## Parameters
1224    ///
1225    /// - `shifts`: A slice of shifts corresponding to each dimension;
1226    ///   must be non-empty, the same length as `dims`, and all ``1..<size>``.
1227    /// - `dims`: A slice of dimensions to roll; must be non-empty;
1228    ///   the same length as `shifts`, and must not contain repeats.
1229    ///
1230    /// ## Panics
1231    ///
1232    /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1233    ///
1234    /// ## Returns
1235    ///
1236    /// A new tensor with the specified dimensions rolled by the given shifts.
1237    #[inline(always)]
1238    fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1239        #[cfg(debug_assertions)]
1240        {
1241            assert!(!shifts.is_empty());
1242            assert_eq!(
1243                shifts.len(),
1244                dims.len(),
1245                "Shifts and dimensions must align; found {} shifts and {} dims",
1246                shifts.len(),
1247                dims.len()
1248            );
1249
1250            let mut unique_dims = dims.to_vec();
1251            unique_dims.dedup();
1252
1253            assert_eq!(
1254                unique_dims.len(),
1255                dims.len(),
1256                "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1257                unique_dims.len(),
1258                dims.len()
1259            )
1260        }
1261
1262        let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1263
1264        if dims.len() == 1 {
1265            x
1266        } else {
1267            x.unchecked_roll(&shifts[1..], &dims[1..])
1268        }
1269    }
1270
1271    /// Returns a tensor containing the elements selected from the given slices.
1272    ///
1273    /// This method provides flexible tensor slicing with support for various range types,
1274    /// negative indices, and stepped slicing. The method accepts both single slices and
1275    /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1276    ///
1277    /// # Arguments
1278    ///
1279    /// * `slices` - Can be:
1280    ///   - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1281    ///   - An array of ranges (e.g., `[0..2, 1..4]`)
1282    ///   - The [`s!`] macro output for advanced slicing with steps
1283    ///   - a `&Vec<Slice>` or `&[Slice]`
1284    ///
1285    /// # Behavior
1286    ///
1287    /// - Supports partial and full slicing in any number of dimensions
1288    /// - Handles negative indices by wrapping from the end (-1 is the last element)
1289    /// - Automatically clamps ranges that exceed tensor dimensions
1290    /// - Supports stepped slicing for selecting every nth element
1291    /// - Negative steps reverse the selection order
1292    ///
1293    /// # Panics
1294    ///
1295    /// - If the number of slices exceeds the tensor's dimensions
1296    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1297    /// - If a step is zero
1298    ///
1299    /// # Examples
1300    ///
1301    /// ```rust
1302    /// use burn_tensor::backend::Backend;
1303    /// use burn_tensor::{Tensor, Shape, s};
1304    ///
1305    /// fn example<B: Backend>() {
1306    ///     let device = B::Device::default();
1307    ///
1308    ///     // Single dimension slicing - no brackets needed!
1309    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1310    ///     let slice = tensor.clone().slice(2..8);  // Simple range
1311    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1312    ///
1313    ///     // Using s! macro for single dimension with step
1314    ///     let slice = tensor.clone().slice(s![0..10;2]);  // Every 2nd element
1315    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1316    ///
1317    ///     // Reverse a dimension with negative step
1318    ///     let slice = tensor.slice(s![..;-1]);  // Reverse entire tensor
1319    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1320    ///
1321    ///     // Multi-dimensional slicing
1322    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1323    ///
1324    ///     // Array syntax for simple ranges
1325    ///     let slice = tensor.clone().slice([1..3, 2..5]);
1326    ///     assert_eq!(slice.dims(), [2, 3]);
1327    ///
1328    ///     // Advanced multi-dimensional with s! macro
1329    ///     let slice = tensor.clone().slice(s![0..4;2, ..;-1]);  // Every 2nd row, reverse columns
1330    ///     assert_eq!(slice.dims(), [2, 6]);
1331    ///
1332    ///     // Complex 3D example with mixed slice types
1333    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1334    ///     let slice = tensor.slice(s![1..3, ..;2, -3..]);  // Rows 1-2, every 2nd col, last 3 depth
1335    ///     assert_eq!(slice.dims(), [2, 3, 3]);
1336    ///
1337    ///     // Using negative indices
1338    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1339    ///     let slice = tensor.slice(s![-2.., ..-1]);  // Last 2 rows, all but last column
1340    ///     assert_eq!(slice.dims(), [2, 5]);
1341    /// }
1342    /// ```
1343    ///
1344    /// # See Also
1345    ///
1346    /// - [`s!`] - The recommended macro for creating complex slice specifications
1347    /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1348    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1349    /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1350    ///
1351    /// [`s!`]: crate::s!
1352    pub fn slice<S>(self, slices: S) -> Self
1353    where
1354        S: SliceArg,
1355    {
1356        let shape = self.shape();
1357        let slices = slices.into_slices(&shape);
1358
1359        // Validate slices
1360        check!(TensorCheck::slice::<D>(&shape, &slices));
1361
1362        // Calculate output shape and check for empty slices
1363        let mut output_dims = shape.clone();
1364        for (dim, slice) in slices.iter().enumerate() {
1365            output_dims[dim] = slice.output_size(shape[dim]);
1366        }
1367
1368        // Return empty tensor if any dimension is 0 (empty slice)
1369        if output_dims.contains(&0) {
1370            return Self::empty(output_dims, &self.device());
1371        }
1372        Self::new(K::slice(self.primitive, &slices))
1373    }
1374
1375    /// Assigns values to a slice of the tensor and returns the updated tensor.
1376    ///
1377    /// This method supports advanced slicing with steps, including negative steps for reverse
1378    /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1379    /// providing powerful syntax for complex patterns.
1380    ///
1381    /// # Arguments
1382    ///
1383    /// * `slices` - Slice specification (same format as `slice` method)
1384    /// * `values` - Tensor with values to assign (must match slice dimensions)
1385    ///
1386    /// # Panics
1387    ///
1388    /// - If slices exceed tensor dimensions
1389    /// - If values dimensions don't match the selected slice shape
1390    /// - If a step is zero
1391    ///
1392    /// # Examples
1393    ///
1394    /// ```rust
1395    /// use burn_tensor::backend::Backend;
1396    /// use burn_tensor::{Tensor, s};
1397    ///
1398    /// fn example<B: Backend>() {
1399    ///     let device = B::Device::default();
1400    ///
1401    ///     // Simple assignment to a sub-region
1402    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1403    ///     let values = Tensor::<B, 2>::ones([2, 3], &device);
1404    ///     tensor = tensor.slice_assign([1..3, 2..5], values);
1405    ///     // Now tensor[1..3, 2..5] contains ones
1406    ///
1407    ///     // Single dimension assignment with step
1408    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1409    ///     let values = Tensor::<B, 1>::ones([5], &device);
1410    ///     tensor = tensor.slice_assign(s![0..10;2], values);
1411    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1412    ///
1413    ///     // Reverse assignment with negative step
1414    ///     let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1415    ///     let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1416    ///     tensor = tensor.slice_assign(s![..;-1], values);
1417    ///     // Assigns in reverse: [14, 13, 12, 11, 10]
1418    ///
1419    ///     // Complex multi-dimensional assignment
1420    ///     let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1421    ///     let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1422    ///     tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1423    ///     // Assigns to every 2nd row, every 2nd column, last 3 in depth
1424    ///
1425    ///     // Mixed syntax example
1426    ///     let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1427    ///     let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1428    ///     tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1429    ///     // Creates a checkerboard pattern with ones
1430    /// }
1431    /// ```
1432    ///
1433    /// # See Also
1434    ///
1435    /// - [`s!`] - The recommended macro for creating complex slice specifications
1436    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1437    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1438    ///
1439    /// [`s!`]: crate::s!
1440    pub fn slice_assign<S>(self, slices: S, values: Self) -> Self
1441    where
1442        S: SliceArg,
1443    {
1444        let shape = self.shape();
1445        let slices = slices.into_slices(&shape);
1446
1447        // Check if any slice produces 0 elements (empty assignment).
1448        // Empty assignments are no-ops and would cause issues in backend implementations.
1449        let is_empty_assignment = slices
1450            .iter()
1451            .enumerate()
1452            .any(|(i, slice)| slice.output_size(shape[i]) == 0);
1453
1454        if is_empty_assignment {
1455            return self;
1456        }
1457
1458        check!(TensorCheck::slice_assign::<D>(
1459            &shape,
1460            &values.shape(),
1461            &slices
1462        ));
1463
1464        Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1465    }
1466
1467    /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1468    ///
1469    /// Like other slice methods, accepts both single slices and arrays. However, this method
1470    /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1471    /// with a constant tensor for stepped patterns.
1472    ///
1473    /// # Arguments
1474    ///
1475    /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1476    /// * `value` - The value to fill the slice with
1477    ///
1478    /// # Panics
1479    ///
1480    /// - If slices exceed tensor dimensions
1481    /// - If any slice has a step != 1 (not yet supported)
1482    ///
1483    /// # Examples
1484    ///
1485    /// ```rust
1486    /// use burn_tensor::backend::Backend;
1487    /// use burn_tensor::{Tensor, s};
1488    ///
1489    /// fn example<B: Backend>() {
1490    ///     let device = B::Device::default();
1491    ///
1492    ///     // Simple fill for a single dimension
1493    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1494    ///     tensor = tensor.slice_fill(2..5, 1.0);
1495    ///     // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1496    ///
1497    ///     // Multi-dimensional fill
1498    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1499    ///     tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1500    ///     // Fills the rectangle at rows 1-2, columns 2-4 with -1
1501    ///
1502    ///     // Using negative indices
1503    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1504    ///     tensor = tensor.slice_fill(-3.., 2.0);
1505    ///     // Fills the last 3 elements with 2.0
1506    ///
1507    ///     // Complex multi-dimensional example
1508    ///     let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1509    ///     tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1510    ///     // Sets rows 1-2, all columns, last 2 in depth to 0
1511    ///
1512    ///     // Stepped slicing is supported
1513    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1514    ///     tensor = tensor.slice_fill(s![0..10;2], 1.0);
1515    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1516    /// }
1517    /// ```
1518    ///
1519    /// # See Also
1520    ///
1521    /// - [`s!`] - The macro for creating slice specifications with steps
1522    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1523    /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1524    ///
1525    /// [`s!`]: crate::s!
1526    pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self
1527    where
1528        S: SliceArg,
1529    {
1530        let shape = self.shape();
1531        let slices = slices.into_slices(&shape);
1532
1533        check!(TensorCheck::slice::<D>(&shape, &slices));
1534
1535        let slice_shape = shape.slice(&slices).unwrap();
1536        let value =
1537            Tensor::<B, 1, K>::from_data([value.elem::<K::Elem>()], (&self.device(), self.dtype()));
1538        let value = value.expand(slice_shape);
1539        self.slice_assign(&slices, value)
1540    }
1541
1542    /// Returns a new tensor with the specified dimension sliced.
1543    ///
1544    /// # Arguments
1545    ///
1546    /// * `dim`: The dimension to slice.
1547    /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1548    ///   slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1549    ///
1550    /// # Returns
1551    ///
1552    /// A new tensor with the specified dimension sliced.
1553    ///
1554    /// # Panics
1555    ///
1556    /// If the slice is out of bounds for the specified dimension.
1557    ///
1558    /// # Examples
1559    ///
1560    /// ```rust
1561    /// # use burn_tensor::{Tensor, s};
1562    /// # use burn_tensor::backend::Backend;
1563    /// #
1564    /// # fn example<B: Backend>() {
1565    /// #     let device = B::Device::default();
1566    ///     let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1567    ///
1568    ///     // Simple range slicing
1569    ///     let sliced = tensor.clone().slice_dim(1, 1..3);
1570    ///     assert_eq!(sliced.shape().as_slice(), [3, 2, 5]);
1571    ///
1572    ///     // Slicing with step - take every 2nd element
1573    ///     let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1574    ///     assert_eq!(sliced.shape().as_slice(), [3, 4, 3]); // Takes indices 0, 2, 4
1575    ///
1576    ///     // Reverse slicing with negative step
1577    ///     let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1578    ///     assert_eq!(sliced.shape().as_slice(), [3, 4, 5]); // Reverses dimension 1
1579    ///
1580    ///     // Select from index 2 with step 3
1581    ///     let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1582    ///     assert_eq!(sliced.shape().as_slice(), [1, 4, 5]); // Takes only index 2
1583    ///
1584    ///     // Select single index (reduces dimension to size 1)
1585    ///     let sliced = tensor.slice_dim(0, 1);
1586    ///     assert_eq!(sliced.shape().as_slice(), [1, 4, 5]);
1587    /// # }
1588    /// ```
1589    ///
1590    /// # See Also
1591    ///
1592    /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1593    /// - [`s!`] - The macro for creating complex slice specifications
1594    ///
1595    /// [`s!`]: crate::s!
1596    pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1597    where
1598        S: Into<Slice>,
1599    {
1600        check!(TensorCheck::check_dim::<D>(dim));
1601        let slice: Slice = slice.into();
1602
1603        let mut slices = vec![Slice::full(); D];
1604        slices[dim] = slice;
1605
1606        self.slice(&slices)
1607    }
1608
1609    /// Returns the device of the current tensor.
1610    pub fn device(&self) -> B::Device {
1611        K::device(&self.primitive)
1612    }
1613
1614    /// Move the tensor to the given device.
1615    pub fn to_device(self, device: &B::Device) -> Self {
1616        Self::new(K::to_device(self.primitive, device))
1617    }
1618
1619    /// Select tensor elements along the given dimension corresponding to the given indices.
1620    ///
1621    /// # Arguments
1622    ///
1623    /// * `dim` - The dimension to select from. Supports negative indexing.
1624    /// * `indices` - The indices of the elements to select.
1625    ///
1626    /// # Example
1627    ///
1628    /// ```rust
1629    /// use burn_tensor::backend::Backend;
1630    /// use burn_tensor::{Tensor, Int};
1631    ///
1632    /// fn example<B: Backend>() {
1633    ///   let device = B::Device::default();
1634    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1635    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1636    ///   let tensor = tensor.select(0, indices);
1637    ///   println!("{tensor}");
1638    ///   //  [[1.0, -2.0, 3.0]]
1639    /// }
1640    /// ```
1641    pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1642        let dim = dim.expect_dim_index(D);
1643        check!(TensorCheck::select::<D>(dim));
1644        Self::new(K::select(self.primitive, dim, indices.primitive))
1645    }
1646
1647    /// Assign the selected elements along the given dimension corresponding to the given indices
1648    /// from the value tensor to the original tensor using sum reduction.
1649    ///
1650    /// # Note
1651    /// For booleans, the sum operator is logical or.
1652    ///
1653    /// # Arguments
1654    ///
1655    /// * `dim` - The dimension along which to select. Supports negative indexing.
1656    /// * `indices` - The indices to select from the tensor.
1657    /// * `values` - The values to assign to the selected indices.
1658    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1659    ///
1660    /// # Example
1661    ///
1662    /// Example using a 3D tensor:
1663    ///
1664    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1665    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1666    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1667    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1668    ///
1669    /// # Warning
1670    ///
1671    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1672    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1673    pub fn select_assign(
1674        self,
1675        dim: impl AsIndex,
1676        indices: Tensor<B, 1, Int>,
1677        values: Tensor<B, D, K>,
1678        update: IndexingUpdateOp,
1679    ) -> Self {
1680        let dim = dim.expect_dim_index(D);
1681        check!(TensorCheck::select_assign::<D>(
1682            dim,
1683            &indices.shape(),
1684            &values.shape()
1685        ));
1686
1687        Self::new(K::select_assign(
1688            self.primitive,
1689            dim,
1690            indices.primitive,
1691            values.primitive,
1692            update,
1693        ))
1694    }
1695
1696    /// Update the given tensor with the value tensor where the mask is true.
1697    ///
1698    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1699    /// a scalar.
1700    ///
1701    /// # Example
1702    ///
1703    /// ```rust
1704    /// use burn_tensor::backend::Backend;
1705    /// use burn_tensor::{Tensor, Shape, Bool};
1706    ///
1707    /// fn example<B: Backend>() {
1708    ///   let device = B::Device::default();
1709    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1710    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1711    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1712    ///   let tensor = tensor.mask_where(mask, value);
1713    ///   println!("{tensor}");
1714    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1715    /// }
1716    /// ```
1717    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1718        Self::new(K::mask_where(
1719            self.primitive,
1720            mask.primitive,
1721            value.primitive,
1722        ))
1723    }
1724
1725    /// Update the given tensor with the value where the mask is true.
1726    ///
1727    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1728    /// a tensor.
1729    ///
1730    /// # Example
1731    ///
1732    /// ```rust
1733    /// use burn_tensor::backend::Backend;
1734    /// use burn_tensor::{Tensor, Shape, Bool};
1735    ///
1736    /// fn example<B: Backend>() {
1737    ///   let device = B::Device::default();
1738    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1739    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1740    ///   let tensor = tensor.mask_fill(mask, 3.0);
1741    ///   println!("{tensor}");
1742    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1743    /// }
1744    /// ```
1745    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1746        let value = Scalar::new(value, &self.dtype());
1747        Self::new(K::mask_fill(self.primitive, mask.primitive, value))
1748    }
1749
1750    /// Gather tensor elements corresponding to the given indices from the specified dim.
1751    ///
1752    /// Example using a 3D tensor:
1753    ///
1754    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1755    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1756    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1757    ///
1758    /// # Notes
1759    ///
1760    /// The index tensor should have the same shape as the original tensor except for the dim
1761    /// specified.
1762    ///
1763    /// # Warning
1764    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1765    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1766    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1767        check!(TensorCheck::gather::<D>(
1768            dim,
1769            &self.shape(),
1770            &indices.shape()
1771        ));
1772
1773        Self::new(K::gather(dim, self.primitive, indices.primitive))
1774    }
1775
1776    /// Assign the gathered elements corresponding to the given indices along the specified dimension
1777    /// from the value tensor to the original tensor using sum reduction.
1778    ///
1779    /// Example using a 3D tensor:
1780    ///
1781    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1782    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1783    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1784    ///
1785    /// # Arguments
1786    /// * `dim` - The axis along which to scatter elements.
1787    /// * `indices` - The indices of the elements to scatter.
1788    /// * `values` - The values to scatter into the tensor.
1789    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1790    ///
1791    /// # Notes
1792    ///
1793    /// The index tensor should have the same shape as the original tensor except for the specified
1794    /// dimension. The value and index tensors should have the same shape.
1795    ///
1796    /// Other references to the input tensor will not be modified by this operation.
1797    ///
1798    /// # Warning
1799    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1800    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1801    ///
1802    /// # Panics
1803    /// If the `update` is not `IndexingUpdateOp::Add`. Other operations are currently not implemented.
1804    pub fn scatter(
1805        self,
1806        dim: usize,
1807        indices: Tensor<B, D, Int>,
1808        values: Self,
1809        update: IndexingUpdateOp,
1810    ) -> Self {
1811        check!(TensorCheck::scatter::<D>(
1812            dim,
1813            &self.shape(),
1814            &indices.shape(),
1815            &values.shape()
1816        ));
1817
1818        Self::new(K::scatter(
1819            dim,
1820            self.primitive,
1821            indices.primitive,
1822            values.primitive,
1823            update,
1824        ))
1825    }
1826
1827    /// Multi-dimensional scatter: update the tensor at locations given by `indices` using the specified `update` operation.
1828    ///
1829    /// The size of `indices`'s last axis (call it `K`) indexes the leading `K` dims of `self`;
1830    /// the batch shape `indices.shape[0..M-1]` is preserved. `values` has shape
1831    /// `indices.shape[0..M-1] ++ self.shape[K..D]`. Constraints: `K <= D` and `M >= 1`.
1832    ///
1833    /// # Arguments
1834    /// * `indices` - The indices of the elements to scatter.
1835    /// * `values` - The values to scatter into the tensor.
1836    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1837    ///
1838    /// # Note
1839    ///
1840    /// When `indices` contains duplicate entries, behavior varies by operation:
1841    /// - For `Add`, accumulation is supported, though results may be non-deterministic on GPU
1842    ///   backends.
1843    /// - For other operations (`Assign`, `Mul`, `Min`, `Max`), duplicate indices result in
1844    ///   undefined behavior for both the forward result and the backward gradients.
1845    ///
1846    /// For deterministic results and correct gradient calculation across all operations,
1847    /// `indices` should contain unique entries.
1848    ///
1849    /// # Warning
1850    ///
1851    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1852    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1853    pub fn scatter_nd<const M: usize, const DV: usize>(
1854        self,
1855        indices: Tensor<B, M, Int>,
1856        values: Tensor<B, DV, K>,
1857        update: IndexingUpdateOp,
1858    ) -> Self {
1859        check!(TensorCheck::scatter_nd::<D, M, DV>(
1860            &self.shape(),
1861            &indices.shape(),
1862            &values.shape()
1863        ));
1864        Self::new(K::scatter_nd(
1865            self.primitive,
1866            indices.primitive,
1867            values.primitive,
1868            update,
1869        ))
1870    }
1871
1872    /// Multi-dimensional gather: collect slices from `self` at multi-index locations
1873    /// specified by `indices`.
1874    ///
1875    /// The size of `indices`'s last axis (call it `K`) indexes the leading `K` dims of `self`;
1876    /// the batch shape `indices.shape[0..M-1]` is preserved. The output has shape
1877    /// `indices.shape[0..M-1] ++ self.shape[K..D]`. Constraints: `K <= D` and `M >= 1`.
1878    ///
1879    /// # Warning
1880    ///
1881    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1882    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1883    pub fn gather_nd<const M: usize, const DV: usize>(
1884        self,
1885        indices: Tensor<B, M, Int>,
1886    ) -> Tensor<B, DV, K> {
1887        check!(TensorCheck::gather_nd::<D, M, DV>(&indices.shape()));
1888        Tensor::new(K::gather_nd(self.primitive, indices.primitive))
1889    }
1890
1891    /// Converts the data of the current tensor.
1892    ///
1893    /// # Note
1894    ///
1895    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1896    /// tensors at once. This may improve laziness, especially if executed on a different
1897    /// thread in native environments.
1898    pub fn into_data(self) -> TensorData {
1899        self.try_into_data().expect(
1900            "Error while reading data: use `try_into_data` instead to catch the error at runtime",
1901        )
1902    }
1903
1904    /// Converts the data of the current tensor and returns any error that might have occurred since the
1905    /// last time the device was synchronized.
1906    ///
1907    /// # Note
1908    ///
1909    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1910    /// tensors at once. This may improve laziness, especially if executed on a different
1911    /// thread in native environments.
1912    pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {
1913        crate::try_read_sync(self.into_data_async()).expect(
1914            "Failed to read tensor data synchronously.
1915        This can happen on platforms that don't support blocking futures like WASM.
1916        If possible, try using into_data_async instead.",
1917        )
1918    }
1919
1920    /// Converts the data of the current tensor.
1921    ///
1922    /// # Note
1923    ///
1924    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1925    /// tensors at once. This may improve laziness, especially if executed on a different
1926    /// thread in native environments.
1927    pub fn to_data(&self) -> TensorData {
1928        self.clone().into_data()
1929    }
1930
1931    /// Returns the data of the current tensor.
1932    pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {
1933        K::into_data_async(self.primitive).await
1934    }
1935
1936    /// Returns the data of the current tensor.
1937    pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {
1938        self.clone().into_data_async().await
1939    }
1940
1941    /// Create a tensor from the given data on the given device.
1942    pub fn from_data<T>(data: T, options: impl Into<TensorCreationOptions<B>>) -> Self
1943    where
1944        T: Into<TensorData>,
1945    {
1946        let data = data.into();
1947        check!(TensorCheck::creation_ops::<D>(
1948            "From Data",
1949            data.shape.as_slice()
1950        ));
1951
1952        // Use the given dtype when provided, otherwise default device dtype
1953        let opt = options.into();
1954        let dtype = opt.resolve_dtype::<K>();
1955        Self::new(K::from_data(data, &opt.device, dtype))
1956    }
1957
1958    /// Repeat the tensor along the given dimension.
1959    ///
1960    /// The output tensor has the same shape, except along the given dimension.
1961    ///
1962    /// # Arguments
1963    /// - `dim`: The dimension to repeat.
1964    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1965    ///
1966    /// # Returns
1967    ///
1968    /// A new tensor with the given dimension repeated `times` times.
1969    ///
1970    /// # Example
1971    ///
1972    /// ```rust
1973    /// use burn_tensor::backend::Backend;
1974    /// use burn_tensor::Tensor;
1975    ///
1976    /// fn example<B: Backend>() {
1977    ///     let device = Default::default();
1978    ///     // Create a 2D tensor with dimensions [3, 2]
1979    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1980    ///
1981    ///     // Repeat the tensor along the dimension 0 twice.
1982    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1983    ///     // The resulting tensor will have dimensions [6, 2].
1984    ///     let repeated = tensor.repeat_dim(0, 2);
1985    ///     println!("{repeated}");
1986    /// }
1987    /// ```
1988    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1989        if times > 0 {
1990            Self::new(K::repeat_dim(self.primitive, dim, times))
1991        } else {
1992            let shape = self.shape().repeat(dim, times).unwrap();
1993            Self::empty(shape, &self.device())
1994        }
1995    }
1996
1997    /// Repeat the tensor along the given dimensions.
1998    /// # Arguments
1999    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
2000    ///
2001    /// # Returns
2002    ///
2003    /// A new tensor with the given dimensions repeated `times` times.
2004    ///
2005    /// # Panics
2006    ///
2007    /// If `sizes` contains more elements than the number of dimensions.
2008    ///
2009    /// # Example
2010    ///
2011    /// ```rust
2012    ///
2013    /// use burn_tensor::backend::Backend;
2014    /// use burn_tensor::Tensor;
2015    ///
2016    /// fn example<B: Backend>() {
2017    ///     let device = Default::default();
2018    ///     // Create a 2D tensor with dimensions [3, 2]
2019    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2020    ///
2021    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
2022    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
2023    ///     // The resulting tensor will have dimensions [6, 2].
2024    ///     let repeated = tensor.repeat(&[2, 1]);
2025    /// }
2026    /// ```
2027    pub fn repeat(self, sizes: &[usize]) -> Self {
2028        if sizes.contains(&0) {
2029            let mut shape = self.shape();
2030            for (dim, &times) in sizes.iter().enumerate() {
2031                shape = shape.repeat(dim, times).unwrap();
2032            }
2033
2034            return Self::empty(shape, &self.device());
2035        }
2036
2037        let mut tensor = self;
2038        for (dim, &times) in sizes.iter().enumerate() {
2039            if times > 1 {
2040                tensor = tensor.repeat_dim(dim, times);
2041            }
2042        }
2043        tensor
2044    }
2045
2046    /// Applies element-wise equal comparison.
2047    ///
2048    /// # Returns
2049    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
2050    ///
2051    /// # Panics
2052    ///
2053    /// If the two tensors don't have the same shape.
2054    ///
2055    /// # Example
2056    ///
2057    /// ```rust
2058    /// use burn_tensor::backend::Backend;
2059    /// use burn_tensor::Tensor;
2060    ///
2061    /// fn example<B: Backend>() {
2062    ///     let device = Default::default();
2063    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2064    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2065    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
2066    ///     // [[false, true], [true, true], [true, true]]
2067    ///     let equal = t1.equal(t2);
2068    ///     println!("{equal}");
2069    /// }
2070    /// ```
2071    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
2072        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
2073        Tensor::new(K::equal(self.primitive, other.primitive))
2074    }
2075
2076    /// Applies element-wise non-equality comparison.
2077    ///
2078    /// # Returns
2079    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
2080    ///
2081    /// # Panics
2082    ///
2083    /// If the two tensors don't have the same shape.
2084    ///
2085    /// # Example
2086    ///
2087    /// ```rust
2088    /// use burn_tensor::backend::Backend;
2089    /// use burn_tensor::Tensor;
2090    ///
2091    /// fn example<B: Backend>() {
2092    ///     let device = Default::default();
2093    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2094    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2095    ///     // Compare the elements of the two 2D tensors for inequality.
2096    ///     // [[true, false], [false, false], [false, false]]
2097    ///     let not_equal = t1.not_equal(t2);
2098    ///     println!("{not_equal}");
2099    /// }
2100    /// ```
2101    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
2102        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
2103        Tensor::new(K::not_equal(self.primitive, other.primitive))
2104    }
2105
2106    /// Applies element wise equal comparison and returns a boolean tensor.
2107    ///
2108    /// # Arguments
2109    ///
2110    /// * `other` - The element to compare.
2111    ///
2112    /// # Example
2113    ///
2114    /// ```rust
2115    /// use burn_tensor::backend::Backend;
2116    /// use burn_tensor::{Tensor, Shape};
2117    ///
2118    /// fn example<B: Backend>() {
2119    ///    let device = B::Device::default();
2120    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2121    ///    let tensor = tensor.equal_elem(3.0);
2122    ///    println!("{tensor}");
2123    ///    // [[false, false, true], [false, false, false]]
2124    /// }
2125    /// ```
2126    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2127        let other = Scalar::new(other, &self.dtype());
2128        Tensor::new(K::equal_elem(self.primitive, other))
2129    }
2130
2131    /// Applies element wise non-equality comparison and returns a boolean tensor.
2132    ///
2133    /// # Arguments
2134    ///
2135    /// * `other` - The element to compare.
2136    ///
2137    /// # Example
2138    ///
2139    /// ```rust
2140    /// use burn_tensor::backend::Backend;
2141    /// use burn_tensor::{Tensor, Shape};
2142    ///
2143    /// fn example<B: Backend>() {
2144    ///    let device = B::Device::default();
2145    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2146    ///    let tensor = tensor.not_equal_elem(3.0);
2147    ///    println!("{tensor}");
2148    ///    // [[true, true, false], [true, true, true]]
2149    /// }
2150    /// ```
2151    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2152        let other = Scalar::new(other, &self.dtype());
2153        Tensor::new(K::not_equal_elem(self.primitive, other))
2154    }
2155
2156    /// Concatenates all tensors into a new one along the given dimension.
2157    ///
2158    /// # Panics
2159    ///
2160    /// - If `dim` is higher than the rank.
2161    /// - If `tensors` is an empty vector.
2162    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
2163    ///
2164    /// # Example
2165    ///
2166    /// ```rust
2167    /// use burn_tensor::backend::Backend;
2168    /// use burn_tensor::Tensor;
2169    ///
2170    /// fn example<B: Backend>() {
2171    ///     let device = Default::default();
2172    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
2173    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2174    ///
2175    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
2176    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
2177    ///     // The resulting tensor will have shape [2, 7].
2178    ///     let concat = Tensor::cat(vec![t1, t2], 1);
2179    ///     println!("{concat}");
2180    /// }
2181    /// ```
2182    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
2183        check!(TensorCheck::cat(&tensors, dim));
2184
2185        // Filter out tensors with size 0 along the concatenation dimension.
2186        // Empty tensors don't contribute to the output and would cause issues
2187        // in backend implementations (e.g., division by zero in slice_assign).
2188        // Safety: TensorCheck::cat ensures tensors is non-empty
2189        let first_tensor = tensors.first().unwrap();
2190        let device = first_tensor.device();
2191        let mut shape = first_tensor.shape();
2192
2193        let non_empty_primitives: Vec<_> = tensors
2194            .into_iter()
2195            .filter(|t| t.shape()[dim] > 0)
2196            .map(|t| t.primitive)
2197            .collect();
2198
2199        // If all tensors were empty, return an empty tensor with size 0 on concat dim
2200        if non_empty_primitives.is_empty() {
2201            shape[dim] = 0;
2202            return Self::empty(shape, &device);
2203        }
2204
2205        Self::new(K::cat(non_empty_primitives, dim))
2206    }
2207
2208    /// Concatenates all tensors into a new one along a new dimension.
2209    ///
2210    /// # Panics
2211    ///
2212    /// - If all tensors don't have the same shape.
2213    /// - If given dimension is not with range of 0..D2
2214    ///
2215    /// # Example
2216    ///
2217    /// ```rust
2218    /// use burn_tensor::backend::Backend;
2219    /// use burn_tensor::Tensor;
2220    ///
2221    /// fn example<B: Backend>() {
2222    ///     let device = Default::default();
2223    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2224    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2225    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2226    ///
2227    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
2228    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
2229    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
2230    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
2231    ///     // The resulting tensor will have shape [3, 2, 3].
2232    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
2233    ///     println!("{stacked}");
2234    /// }
2235    /// ```
2236    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
2237        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
2238        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
2239        Tensor::<B, D2, K>::cat(tensors, dim)
2240    }
2241
2242    /// Iterate over slices of tensors alongside a given dimension.
2243    ///
2244    /// # Panics
2245    ///
2246    /// If given dimension is greater than or equal to tensor rank.
2247    ///
2248    /// # Returns
2249    ///
2250    /// A tensor iterator.
2251    ///
2252    /// # Example
2253    ///
2254    /// ```rust
2255    /// use burn_tensor::backend::Backend;
2256    /// use burn_tensor::Tensor;
2257    /// fn example<B: Backend>() {
2258    ///   let device = Default::default();
2259    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2260    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
2261    ///   let iter = tensor.iter_dim(0);
2262    ///   for (i,tensor) in iter.enumerate() {
2263    ///     println!("Tensor {}: {}", i, tensor);
2264    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
2265    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
2266    ///  }
2267    /// }
2268    /// ```
2269    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
2270        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
2271        DimIter::new(self, dim)
2272    }
2273
2274    /// Returns a new tensor with the given dimension narrowed to the given range.
2275    ///
2276    /// # Panics
2277    ///
2278    /// - If the dimension is greater than the number of dimensions of the tensor.
2279    /// - If the given range exceeds the number of elements on the given dimension.
2280    ///
2281    /// # Returns
2282    ///
2283    /// A new tensor with the given dimension narrowed to the given range.
2284    ///
2285    /// # Example
2286    ///
2287    /// ```rust
2288    /// use burn_tensor::backend::Backend;
2289    /// use burn_tensor::Tensor;
2290    ///
2291    /// fn example<B: Backend>() {
2292    ///     let device = Default::default();
2293    ///     // Create a 2D tensor with dimensions [4, 3]
2294    ///     let tensor = Tensor::<B, 2>::from_data(
2295    ///         [
2296    ///             [3.0, 4.9, 2.0],
2297    ///             [2.0, 1.9, 3.0],
2298    ///             [6.0, 1.5, 7.0],
2299    ///             [3.0, 4.9, 9.0],
2300    ///         ],
2301    ///         &device,
2302    ///     );
2303    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
2304    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
2305    ///     // The resulting tensor will have dimensions [3, 3].
2306    ///     let narrowed = tensor.narrow(0, 1, 3);
2307    ///     println!("{narrowed}");
2308    /// }
2309    /// ```
2310    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
2311        check!(TensorCheck::dim_ops::<D>("narrow", dim));
2312        check!(TensorCheck::narrow(&self, dim, start, length));
2313        let dims = self.dims();
2314
2315        let ranges: [Range<usize>; D] = dims
2316            .iter()
2317            .enumerate()
2318            .map(|(i, d)| {
2319                if i == dim {
2320                    start..(start + length)
2321                } else {
2322                    0..*d
2323                }
2324            })
2325            .collect::<Vec<_>>()
2326            .try_into()
2327            .unwrap();
2328
2329        Self::slice(self, ranges)
2330    }
2331
2332    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
2333    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2334    ///
2335    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2336    /// Otherwise all chunks will be of equal size except for the last one.
2337    ///
2338    /// # Panics
2339    ///
2340    /// If the dimension is greater than the number of dimensions of the tensor.
2341    ///
2342    /// # Returns
2343    /// A vector of tensors.
2344    ///
2345    /// # Example
2346    ///
2347    /// ```rust
2348    /// use burn_tensor::backend::Backend;
2349    /// use burn_tensor::Tensor;
2350    ///
2351    /// fn example<B: Backend>() {
2352    ///     let device = Default::default();
2353    ///     // Create a 2D tensor with dimensions [4, 3]
2354    ///     let tensor = Tensor::<B, 2>::from_data(
2355    ///         [
2356    ///             [3.0, 4.9, 2.0],
2357    ///             [2.0, 1.9, 3.0],
2358    ///             [6.0, 1.5, 7.0],
2359    ///             [3.0, 4.9, 9.0],
2360    ///         ],
2361    ///         &device,
2362    ///     );
2363    ///     // Split the tensor along the dimension 1 into 2 chunks.
2364    ///     // The first chuck will have shape [4, 2]:
2365    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
2366    ///     // The second chunk will have shape [4, 1]:
2367    ///     // [[2.0], [3.0], [7.0], [9.0]]
2368    ///     let chunks = tensor.chunk(2, 1);
2369    ///     println!("{chunks:?}");
2370    /// }
2371    /// ```
2372    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
2373        check!(TensorCheck::dim_ops::<D>("chunk", dim));
2374        let size = self.shape()[dim];
2375        if size < chunks {
2376            return (0..size)
2377                .map(|i| Self::narrow(self.clone(), dim, i, 1))
2378                .collect();
2379        }
2380
2381        let mut tensors = Vec::with_capacity(chunks);
2382        let mut sum_chunk_size = 0;
2383        if size.is_multiple_of(chunks) {
2384            let chunk_size = size / chunks;
2385            for _ in 0..chunks {
2386                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2387                sum_chunk_size += chunk_size;
2388            }
2389        } else {
2390            let chunk_size = (size / chunks) + 1; // assumes not divisible
2391            for _ in 0..chunks - 1 {
2392                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2393                sum_chunk_size += chunk_size;
2394            }
2395            let remainder = size % chunk_size;
2396            tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
2397        }
2398
2399        tensors
2400    }
2401
2402    /// Splits the tensor into chunks of a specified size along a given dimension.
2403    /// Each chunk is a view of the original tensor.
2404    ///
2405    /// If the tensor size along the given dimension is not divisible by `split_size`,
2406    /// then the last chunk will be smaller.
2407    ///
2408    /// # Panics
2409    ///
2410    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2411    ///
2412    /// # Returns
2413    ///
2414    /// A vector of tensors.
2415    ///
2416    /// # Example
2417    /// ```rust
2418    /// use burn_tensor::backend::Backend;
2419    /// use burn_tensor::Tensor;
2420    ///
2421    /// fn example<B: Backend>() {
2422    ///     let device = Default::default();
2423    ///     // Create a 1D tensor with 5 elements
2424    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2425    ///     // Split the tensor into chunks of size 2 along dimension 0
2426    ///     let chunks = tensor.split(2, 0);
2427    ///     // The result is a vector of tensors:
2428    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2429    ///     println!("{:?}", chunks);
2430    /// }
2431    /// ```
2432    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2433        check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2434        let size = self.shape()[dim];
2435        let mut tensors = Vec::new();
2436
2437        let mut start = 0;
2438        while start < size {
2439            let length = usize::min(split_size, size - start);
2440            tensors.push(Self::narrow(self.clone(), dim, start, length));
2441            start += length;
2442        }
2443
2444        tensors
2445    }
2446
2447    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2448    /// Each chunk is a view of the original tensor.
2449    ///
2450    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2451    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2452    ///
2453    /// # Panics
2454    ///
2455    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2456    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2457    ///
2458    /// # Returns
2459    ///
2460    /// A vector of tensors.
2461    ///
2462    /// # Example
2463    /// ```rust
2464    /// use burn_tensor::backend::Backend;
2465    /// use burn_tensor::Tensor;
2466    ///
2467    /// fn example<B: Backend>() {
2468    ///     let device = Default::default();
2469    ///     // Create a 1D tensor with 5 elements
2470    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2471    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
2472    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2473    ///     // The result is a vector of tensors:
2474    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2475    ///     println!("{:?}", chunks);
2476    /// }
2477    /// ```
2478    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2479        check!(TensorCheck::split_with_sizes::<D>(
2480            &self.shape(),
2481            &split_sizes,
2482            dim
2483        ));
2484        let mut tensors = Vec::new();
2485
2486        let mut start = 0;
2487        for length in split_sizes {
2488            if length == 0 {
2489                continue;
2490            }
2491            tensors.push(Self::narrow(self.clone(), dim, start, length));
2492            start += length;
2493        }
2494
2495        tensors
2496    }
2497
2498    /// Tests if any element in the `tensor` evaluates to True.
2499    ///
2500    /// # Arguments
2501    ///
2502    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2503    ///
2504    /// # Returns
2505    ///
2506    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2507    /// evaluates to True, False otherwise.
2508    ///
2509    /// # Example
2510    ///
2511    /// ```rust
2512    /// use burn_tensor::backend::Backend;
2513    /// use burn_tensor::{Tensor, Bool};
2514    ///
2515    /// fn example<B: Backend>() {
2516    ///   let device = Default::default();
2517    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2518    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2519    ///
2520    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2521    ///   let any_tensor = tensor.any();
2522    ///   println!("{}", any_tensor);
2523    ///   // Tensor { data: [true], ... }
2524    ///
2525    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2526    ///   let any_tensor_two = tensor_two.any();
2527    ///   println!("{}", any_tensor_two);
2528    ///   // Tensor { data: [false], ... }
2529    /// }
2530    /// ```
2531    pub fn any(self) -> Tensor<B, 1, Bool> {
2532        Tensor::new(K::any(self.primitive))
2533    }
2534
2535    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2536    ///
2537    /// # Arguments
2538    ///
2539    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2540    /// * `dim` - The axis along which to test.
2541    ///
2542    /// # Returns
2543    ///
2544    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2545    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2546    /// evaluates to True, False otherwise.
2547    ///
2548    /// # Example
2549    ///
2550    /// ```rust
2551    /// use burn_tensor::backend::Backend;
2552    /// use burn_tensor::{Tensor, Bool};
2553    ///
2554    /// fn example<B: Backend>() {
2555    ///     let device = Default::default();
2556    ///     let tensor =
2557    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2558    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
2559    ///     // [[true], [true]],
2560    ///     let any_dim = tensor.clone().any_dim(1);
2561    ///     println!("{any_dim}");
2562    /// }
2563    /// ```
2564    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2565        Tensor::new(K::any_dim(self.primitive, dim))
2566    }
2567
2568    /// Tests if all elements in the `tensor` evaluate to True.
2569    ///
2570    /// # Arguments
2571    ///
2572    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2573    ///
2574    /// # Returns
2575    ///
2576    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2577    /// evaluate to True, False otherwise.
2578    ///
2579    /// # Example
2580    ///
2581    /// ```rust
2582    /// use burn_tensor::backend::Backend;
2583    /// use burn_tensor::{Tensor, Bool};
2584    ///
2585    /// fn example<B: Backend>() {
2586    ///     let device = Default::default();
2587    ///     let tensor =
2588    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2589    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
2590    ///     // [false]
2591    ///     let all = tensor.all();
2592    ///     println!("{all}");
2593    /// }
2594    /// ```
2595    pub fn all(self) -> Tensor<B, 1, Bool> {
2596        Tensor::new(K::all(self.primitive))
2597    }
2598
2599    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2600    ///
2601    /// # Arguments
2602    ///
2603    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2604    /// * `dim` - The axis along which to test.
2605    ///
2606    /// # Returns
2607    ///
2608    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2609    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2610    /// evaluates to True, False otherwise.
2611    ///
2612    /// # Example
2613    ///
2614    /// ```rust
2615    /// use burn_tensor::backend::Backend;
2616    /// use burn_tensor::{Tensor, Bool};
2617    ///
2618    /// fn example<B: Backend>() {
2619    ///     let device = Default::default();
2620    ///     let tensor =
2621    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2622    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
2623    ///     // [[true, true, false]]
2624    ///     let all_dim = tensor.clone().all_dim(0);
2625    ///     println!("{all_dim}");
2626    /// }
2627    /// ```
2628    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2629        Tensor::new(K::all_dim(self.primitive, dim))
2630    }
2631
2632    /// Convert the tensor into a scalar.
2633    ///
2634    /// # Panics
2635    ///
2636    /// - If the tensor doesn't have one element.
2637    /// - If the backend fails to read the tensor data synchronously.
2638    ///
2639    /// # Returns
2640    ///
2641    /// The scalar value of the tensor.
2642    ///
2643    /// # Example
2644    ///
2645    /// ```rust
2646    /// use burn_tensor::backend::Backend;
2647    /// use burn_tensor::Tensor;
2648    ///
2649    /// fn example<B: Backend>() {
2650    ///     let device = Default::default();
2651    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2652    ///     // Convert the tensor with a single element into a scalar.
2653    ///     let scalar = tensor.into_scalar();
2654    ///     println!("{scalar}");
2655    /// }
2656    /// ```
2657    pub fn into_scalar(self) -> K::Elem {
2658        crate::try_read_sync(self.into_scalar_async())
2659            .expect(
2660            "Failed to read tensor data synchronously. This can happen on platforms
2661            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2662            )
2663            .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime")
2664    }
2665
2666    /// Convert the tensor into a scalar and returns any error that might have occurred since the
2667    /// last time the device was synchronized.
2668    ///
2669    /// # Panics
2670    ///
2671    /// - If the tensor doesn't have one element.
2672    /// - If the backend fails to read the tensor data synchronously.
2673    ///
2674    /// # Returns
2675    ///
2676    /// The scalar value of the tensor.
2677    pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {
2678        crate::try_read_sync(self.into_scalar_async()).expect(
2679            "Failed to read tensor data synchronously. This can happen on platforms
2680            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2681        )
2682    }
2683
2684    /// Convert the tensor into a scalar.
2685    ///
2686    /// # Panics
2687    ///
2688    /// If the tensor doesn't have one element.
2689    pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {
2690        check!(TensorCheck::into_scalar::<D>(&self.shape()));
2691
2692        Ok(self.into_data_async().await?.iter().next().unwrap())
2693    }
2694
2695    /// Broadcast the tensor to the given shape.
2696    ///
2697    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2698    /// (which can be inferred with `-1`).
2699    ///
2700    /// # Arguments
2701    ///
2702    /// * `shape` - The shape to broadcast the tensor to.
2703    ///   Can contain -1 for dimensions that should be inferred.
2704    ///   The number of elements in the shape must be greater or equal as
2705    ///   the number of dimensions of the tensor.
2706    ///
2707    /// # Panics
2708    ///
2709    /// If the tensor cannot be broadcasted to the given shape.
2710    ///
2711    /// # Returns
2712    ///
2713    /// A new tensor with the given shape.
2714    ///
2715    /// # Example
2716    ///
2717    /// ```rust
2718    /// use burn_tensor::backend::Backend;
2719    /// use burn_tensor::Tensor;
2720    ///
2721    /// fn example<B: Backend>() {
2722    ///     let device = Default::default();
2723    ///     // Create a 2D tensor with dimensions [3, 1]
2724    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2725    ///     // Expand the tensor to a new shape [3, 4]
2726    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2727    ///     let expanded = tensor.expand([3, 4]);
2728    ///     println!("{}", expanded);
2729    /// }
2730    /// ```
2731    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2732        let shape = shape.into_shape(&self.shape());
2733        check!(TensorCheck::expand::<D, D2>(
2734            "expand",
2735            &self.shape(),
2736            &shape,
2737        ));
2738
2739        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2740    }
2741
2742    /// Unfold windows along a dimension.
2743    ///
2744    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2745    /// where windows are advanced by `step` at each index.
2746    ///
2747    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2748    ///
2749    /// The new view will have the unfolded dimension replaced by two dimensions;
2750    /// one in the position of the original dimension, with size equal to the number of windows,
2751    /// and one appended to the right-most position, with size equal to `size`.
2752    ///
2753    /// # Warning
2754    ///
2755    /// For the `ndarray` backend; this is not a view but a copy
2756    /// with duplicated data.
2757    ///
2758    /// # Arguments
2759    ///
2760    /// * `dim` - the dimension to unfold.
2761    /// * `size` - the size of each unfolded window.
2762    /// * `step` - the step between each window.
2763    ///
2764    /// # Returns
2765    ///
2766    /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2767    pub fn unfold<const D2: usize, I: AsIndex>(
2768        self,
2769        dim: I,
2770        size: usize,
2771        step: usize,
2772    ) -> Tensor<B, D2, K> {
2773        let dim = dim.expect_dim_index(D);
2774        check!(TensorCheck::unfold::<D, D2>(
2775            "unfold",
2776            &self.shape(),
2777            dim,
2778            size,
2779            step,
2780        ));
2781        Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2782    }
2783}
2784
2785/// Iterator given by (Tensor::iter_dim).
2786pub struct DimIter<B, const D: usize, K>
2787where
2788    B: Backend,
2789    K: BasicOps<B>,
2790{
2791    start: usize,
2792    end: usize,
2793    dim: usize,
2794    ranges: [Range<usize>; D],
2795    tensor: Tensor<B, D, K>,
2796}
2797
2798impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2799    type Item = Tensor<B, D, K>;
2800
2801    fn next(&mut self) -> Option<Self::Item> {
2802        if self.start >= self.end {
2803            return None;
2804        }
2805
2806        let mut ranges = self.ranges.clone();
2807        ranges[self.dim] = self.start..(self.start + 1);
2808
2809        let slice = self.tensor.clone().slice(ranges);
2810        self.start += 1;
2811
2812        Some(slice)
2813    }
2814}
2815
2816impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2817    fn next_back(&mut self) -> Option<Self::Item> {
2818        if self.start >= self.end {
2819            return None;
2820        }
2821
2822        let mut ranges = self.ranges.clone();
2823        ranges[self.dim] = (self.end - 1)..self.end;
2824
2825        let slice = self.tensor.clone().slice(ranges);
2826        self.end = self.end.saturating_sub(1);
2827
2828        Some(slice)
2829    }
2830}
2831
2832impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2833    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2834        let dims = tensor.dims();
2835        let ranges = dims
2836            .iter()
2837            .map(|&dim| 0..dim)
2838            .collect::<Vec<Range<usize>>>();
2839        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2840        Self {
2841            end: dims[dim],
2842            ranges,
2843            start: 0,
2844            dim,
2845            tensor,
2846        }
2847    }
2848}
2849
2850impl<B, const D: usize, K> Tensor<B, D, K>
2851where
2852    B: Backend,
2853    K: BasicOps<B>,
2854    <K as BasicOps<B>>::Elem: Debug,
2855{
2856    #[inline]
2857    fn push_newline_indent(acc: &mut String, indent: usize) {
2858        acc.push('\n');
2859        for _ in 0..indent {
2860            acc.push(' ');
2861        }
2862    }
2863    fn fmt_inner_tensor(
2864        &self,
2865        acc: &mut String,
2866        depth: usize,
2867        multi_index: &mut [usize],
2868        range: (usize, usize),
2869        precision: Option<usize>,
2870    ) {
2871        let (start, end) = range;
2872        for i in start..end {
2873            if i > 0 {
2874                acc.push_str(", ");
2875            }
2876            multi_index[depth] = i;
2877            let range: [Range<usize>; D] =
2878                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2879
2880            let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());
2881
2882            if let Some(Ok(data)) = data {
2883                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2884                match (precision, K::name()) {
2885                    (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2886                    (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2887                    _ => acc.push_str(&format!("{elem:?}")),
2888                }
2889            } else {
2890                acc.push_str("<Tensor data not available>");
2891            }
2892        }
2893    }
2894
2895    fn fmt_outer_tensor(
2896        &self,
2897        acc: &mut String,
2898        depth: usize,
2899        multi_index: &mut [usize],
2900        print_options: &PrintOptions,
2901        summarize: bool,
2902        range: (usize, usize),
2903    ) {
2904        let (start, end) = range;
2905        for i in start..end {
2906            if i > start {
2907                acc.push(',');
2908                Self::push_newline_indent(acc, depth + 1);
2909            }
2910            acc.push('[');
2911            multi_index[depth] = i;
2912            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2913            acc.push(']');
2914        }
2915    }
2916
2917    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2918    ///
2919    /// This function is designed to work with tensors of any dimensionality.
2920    /// It traverses the tensor dimensions recursively, converting the elements
2921    /// to strings and appending them to the accumulator string with the
2922    /// appropriate formatting.
2923    ///
2924    /// # Arguments
2925    ///
2926    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2927    /// * `depth` - The current depth of the tensor dimensions being processed.
2928    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2929    fn display_recursive(
2930        &self,
2931        acc: &mut String,
2932        depth: usize,
2933        multi_index: &mut [usize],
2934        print_options: &PrintOptions,
2935        summarize: bool,
2936    ) {
2937        let edge_items = print_options.edge_items;
2938
2939        if depth == 0 {
2940            acc.push('[');
2941        }
2942
2943        if depth == self.dims().len() - 1 {
2944            // if we are at the innermost dimension, just push its elements into the accumulator
2945            if summarize && self.dims()[depth] > 2 * edge_items {
2946                // print the starting `edge_items` elements
2947                self.fmt_inner_tensor(
2948                    acc,
2949                    depth,
2950                    multi_index,
2951                    (0, edge_items),
2952                    print_options.precision,
2953                );
2954                acc.push_str(", ...");
2955                // print the last `edge_items` elements
2956                self.fmt_inner_tensor(
2957                    acc,
2958                    depth,
2959                    multi_index,
2960                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2961                    print_options.precision,
2962                );
2963            } else {
2964                // print all the elements
2965                self.fmt_inner_tensor(
2966                    acc,
2967                    depth,
2968                    multi_index,
2969                    (0, self.dims()[depth]),
2970                    print_options.precision,
2971                );
2972            }
2973        } else {
2974            // otherwise, iterate through the current dimension and recursively display the inner tensors
2975            if summarize && self.dims()[depth] > 2 * edge_items {
2976                self.fmt_outer_tensor(
2977                    acc,
2978                    depth,
2979                    multi_index,
2980                    print_options,
2981                    summarize,
2982                    (0, edge_items),
2983                );
2984
2985                acc.push(',');
2986                Self::push_newline_indent(acc, depth + 1);
2987                acc.push_str("...");
2988                Self::push_newline_indent(acc, depth + 1);
2989
2990                self.fmt_outer_tensor(
2991                    acc,
2992                    depth,
2993                    multi_index,
2994                    print_options,
2995                    summarize,
2996                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2997                );
2998            } else {
2999                self.fmt_outer_tensor(
3000                    acc,
3001                    depth,
3002                    multi_index,
3003                    print_options,
3004                    summarize,
3005                    (0, self.dims()[depth]),
3006                );
3007            }
3008        }
3009
3010        if depth == 0 {
3011            acc.push(']');
3012        }
3013    }
3014}
3015
3016#[derive(Clone, Debug)]
3017/// Options for Tensor pretty printing
3018pub struct PrintOptions {
3019    /// number of elements to start summarizing tensor
3020    pub threshold: usize,
3021
3022    /// number of starting elements and ending elements to display
3023    pub edge_items: usize,
3024
3025    /// Precision for floating point numbers
3026    pub precision: Option<usize>,
3027}
3028
3029static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
3030
3031impl PrintOptions {
3032    /// Print options with default values
3033    pub const fn const_default() -> Self {
3034        Self {
3035            threshold: 1000,
3036            edge_items: 3,
3037            precision: None,
3038        }
3039    }
3040}
3041
3042impl Default for PrintOptions {
3043    fn default() -> Self {
3044        Self::const_default()
3045    }
3046}
3047
3048/// Set print options
3049pub fn set_print_options(options: PrintOptions) {
3050    let mut print_opts = PRINT_OPTS.write().unwrap();
3051    *print_opts = options;
3052}
3053
3054/// Pretty print tensors
3055impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
3056where
3057    B: Backend,
3058    B::IntElem: core::fmt::Display,
3059    K: BasicOps<B>,
3060    <K as BasicOps<B>>::Elem: Debug,
3061{
3062    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3063        writeln!(f, "Tensor {{")?;
3064
3065        {
3066            // Do not lock the mutex for the whole function
3067            let mut po = { PRINT_OPTS.read().unwrap().clone() };
3068
3069            // Override the precision if it is set from the formatter
3070            // This will be possible when the tensor is printed using the `{:.*}` syntax
3071            if let Some(precision) = f.precision() {
3072                po.precision = Some(precision);
3073            }
3074
3075            let mut acc = String::new();
3076            let mut multi_index = vec![0; D];
3077            let summarize = self.shape().num_elements() > po.threshold;
3078
3079            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
3080
3081            writeln!(f, "  data:")?;
3082            write!(f, "{acc}")?;
3083            writeln!(f, ",")?;
3084        }
3085
3086        writeln!(f, "  shape:  {:?},", self.dims())?;
3087        writeln!(f, "  device:  {:?},", self.device())?;
3088        writeln!(f, "  backend:  {:?},", B::name(&self.device()))?;
3089        writeln!(f, "  kind:  {:?},", K::name())?;
3090
3091        let dtype = self.primitive.dtype();
3092
3093        writeln!(f, "  dtype:  {:?},", dtype.name())?;
3094        write!(f, "}}")
3095    }
3096}
3097
3098/// Trait used for movedim arguments
3099pub trait MovedimArgs {
3100    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3101    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3102}
3103
3104impl MovedimArgs for Vec<i32> {
3105    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3106        let set = self
3107            .iter()
3108            .map(|&dim| {
3109                if dim < 0 {
3110                    (D as i32 + dim) as usize
3111                } else {
3112                    dim as usize
3113                }
3114            })
3115            .collect::<Vec<usize>>();
3116        check!(TensorCheck::movedim_args_vec::<D>(&set));
3117
3118        set
3119    }
3120}
3121
3122impl MovedimArgs for Vec<usize> {
3123    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3124        check!(TensorCheck::movedim_args_vec::<D>(&self));
3125        self
3126    }
3127}
3128
3129impl MovedimArgs for usize {
3130    #[allow(clippy::vec_init_then_push)]
3131    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3132        check!(TensorCheck::movedim_args_usize::<D>(self));
3133
3134        let mut set = Vec::with_capacity(1);
3135        set.push(self);
3136
3137        set
3138    }
3139}
3140
3141impl MovedimArgs for i32 {
3142    #[allow(clippy::vec_init_then_push)]
3143    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3144        check!(TensorCheck::movedim_args_i32::<D>(self));
3145
3146        let dim = if self < 0 {
3147            (D as i32 + self) as usize
3148        } else {
3149            self as usize
3150        };
3151
3152        let mut set = Vec::with_capacity(1);
3153        set.push(dim);
3154
3155        set
3156    }
3157}
3158
3159/// Trait used for reshape arguments.
3160pub trait ReshapeArgs<const D2: usize>: Debug {
3161    /// Converts to a shape.
3162    fn into_shape<const D: usize>(self, source: Shape) -> Shape;
3163}
3164
3165impl<const D2: usize, I: AsIndex> ReshapeArgs<D2> for [I; D2] {
3166    fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3167        unwrap_shape_reshape(source.reshape(self))
3168    }
3169}
3170
3171impl<const D2: usize> ReshapeArgs<D2> for Shape {
3172    fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3173        unwrap_shape_reshape(source.reshape(self))
3174    }
3175}
3176
3177/// Trait used for broadcast arguments.
3178pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3179    /// Converts to a shape.
3180    fn into_shape(self, shape: &Shape) -> Shape;
3181}
3182
3183impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3184    fn into_shape(self, _shape: &Shape) -> Shape {
3185        self
3186    }
3187}
3188
3189impl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {
3190    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3191    fn into_shape(self, shape: &Shape) -> Shape {
3192        if self.len() < shape.num_dims() {
3193            panic!("Broadcast arguments must be greater than the number of dimensions");
3194        }
3195
3196        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3197        let new_shape: Vec<_> = self
3198            .iter()
3199            .rev()
3200            .map(|x| {
3201                let primitive = x.as_index();
3202                if primitive < -1 || primitive == 0 {
3203                    panic!("Broadcast arguments must be positive or -1");
3204                }
3205                primitive
3206            })
3207            .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3208            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3209            .collect::<Vec<_>>()
3210            .into_iter()
3211            .rev()
3212            .collect();
3213
3214        if new_shape.contains(&0) {
3215            panic!("Cannot substitute -1 for a non-existing dimension");
3216        }
3217
3218        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3219
3220        Shape::from(new_shape)
3221    }
3222}
3223
3224impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3225where
3226    B: Backend,
3227    K: BasicOps<B>,
3228    K::Elem: Debug + Copy + Serialize,
3229{
3230    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3231        let data = self.to_data();
3232        data.serialize(serializer)
3233    }
3234}
3235
3236impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3237where
3238    B: Backend,
3239    K: BasicOps<B>,
3240    K::Elem: Debug + Copy + Deserialize<'de>,
3241{
3242    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3243        let tensor = Tensor::from_data(
3244            TensorData::deserialize(deserializer)?,
3245            &<B::Device as Default>::default(),
3246        );
3247        Ok(tensor)
3248    }
3249}
3250
3251#[cfg(test)]
3252mod tests {
3253    use burn_std::SliceOps;
3254
3255    use crate::{Shape, s};
3256
3257    #[test]
3258    fn slice_range_single_dim_leading() {
3259        let shape = Shape::new([8, 4]);
3260
3261        // Half-open range
3262        let slices = shape.clone().into_slices([0..5]);
3263        assert_eq!(slices[0].to_range(8), 0..5);
3264        let slices = shape.clone().into_slices([-3..-1]);
3265        assert_eq!(slices[0].to_range(8), 5..7);
3266
3267        // Inclusive range
3268        let slices = shape.clone().into_slices([0..=4]);
3269        assert_eq!(slices[0].to_range(8), 0..5);
3270        let slices = shape.clone().into_slices([-2..=-1]);
3271        assert_eq!(slices[0].to_range(8), 6..8);
3272
3273        // Unbounded start
3274        let slices = shape.clone().into_slices([..3]);
3275        assert_eq!(slices[0].to_range(8), 0..3);
3276        let slices = shape.clone().into_slices([..-5]);
3277        assert_eq!(slices[0].to_range(8), 0..3);
3278
3279        // Unbounded end
3280        let slices = shape.clone().into_slices([5..]);
3281        assert_eq!(slices[0].to_range(8), 5..8);
3282        let slices = shape.clone().into_slices([-3..]);
3283        assert_eq!(slices[0].to_range(8), 5..8);
3284
3285        // Full range
3286        let slices = shape.into_slices([..]);
3287        assert_eq!(slices[0].to_range(8), 0..8);
3288    }
3289
3290    #[test]
3291    fn test_negative_slice_indices() {
3292        use crate::Slice;
3293
3294        // Test negative indices conversion
3295        let slice: Slice = (-3..-1).into();
3296        assert_eq!(slice.start, -3);
3297        assert_eq!(slice.end, Some(-1));
3298
3299        // Test to_range conversion with size 8
3300        let range = slice.to_range(8);
3301        assert_eq!(range, 5..7);
3302
3303        // Test with shape slice
3304        let shape = Shape::new([8, 4]);
3305        let result = shape.clone().into_slices([-3..-1]);
3306        assert_eq!(result[0].to_range(8), 5..7);
3307
3308        // Test more negative index cases
3309        let slice2: Slice = (-5..).into();
3310        assert_eq!(slice2.to_range(10), 5..10);
3311
3312        let slice3: Slice = (..-2).into();
3313        assert_eq!(slice3.to_range(10), 0..8);
3314
3315        // Test with s! macro - single dimension returns Slice directly
3316        let slice4 = s![-3..-1];
3317        assert_eq!(slice4.start, -3);
3318        assert_eq!(slice4.end, Some(-1));
3319    }
3320
3321    #[test]
3322    fn slice_range_multi_dim() {
3323        let shape = Shape::new([8, 4]);
3324
3325        // Multiple ways to provide ranges
3326        let slices = shape.clone().into_slices([0..5, 0..4]);
3327        assert_eq!(slices[0].to_range(8), 0..5);
3328        assert_eq!(slices[1].to_range(4), 0..4);
3329
3330        let slices = shape.clone().into_slices([0.., 0..]);
3331        assert_eq!(slices[0].to_range(8), 0..8);
3332        assert_eq!(slices[1].to_range(4), 0..4);
3333
3334        let slices = shape.clone().into_slices([0..=7, 0..=3]);
3335        assert_eq!(slices[0].to_range(8), 0..8);
3336        assert_eq!(slices[1].to_range(4), 0..4);
3337
3338        let slices = shape.clone().into_slices([0..5, 0..3]);
3339        assert_eq!(slices[0].to_range(8), 0..5);
3340        assert_eq!(slices[1].to_range(4), 0..3);
3341
3342        let slices = shape.into_slices([0.., 0..]);
3343        assert_eq!(slices[0].to_range(8), 0..8);
3344        assert_eq!(slices[1].to_range(4), 0..4);
3345    }
3346
3347    #[test]
3348    fn slice_range_multi_dim_index() {
3349        let shape = Shape::new([8, 4]);
3350
3351        // Indices (single integer) should also convert to correct range
3352        let slices = shape.clone().into_slices([0, 2]);
3353        assert_eq!(slices[0].to_range(8), 0..1);
3354        assert_eq!(slices[1].to_range(4), 2..3);
3355
3356        let slices = shape.into_slices([-1, -1]);
3357        assert_eq!(slices[0].to_range(8), 7..8);
3358        assert_eq!(slices[1].to_range(4), 3..4);
3359    }
3360
3361    #[test]
3362    fn slice_range_multi_dim_heterogeneous() {
3363        // Slice macro `s![]` can be used to provide different range types
3364        let shape = Shape::new([8, 4, 2]);
3365        let slice = s![0..5, .., -1];
3366        let slices = shape.into_slices(slice);
3367        assert_eq!(slices[0].to_range(8), 0..5);
3368        assert_eq!(slices[1].to_range(4), 0..4);
3369        assert_eq!(slices[2].to_range(2), 1..2);
3370
3371        let shape = Shape::new([8, 4, 2, 3]);
3372        let slice = s![..=4, 0..=3, .., -2..];
3373        let slices = shape.into_slices(slice);
3374        assert_eq!(slices[0].to_range(8), 0..5);
3375        assert_eq!(slices[1].to_range(4), 0..4);
3376        assert_eq!(slices[2].to_range(2), 0..2);
3377        assert_eq!(slices[3].to_range(3), 1..3);
3378
3379        let shape = Shape::new([3, 4]);
3380        let slice = s![1..-1, ..];
3381        let slices = shape.into_slices(slice);
3382        assert_eq!(slices[0].to_range(3), 1..2);
3383        assert_eq!(slices[1].to_range(4), 0..4);
3384    }
3385}