burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2use crate::backend::ExecutionError;
3
4pub use burn_backend::tensor::BasicOps;
5
6use alloc::vec::Vec;
7
8use alloc::format;
9use alloc::string::String;
10use alloc::vec;
11
12use burn_std::stub::RwLock;
13use core::iter::repeat;
14use core::{fmt::Debug, ops::Range};
15use serde::{Deserialize, Deserializer};
16
17use crate::IndexingUpdateOp;
18use crate::{AsIndex, Slice, SliceArg, wrap_index};
19use crate::{
20    Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,
21    backend::Backend, check,
22};
23use crate::{DType, Element};
24use crate::{cast::ToElement, check::TensorCheck};
25use serde::{Serialize, Serializer};
26
27/// A tensor with a given backend, shape and data type.
28///
29/// # Indexing
30/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
31/// or [`select`](Tensor::select) for numeric types.
32///
33/// ## Example
34///
35/// ```rust
36/// use burn_tensor::backend::Backend;
37/// use burn_tensor::Tensor;
38/// use burn_tensor::Int;
39///
40/// fn example<B: Backend>() {
41///     let device = Default::default();
42///
43///     let tensor = Tensor::<B, 2>::from_data(
44///         [
45///             [3.0, 4.9, 2.0],
46///             [2.0, 1.9, 3.0],
47///             [6.0, 1.5, 7.0],
48///             [3.0, 4.9, 9.0],
49///         ],
50///         &device,
51///     );
52///
53///     // Slice the tensor to get the second and third rows:
54///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
55///     // The resulting tensor will have dimensions [2, 3].
56///     let slice = tensor.clone().slice([1..3]);
57///     println!("{slice}");
58///
59///     // Slice the tensor to get the first two rows and the first 2 columns:
60///     // [[3.0, 4.9], [2.0, 1.9]]
61///     // The resulting tensor will have dimensions [2, 2].
62///     let slice = tensor.clone().slice([0..2, 0..2]);
63///     println!("{slice}");
64///
65///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
66///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
67///     // The resulting tensor will have dimensions [4, 2]
68///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
69///     let indexed = tensor.select(1, indices);
70///     println!("{indexed}");
71/// }
72/// ```
73#[derive(new, Clone, Debug)]
74pub struct Tensor<B, const D: usize, K = Float>
75where
76    B: Backend,
77    K: TensorKind<B>,
78{
79    pub(crate) primitive: K::Primitive,
80}
81
82impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
83where
84    B: Backend,
85    K: BasicOps<B>,
86    T: Into<TensorData>,
87{
88    fn from(value: T) -> Self {
89        Tensor::from_data(value.into(), &Default::default())
90    }
91}
92
93impl<B, const D: usize, K> Tensor<B, D, K>
94where
95    B: Backend,
96    K: BasicOps<B>,
97    K::Elem: Element,
98{
99    /// Executes an operation on the tensor and modifies its value.
100    ///
101    /// # Notes
102    ///
103    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
104    /// no other reference pointing to the same tensor.
105    ///
106    /// Wrapping operations with inplace is not an optimization, it's mainly there if you
107    /// want to mutate a tensor by using owned operations. A plausible usage would be to
108    /// update the weights of a mutable model reference.
109    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
110        let mut tensor_owned = Tensor::empty([0; D], &self.device());
111        core::mem::swap(&mut tensor_owned, self);
112
113        let mut tensor_new = func(tensor_owned);
114        core::mem::swap(&mut tensor_new, self);
115    }
116
117    /// Converts the tensor into a primitive tensor.
118    pub fn into_primitive(self) -> K::Primitive {
119        self.primitive
120    }
121
122    /// Converts from a primitive tensor into a tensor.
123    pub fn from_primitive(tensor: K::Primitive) -> Self {
124        Self::new(tensor)
125    }
126
127    /// Returns the number of dimensions of the tensor.
128    pub fn rank(&self) -> usize {
129        self.primitive.rank()
130    }
131
132    /// Returns the tensor primitive data type.
133    ///
134    /// # Note
135    /// Some element types are encoded in different primitive types depending on the backend
136    /// (e.g., bool could be encoded as `u8` or `u32`).
137    pub fn dtype(&self) -> DType {
138        self.primitive.dtype()
139    }
140
141    /// Create an empty tensor of the given shape.
142    ///
143    /// # Arguments
144    ///
145    /// - `shape`: The shape of the tensor.
146    /// - `device`: The device where the tensor will be created.
147    ///
148    /// # Example
149    /// ```rust
150    /// use burn_tensor::backend::Backend;
151    /// use burn_tensor::Tensor;
152    ///
153    /// fn example<B: Backend>() {
154    ///    let device = Default::default();
155    ///    // Create an empty tensor with dimensions [2, 3, 4].
156    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
157    /// }
158    /// ```
159    pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
160        let shape = shape.into();
161        check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
162        Self::new(K::empty(shape, device, K::Elem::dtype()))
163    }
164
165    /// Create a tensor of the given shape where each element is zero.
166    ///
167    /// # Example
168    ///
169    /// ```rust
170    /// use burn_tensor::backend::Backend;
171    /// use burn_tensor::{Tensor, Shape};
172    ///
173    /// fn example<B: Backend>() {
174    ///    let device = B::Device::default();
175    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
176    ///    println!("{tensor}");
177    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
178    /// }
179    /// ```
180    pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
181        let shape = shape.into();
182        check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
183        Self::new(K::zeros(shape, device, K::Elem::dtype()))
184    }
185
186    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
187    ///
188    /// # Example
189    ///
190    /// ```rust
191    /// use burn_tensor::backend::Backend;
192    /// use burn_tensor::{Tensor, Shape};
193    ///
194    /// fn example<B: Backend>() {
195    ///   let device = B::Device::default();
196    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
197    ///   let tensor = tensor.zeros_like();
198    ///   println!("{tensor}");
199    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
200    /// }
201    /// ```
202    pub fn zeros_like(&self) -> Self {
203        Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
204    }
205
206    /// Create a tensor of the given shape where each element is one.
207    ///
208    /// # Example
209    ///
210    /// ```rust
211    /// use burn_tensor::backend::Backend;
212    /// use burn_tensor::{Tensor, Shape};
213    ///
214    /// fn example<B: Backend>() {
215    ///   let device = B::Device::default();
216    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
217    ///   println!("{tensor}");
218    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
219    /// }
220    /// ```
221    pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
222        let shape = shape.into();
223        check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
224        Self::new(K::ones(shape, device, K::Elem::dtype()))
225    }
226
227    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
228    ///
229    /// # Example
230    ///
231    /// ```rust
232    /// use burn_tensor::backend::Backend;
233    /// use burn_tensor::{Tensor, Shape};
234    ///
235    /// fn example<B: Backend>() {
236    ///    let device = B::Device::default();
237    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
238    ///    let tensor = tensor.ones_like();
239    ///    println!("{tensor}");
240    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
241    /// }
242    /// ```
243    pub fn ones_like(&self) -> Self {
244        Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
245    }
246
247    /// Create a tensor of the given shape where each element is equal to the provided value.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use burn_tensor::backend::Backend;
253    /// use burn_tensor::{Tensor, Shape};
254    ///
255    /// fn example<B: Backend>() {
256    ///   let device = B::Device::default();
257    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
258    ///   println!("{tensor}");
259    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
260    /// }
261    /// ```
262    pub fn full<S: Into<Shape>, E: ElementConversion>(
263        shape: S,
264        fill_value: E,
265        device: &B::Device,
266    ) -> Self {
267        let shape = shape.into();
268        check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
269        Self::new(K::full(shape, fill_value, device, K::Elem::dtype()))
270    }
271
272    /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
273    /// filled with the provided value.
274    ///
275    /// # Example
276    ///
277    /// ```rust
278    /// use burn_tensor::backend::Backend;
279    /// use burn_tensor::{Tensor, Shape};
280    ///
281    /// fn example<B: Backend>() {
282    ///    let device = B::Device::default();
283    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
284    ///    let tensor = tensor.full_like(5.0);
285    ///    println!("{tensor}");
286    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
287    /// }
288    /// ```
289    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
290        Self::new(K::full(
291            self.shape(),
292            fill_value,
293            &self.device(),
294            self.dtype(),
295        ))
296    }
297
298    /// Returns the dimensions of the current tensor.
299    ///
300    /// # Example
301    /// ```rust
302    /// use burn_tensor::backend::Backend;
303    /// use burn_tensor::Tensor;
304    ///
305    /// fn example<B: Backend>() {
306    ///   let device = Default::default();
307    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
308    ///   let dims = tensor.dims(); // [2, 3, 4]
309    ///   println!("{dims:?}");
310    /// }
311    /// ```
312    pub fn dims(&self) -> [usize; D] {
313        Self::shape(self).dims()
314    }
315
316    /// Returns the shape of the current tensor.
317    ///
318    /// # Example
319    /// ```rust
320    /// use burn_tensor::backend::Backend;
321    /// use burn_tensor::Tensor;
322    ///
323    /// fn example<B: Backend>() {
324    ///    let device = Default::default();
325    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
326    ///    // Shape { dims: [2, 3, 4] }
327    ///    let shape = tensor.shape();
328    /// }
329    /// ```
330    pub fn shape(&self) -> Shape {
331        self.primitive.shape()
332    }
333
334    /// Reshape the tensor to have the given shape.
335    ///
336    /// The tensor has the same data and number of elements as the input.
337    ///
338    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
339    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
340    ///
341    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
342    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
343    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
344    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
345    /// with [1, 3, 4] dimensions to [1, 12].
346    ///
347    /// # Arguments
348    /// - `shape`: The new shape of the tensor.
349    ///
350    /// # Panics
351    /// - If the tensor contains more than one `-1` in the shape.
352    /// - If the tensor contains values that are not positive (other than -1).
353    /// - If the shape does not match the number of elements of the original shape.
354    ///
355    /// # Example
356    ///
357    /// ```rust
358    /// use burn_tensor::backend::Backend;
359    /// use burn_tensor::Tensor;
360    ///
361    /// fn example<B: Backend>() {
362    ///    let device = Default::default();
363    ///    // Create a tensor with dimensions [2, 3, 4]
364    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
365    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
366    ///    let reshaped = tensor.reshape([2, -1]);
367    ///    println!("{reshaped}");
368    /// }
369    /// ```
370    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
371        // Convert reshape args to shape
372        let shape = shape.into_shape(&self);
373        Tensor::new(K::reshape(self.primitive, shape))
374    }
375
376    /// Transpose the tensor.
377    ///
378    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
379    /// applied on the last two dimensions. For example, the transpose of a tensor with shape
380    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
381    ///
382    /// See also [`permute`](Tensor::permute).
383    ///
384    /// # Arguments
385    ///
386    /// * `tensor` - The tensor to transpose.
387    ///
388    /// # Returns
389    ///
390    /// The transposed tensor.
391    ///
392    /// # Example
393    ///
394    /// ```rust
395    /// use burn_tensor::backend::Backend;
396    /// use burn_tensor::Tensor;
397    ///
398    /// fn example<B: Backend>() {
399    ///     let device = Default::default();
400    ///     // Create a 2D tensor of shape [2, 3]
401    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
402    ///
403    ///     // Transpose the tensor:
404    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
405    ///     // The resulting tensor will have dimensions [3, 2].
406    ///     let transposed = tensor.transpose();
407    ///     println!("{transposed}");
408    /// }
409    /// ```
410    pub fn transpose(self) -> Tensor<B, D, K> {
411        Tensor::new(K::transpose(self.primitive))
412    }
413
414    /// Alias for `transpose`.
415    #[inline(always)]
416    pub fn t(self) -> Tensor<B, D, K> {
417        self.transpose()
418    }
419
420    /// Swaps two dimensions of a tensor.
421    ///
422    /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
423    ///
424    /// # Arguments
425    ///
426    /// * `tensor` - The tensor to swap the dimensions of.
427    /// * `dim1` - The first dimension to swap, supports negative indexing.
428    /// * `dim2` - The second dimension to swap, supports negative indexing.
429    ///
430    /// # Returns
431    ///
432    /// The tensor with the dimensions swapped.
433    ///
434    /// # Panics
435    ///
436    /// When dimensions are out of bounds.
437    ///
438    /// # Example
439    ///
440    /// ```rust
441    /// use burn_tensor::backend::Backend;
442    /// use burn_tensor::Tensor;
443    ///
444    /// fn example<B: Backend>() {
445    ///     let device = Default::default();
446    ///     // Create a 2D tensor of shape [2, 3]
447    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
448    ///
449    ///     // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
450    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
451    ///     // The resulting tensor will have dimensions [3, 2].
452    ///     let swapped = tensor.swap_dims(0, -1);
453    ///     println!("{swapped}");
454    /// }
455    /// ```
456    pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
457    where
458        Dim1: AsIndex,
459        Dim2: AsIndex,
460    {
461        let dim1 = dim1.expect_dim_index(D);
462        let dim2 = dim2.expect_dim_index(D);
463        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
464        if dim1 == dim2 {
465            self
466        } else {
467            Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
468        }
469    }
470
471    /// Permute the dimensions of the tensor.
472    ///
473    /// This is a no-op when the resolved `axes` match the current order.
474    ///
475    /// # Arguments
476    ///
477    /// * `axes` - The new order of the dimensions. The length of the axes
478    ///   must be equal to the number of dimensions of the tensor.
479    ///   The values must be unique and in the range of the number of dimensions.
480    ///   The values can be negative, in which case they are used as an offset from the end.
481    ///
482    /// # Returns
483    ///
484    /// The tensor with the dimensions permuted.
485    ///
486    /// # Example
487    ///
488    /// ```rust
489    /// use burn_tensor::backend::Backend;
490    /// use burn_tensor::Tensor;
491    ///
492    /// fn example<B: Backend>() {
493    ///     let device = Default::default();
494    ///     // Create a 2D tensor of shape [3, 2]
495    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
496    ///
497    ///     // Permute the dimensions 1 and 0:
498    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
499    ///     // The resulting tensor will have dimensions [3, 2].
500    ///     let permuted = tensor.permute([1, 0]);
501    ///     println!("{permuted}");
502    /// }
503    /// ```
504    pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
505    where
506        Dim: AsIndex,
507    {
508        let mut no_op = true;
509        let mut fixed_axes = [0; D];
510        for (i, axis) in axes.into_iter().enumerate() {
511            let dim = axis.expect_dim_index(D);
512            no_op &= dim == i;
513            fixed_axes[i] = dim;
514        }
515
516        if no_op {
517            self
518        } else {
519            check!(TensorCheck::permute(fixed_axes));
520            Tensor::new(K::permute(self.primitive, &fixed_axes))
521        }
522    }
523
524    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
525    ///
526    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
527    /// at the positions not specified in destination.
528    ///
529    /// # Arguments
530    ///
531    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
532    ///   The values can be negative, in which case they are used as an offset from the end.
533    ///
534    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
535    ///
536    /// # Panics
537    ///
538    /// - If the source and destination dimensions are not of the same length.
539    /// - If the source and destination vectors contain duplicate values.
540    /// - If the source and destination vectors contain values that are out of bounds.
541    ///
542    /// # Returns
543    ///
544    /// The tensor with the dimensions moved.
545    ///
546    /// # Example
547    ///
548    /// ```rust
549    /// use burn_tensor::backend::Backend;
550    /// use burn_tensor::Tensor;
551    ///
552    /// fn example<B: Backend>() {
553    ///     let device = Default::default();
554    ///     // Create a 3D tensor of shape [3, 2, 1]
555    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
556    ///
557    ///     // Move the dimensions 0 and 1:
558    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
559    ///     // The resulting tensor will have dimensions [2, 3, 1].
560    ///     let moved = tensor.movedim(1, 0);
561    ///     println!("{moved}");
562    /// }
563    /// ```
564    ///
565    /// # Note
566    ///
567    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
568    /// for it
569    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
570        let source_dims = src.into_dim_vec::<D>();
571        let destination_dims = dst.into_dim_vec::<D>();
572
573        check!(TensorCheck::movedim_args_length(
574            &source_dims,
575            &destination_dims
576        ));
577
578        let mut m = [-1; D];
579        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
580            m[d] = s as isize;
581        }
582        let mut axes: [isize; D] = [0; D];
583        let mut source_i = 0;
584        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
585            *item = if m[dest_i] != -1 {
586                m[dest_i]
587            } else {
588                while source_dims.contains(&source_i) {
589                    source_i += 1;
590                }
591                let result = source_i as isize;
592                source_i += 1;
593                result
594            };
595        }
596
597        self.permute(axes)
598    }
599
600    /// Reverse the order of elements in the tensor along the given dimensions.
601    ///
602    /// # Arguments
603    ///
604    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
605    ///   The values can be negative, in which case they are used as an offset from the end.
606    ///
607    /// # Returns
608    ///
609    /// The tensor with the axes flipped.
610    ///
611    /// # Example
612    ///
613    /// ```rust
614    /// use burn_tensor::backend::Backend;
615    /// use burn_tensor::Tensor;
616    ///
617    /// fn example<B: Backend>() {
618    ///     let device = Default::default();
619    ///     // Create a 2D tensor with dimensions [4, 3]
620    ///     let tensor = Tensor::<B, 2>::from_data(
621    ///         [
622    ///             [3.0, 4.9, 2.0],
623    ///             [2.0, 1.9, 3.0],
624    ///             [4.0, 5.9, 8.0],
625    ///             [1.4, 5.8, 6.0],
626    ///         ],
627    ///         &device,
628    ///     );
629    ///
630    ///     // Flip the elements in dimensions 0 and 1:
631    ///     // [[6.0, 5.8, 1.4],
632    ///     //  [8.0, 5.9, 4.0],
633    ///     //  [3.0, 1.9, 2.0],
634    ///     //  [2.0, 4.9, 3.0]]
635    ///     // The resulting tensor will have dimensions [4, 3].
636    ///     let flipped = tensor.flip([0, 1]);
637    ///     println!("{flipped}");
638    /// }
639    /// ```
640    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
641        // Convert the axes to usize and handle negative values without using vector
642        let mut transformed_axes: [usize; N] = [0; N];
643        for (i, &x) in axes.iter().enumerate() {
644            transformed_axes[i] = if x < 0 {
645                (D as isize + x) as usize
646            } else {
647                x as usize
648            };
649        }
650
651        // Check if the axes are valid
652        check!(TensorCheck::flip(D, &transformed_axes));
653
654        Tensor::new(K::flip(self.primitive, &transformed_axes))
655    }
656
657    /// Flatten the tensor along a given range of dimensions.
658    ///
659    /// This function collapses the specified range of dimensions into a single dimension,
660    /// effectively flattening the tensor in that range.
661    ///
662    /// # Arguments
663    ///
664    /// - `start_dim`: The starting dimension of the range to be flattened,
665    ///   supports negative indexing.
666    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
667    ///   supports negative indexing.
668    ///
669    /// # Type Parameters
670    ///
671    /// - `D2`: The resulting number of dimensions in the flattened tensor.
672    ///
673    /// # Returns
674    ///
675    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
676    ///
677    /// # Example
678    ///
679    /// ```rust
680    ///
681    /// use burn_tensor::backend::Backend;
682    /// use burn_tensor::{Tensor, Shape};
683    ///
684    /// fn example<B: Backend>() {
685    ///     let device = Default::default();
686    ///     // Create a 3D tensor with dimensions [2, 3, 4]
687    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
688    ///
689    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
690    ///     // The resulting tensor will have dimensions [2, 12]
691    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
692    ///     println!("{flattened}");
693    /// }
694    /// ```
695    pub fn flatten<const D2: usize>(
696        self,
697        start_dim: impl AsIndex,
698        end_dim: impl AsIndex,
699    ) -> Tensor<B, D2, K> {
700        let start_dim = start_dim.expect_dim_index(D);
701        let end_dim = end_dim.expect_dim_index(D);
702        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
703        let new_shape = self.shape().flatten_dims(start_dim, end_dim);
704
705        Tensor::new(K::reshape(self.primitive, new_shape))
706    }
707
708    /// Squeeze the tensor along all dimensions, removing dimensions
709    /// of size one, and effectively reducing the rank of the tensor.
710    ///
711    /// # Type Parameters
712    ///
713    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
714    ///
715    /// # Returns
716    ///
717    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
718    ///
719    /// # Example
720    ///
721    /// ```rust
722    ///
723    /// use burn_tensor::backend::Backend;
724    /// use burn_tensor::{Tensor, Shape};
725    ///
726    /// fn example<B: Backend>() {
727    ///     let device = Default::default();
728    ///     // Create a 4D tensor with dimensions [1, 3, 1, 3]
729    ///     let tensor = Tensor::<B, 4>::from_data(
730    ///         [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
731    ///         &device,
732    ///     );
733    ///
734    ///     // Squeeze the tensor dimensions.
735    ///     // The resulting tensor will have dimensions [3, 3].
736    ///     let squeezed = tensor.squeeze::<2>();
737    ///     println!("{squeezed}");
738    /// }
739    /// ```
740    pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
741        let new_dims = self
742            .shape()
743            .dims
744            .iter()
745            .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
746            .collect::<Vec<_>>();
747        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
748
749        Tensor::new(K::reshape(self.primitive, new_dims.into()))
750    }
751
752    /// Squeeze the tensor along the given dimension, removing the specified dimension
753    /// of size one, and effectively reducing the rank of the tensor by one.
754    ///
755    /// # Arguments
756    ///
757    /// - `dim`: The dimension to be squeezed.
758    ///
759    /// # Type Parameters
760    ///
761    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
762    ///
763    /// # Panics
764    ///
765    /// If the size in the squeezed dimension is not 1.
766    ///
767    /// # Returns
768    ///
769    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
770    ///
771    /// # Example
772    ///
773    /// ```rust
774    ///
775    /// use burn_tensor::backend::Backend;
776    /// use burn_tensor::{Tensor, Shape};
777    ///
778    /// fn example<B: Backend>() {
779    ///     let device = Default::default();
780    ///     // Create a 3D tensor with dimensions [3, 1, 3]
781    ///     let tensor = Tensor::<B, 3>::from_data(
782    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
783    ///         &device,
784    ///     );
785    ///
786    ///     // Squeeze the dimension 1.
787    ///     // The resulting tensor will have dimensions [3, 3].
788    ///     let squeezed = tensor.squeeze_dim::<2>(1);
789    ///     println!("{squeezed}");
790    /// }
791    /// ```
792    pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
793        check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
794
795        let current_dims = self.shape().dims;
796        let mut new_dims: [usize; D2] = [0; D2];
797
798        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
799        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
800
801        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
802        Tensor::new(K::reshape(self.primitive, new_dims.into()))
803    }
804
805    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
806    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
807    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
808    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
809    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
810    /// from the back.
811    ///
812    /// # Arguments
813    ///
814    /// - `dims`: The dimension(s) to be squeezed.
815    ///
816    /// # Type Parameters
817    ///
818    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
819    ///
820    /// # Returns
821    ///
822    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
823    ///
824    /// # Example
825    ///
826    /// ```rust
827    ///
828    /// use burn_tensor::backend::Backend;
829    /// use burn_tensor::{Tensor, Shape};
830    ///
831    /// fn example<B: Backend>() {
832    ///     let device = Default::default();
833    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
834    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
835    ///
836    ///     // Squeeze the dimensions 1 and 3.
837    ///     // The resulting tensor will have dimensions [2, 4].
838    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
839    ///     println!("{squeezed}");
840    /// }
841    /// ```
842    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
843        let current_dims = self.shape().dims;
844        let mut dim_indices: Vec<usize>;
845
846        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
847        if dims.is_empty() {
848            dim_indices = current_dims
849                .iter()
850                .enumerate()
851                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
852                .collect();
853        } else {
854            // If negative dims, count from the back
855            dim_indices = dims
856                .iter()
857                .map(|&d| {
858                    if d < 0 {
859                        (current_dims.len() as isize + d) as usize
860                    } else {
861                        d as usize
862                    }
863                })
864                .collect();
865        }
866
867        // Sort indices and remove duplicates
868        dim_indices.sort_unstable();
869        dim_indices.dedup();
870
871        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
872        check!(TensorCheck::squeeze_dims_input::<D2>(
873            &dim_indices,
874            &current_dims
875        ));
876
877        // Calculate new dimensions
878        let mut new_dims = Vec::new();
879        for (index, &dim_size) in current_dims.iter().enumerate() {
880            // Exclude the dimension if it's explicitly marked for squeezing
881            if dim_indices.contains(&index) {
882                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
883                continue;
884            }
885            new_dims.push(dim_size);
886        }
887
888        // Check that after squeezing, we still respect the D2 size
889        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
890
891        Tensor::new(K::reshape(self.primitive, new_dims.into()))
892    }
893
894    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
895    ///
896    /// # Type Parameters
897    ///
898    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.
899    ///
900    /// # Panics
901    ///
902    /// If the output size `D2` is smaller than the current number of dimensions.
903    ///
904    /// # Returns
905    ///
906    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
907    ///
908    /// # Example
909    ///
910    /// ```rust
911    /// use burn_tensor::backend::Backend;
912    /// use burn_tensor::{Tensor, Shape};
913    ///
914    /// fn example<B: Backend>() {
915    ///     let device = Default::default();
916    ///     // Create a 2D tensor with dimensions [3, 3]
917    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
918    ///     // Unsqueeze the tensor up to 4 dimensions.
919    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
920    ///     let unsqueezed = tensor.unsqueeze::<4>();
921    ///     println!("{unsqueezed}");
922    /// }
923    /// ```
924    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
925        check!(TensorCheck::unsqueeze::<D, D2>());
926
927        let mut dims = [1; D2];
928        let num_ones = D2 - D;
929        let shape = self.shape();
930
931        dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
932
933        let shape = Shape::new(dims);
934        self.reshape(shape)
935    }
936
937    /// Creates a new tensor with a dimension of size one inserted at the specified position.
938    ///
939    /// # Example
940    ///
941    /// ```rust
942    /// use burn_tensor::backend::Backend;
943    /// use burn_tensor::{Tensor, Shape};
944    ///
945    /// fn example<B: Backend>() {
946    ///     let device = Default::default();
947    ///     // Create a 2D tensor with dimensions [3, 3]
948    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
949    ///     // Unsqueeze the dimension 1.
950    ///     // The resulting tensor will have dimensions [3, 1, 3].
951    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
952    ///     println!("{unsqueezed}");
953    /// }
954    /// ```
955    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
956        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
957
958        let mut dims = [1; D2];
959        let shape = self.shape();
960
961        dims[0..dim].copy_from_slice(&shape[0..dim]);
962
963        if dim < D {
964            dims[dim] = 1;
965            dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
966        } else {
967            dims[dim] = 1;
968        }
969
970        let shape = Shape::new(dims);
971        self.reshape(shape)
972    }
973
974    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
975    /// The indices can be negative, in which case they are counted from the last to the first dimension.
976    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
977    /// is the number of duplicates.
978    /// # Example
979    ///
980    /// ```rust
981    /// use burn_tensor::backend::Backend;
982    /// use burn_tensor::{Tensor, Shape};
983    ///
984    /// fn example<B: Backend>() {
985    ///     let device = Default::default();
986    ///     // Create a 3D tensor with dimensions [3, 4, 5]
987    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
988    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
989    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
990    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
991    ///     println!("{unsqueezed}");
992    /// }
993    /// ```
994    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
995        let mut new_dims = [1; D2];
996        let old_dims = self.shape().dims;
997        //for checking if the dimension is in the acceptable range
998
999        //part 1: convert the negative indices to positive
1000        let mut neg_offset = D2;
1001        let mut dim_indices = axes
1002            .iter()
1003            .map(|d| {
1004                // check if the dimension is in the acceptable range
1005                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
1006                (if *d < 0 {
1007                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
1008                    d + neg_offset as isize + 1
1009                } else {
1010                    *d
1011                }) as usize
1012            })
1013            .collect::<Vec<usize>>();
1014
1015        //sort the indices
1016        dim_indices.sort_unstable();
1017
1018        //Now use this to copy the chunks of the dims
1019        let mut prev_idx: usize = 0;
1020        let mut current_left_b: usize = 0;
1021        let mut current_right_b: usize = 0;
1022        let mut offset: usize = 0;
1023        dim_indices.iter().for_each(|d| {
1024            //check if there is space for at least one dimension
1025            if prev_idx < *d {
1026                current_right_b = *d - offset;
1027                //copy the chunks of the dims
1028                if current_right_b < D {
1029                    new_dims[prev_idx..*d]
1030                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);
1031                } else {
1032                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
1033                }
1034                prev_idx = *d + 1;
1035                //offset is equal to the number of extracted elements from the original shape
1036                offset += current_right_b - current_left_b;
1037                current_left_b = current_right_b;
1038            } else {
1039                //it's sorted so the only reason this would happen
1040                //is if multiple indices are the same
1041                prev_idx += 1;
1042            }
1043        });
1044        //copy over anything past the index of the last new dimension
1045        if current_left_b < D {
1046            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
1047        }
1048
1049        //lastly, create the shape and reshape
1050        let shape = Shape::new(new_dims);
1051        self.reshape(shape)
1052    }
1053
1054    /// Roll operation along a specific dimension; wrapping around the elements.
1055    ///
1056    /// ## Parameters
1057    ///
1058    /// - `shift`: The roll extent; supports negative values and wraps around.
1059    /// - `dim`: The dimension to roll; supports negative indexing.
1060    ///
1061    /// ## Returns
1062    ///
1063    /// A new tensor with the specified dimension rolled by the given shift amount.
1064    pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
1065    where
1066        Shift: AsIndex,
1067        Dim: AsIndex,
1068    {
1069        let dim = dim.expect_dim_index(D);
1070        let size = self.shape().dims[dim];
1071        if size == 0 {
1072            // If the dimension is empty, return the tensor as is.
1073            return self;
1074        }
1075
1076        let shift = wrap_index(shift, size);
1077        if shift == 0 {
1078            // If the shift is zero, return the tensor as is.
1079            return self;
1080        }
1081
1082        self.unchecked_roll_dim(shift, dim)
1083    }
1084
1085    /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1086    ///
1087    /// ## Parameters
1088    ///
1089    /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1090    /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1091    ///
1092    /// ## Returns
1093    ///
1094    /// A new tensor with the specified dimension rolled by the given shift amount.
1095    #[inline(always)]
1096    fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1097        #[cfg(debug_assertions)]
1098        {
1099            let size = self.shape().dims[dim];
1100            assert!(
1101                0 < shift && shift < size,
1102                "Expected: 0 < shift < size: found shift={shift}, size={size}",
1103            );
1104            assert!(
1105                dim < self.shape().num_dims(),
1106                "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1107            );
1108        }
1109
1110        Tensor::cat(
1111            vec![
1112                self.clone().slice_dim(dim, shift..),
1113                self.slice_dim(dim, ..shift),
1114            ],
1115            dim,
1116        )
1117    }
1118
1119    /// Roll operation.
1120    ///
1121    /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1122    ///
1123    /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1124    ///
1125    /// ## Parameters
1126    ///
1127    /// - `shifts`: A slice of shifts corresponding to each dimension;
1128    ///   supports negative values and wraps around.
1129    /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1130    ///
1131    /// ## Returns
1132    ///
1133    /// A new tensor with the specified dimensions rolled by the given shifts.
1134    pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1135    where
1136        Shift: AsIndex,
1137        Dim: AsIndex,
1138    {
1139        assert_eq!(
1140            dims.len(),
1141            shifts.len(),
1142            "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1143        );
1144
1145        // This is a fair amount of complexity, which could be replaced
1146        // by a simple canonicalization of `dims` and wrapping of `shifts`.
1147        // The work is done here to ensure that any roll operation
1148        // which could be a no-op is a no-op; simplifying the accounting
1149        // needed by backend-specific implementations of the inner roll op.
1150
1151        let item_count = dims.len();
1152
1153        let shape = self.shape().dims;
1154
1155        // Accumulate the effective shifts for each dimension.
1156        let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1157        for i in 0..item_count {
1158            let dim = dims[i].expect_dim_index(D);
1159            accumulated_shifts[dim] += shifts[i].index();
1160        }
1161
1162        // Do this after we've checked the validity of `dims` and `shifts`.
1163        if self.shape().num_elements() == 0 {
1164            // If the tensor is empty, return it as is.
1165            return self;
1166        }
1167
1168        // Wrap the accumulated shifts, and filter out empty dimensions.
1169        let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1170        let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1171        for dim in 0..shape.len() {
1172            // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1173            let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1174            if shift == 0 {
1175                continue;
1176            }
1177
1178            effective_dims.push(dim);
1179            effective_shifts.push(shift);
1180        }
1181
1182        // If no shifts are needed, return the original tensor.
1183        if effective_shifts.is_empty() {
1184            return self;
1185        }
1186
1187        // At this point:
1188        // - `dims` contains the effective dimensions to roll, in index order,
1189        // - `shifts` contains the effective usize shifts for each dimension.
1190        // - Every shift is non-zero, and less than the size of the corresponding dimension.
1191        self.unchecked_roll(&effective_shifts, &effective_dims)
1192    }
1193
1194    /// `roll` internal implementation.
1195    ///
1196    /// ## Parameters
1197    ///
1198    /// - `shifts`: A slice of shifts corresponding to each dimension;
1199    ///   must be non-empty, the same length as `dims`, and all ``1..<size>``.
1200    /// - `dims`: A slice of dimensions to roll; must be non-empty;
1201    ///   the same length as `shifts`, and must not contain repeats.
1202    ///
1203    /// ## Panics
1204    ///
1205    /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1206    ///
1207    /// ## Returns
1208    ///
1209    /// A new tensor with the specified dimensions rolled by the given shifts.
1210    #[inline(always)]
1211    fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1212        #[cfg(debug_assertions)]
1213        {
1214            assert!(!shifts.is_empty());
1215            assert_eq!(
1216                shifts.len(),
1217                dims.len(),
1218                "Shifts and dimensions must align; found {} shifts and {} dims",
1219                shifts.len(),
1220                dims.len()
1221            );
1222
1223            let mut unique_dims = dims.to_vec();
1224            unique_dims.dedup();
1225
1226            assert_eq!(
1227                unique_dims.len(),
1228                dims.len(),
1229                "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1230                unique_dims.len(),
1231                dims.len()
1232            )
1233        }
1234
1235        let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1236
1237        if dims.len() == 1 {
1238            x
1239        } else {
1240            x.unchecked_roll(&shifts[1..], &dims[1..])
1241        }
1242    }
1243
1244    /// Returns a tensor containing the elements selected from the given slices.
1245    ///
1246    /// This method provides flexible tensor slicing with support for various range types,
1247    /// negative indices, and stepped slicing. The method accepts both single slices and
1248    /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1249    ///
1250    /// # Arguments
1251    ///
1252    /// * `slices` - Can be:
1253    ///   - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1254    ///   - An array of ranges (e.g., `[0..2, 1..4]`)
1255    ///   - The [`s!`] macro output for advanced slicing with steps
1256    ///   - a `&Vec<Slice>` or `&[Slice]`
1257    ///
1258    /// # Behavior
1259    ///
1260    /// - Supports partial and full slicing in any number of dimensions
1261    /// - Handles negative indices by wrapping from the end (-1 is the last element)
1262    /// - Automatically clamps ranges that exceed tensor dimensions
1263    /// - Supports stepped slicing for selecting every nth element
1264    /// - Negative steps reverse the selection order
1265    ///
1266    /// # Panics
1267    ///
1268    /// - If the number of slices exceeds the tensor's dimensions
1269    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1270    /// - If a step is zero
1271    ///
1272    /// # Examples
1273    ///
1274    /// ```rust
1275    /// use burn_tensor::backend::Backend;
1276    /// use burn_tensor::{Tensor, Shape, s};
1277    ///
1278    /// fn example<B: Backend>() {
1279    ///     let device = B::Device::default();
1280    ///
1281    ///     // Single dimension slicing - no brackets needed!
1282    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1283    ///     let slice = tensor.clone().slice(2..8);  // Simple range
1284    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1285    ///
1286    ///     // Using s! macro for single dimension with step
1287    ///     let slice = tensor.clone().slice(s![0..10;2]);  // Every 2nd element
1288    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1289    ///
1290    ///     // Reverse a dimension with negative step
1291    ///     let slice = tensor.slice(s![..;-1]);  // Reverse entire tensor
1292    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1293    ///
1294    ///     // Multi-dimensional slicing
1295    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1296    ///
1297    ///     // Array syntax for simple ranges
1298    ///     let slice = tensor.clone().slice([1..3, 2..5]);
1299    ///     assert_eq!(slice.dims(), [2, 3]);
1300    ///
1301    ///     // Advanced multi-dimensional with s! macro
1302    ///     let slice = tensor.clone().slice(s![0..4;2, ..;-1]);  // Every 2nd row, reverse columns
1303    ///     assert_eq!(slice.dims(), [2, 6]);
1304    ///
1305    ///     // Complex 3D example with mixed slice types
1306    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1307    ///     let slice = tensor.slice(s![1..3, ..;2, -3..]);  // Rows 1-2, every 2nd col, last 3 depth
1308    ///     assert_eq!(slice.dims(), [2, 3, 3]);
1309    ///
1310    ///     // Using negative indices
1311    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1312    ///     let slice = tensor.slice(s![-2.., ..-1]);  // Last 2 rows, all but last column
1313    ///     assert_eq!(slice.dims(), [2, 5]);
1314    /// }
1315    /// ```
1316    ///
1317    /// # See Also
1318    ///
1319    /// - [`s!`] - The recommended macro for creating complex slice specifications
1320    /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1321    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1322    /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1323    ///
1324    /// [`s!`]: crate::s!
1325    pub fn slice<S>(self, slices: S) -> Self
1326    where
1327        S: SliceArg,
1328    {
1329        let shape = self.shape();
1330        let slices = slices.into_slices(&shape);
1331
1332        // Validate slices
1333        check!(TensorCheck::slice::<D>(&shape, &slices));
1334
1335        // Calculate output shape and check for empty slices
1336        let mut output_dims = shape.dims.clone();
1337        for (dim, slice) in slices.iter().enumerate() {
1338            output_dims[dim] = slice.output_size(shape.dims[dim]);
1339        }
1340
1341        // Return empty tensor if any dimension is 0 (empty slice)
1342        if output_dims.contains(&0) {
1343            return Self::empty(output_dims, &self.device());
1344        }
1345        Self::new(K::slice(self.primitive, &slices))
1346    }
1347
1348    /// Assigns values to a slice of the tensor and returns the updated tensor.
1349    ///
1350    /// This method supports advanced slicing with steps, including negative steps for reverse
1351    /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1352    /// providing powerful syntax for complex patterns.
1353    ///
1354    /// # Arguments
1355    ///
1356    /// * `slices` - Slice specification (same format as `slice` method)
1357    /// * `values` - Tensor with values to assign (must match slice dimensions)
1358    ///
1359    /// # Panics
1360    ///
1361    /// - If slices exceed tensor dimensions
1362    /// - If values dimensions don't match the selected slice shape
1363    /// - If a step is zero
1364    ///
1365    /// # Examples
1366    ///
1367    /// ```rust
1368    /// use burn_tensor::backend::Backend;
1369    /// use burn_tensor::{Tensor, s};
1370    ///
1371    /// fn example<B: Backend>() {
1372    ///     let device = B::Device::default();
1373    ///
1374    ///     // Simple assignment to a sub-region
1375    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1376    ///     let values = Tensor::<B, 2>::ones([2, 3], &device);
1377    ///     tensor = tensor.slice_assign([1..3, 2..5], values);
1378    ///     // Now tensor[1..3, 2..5] contains ones
1379    ///
1380    ///     // Single dimension assignment with step
1381    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1382    ///     let values = Tensor::<B, 1>::ones([5], &device);
1383    ///     tensor = tensor.slice_assign(s![0..10;2], values);
1384    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1385    ///
1386    ///     // Reverse assignment with negative step
1387    ///     let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1388    ///     let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1389    ///     tensor = tensor.slice_assign(s![..;-1], values);
1390    ///     // Assigns in reverse: [14, 13, 12, 11, 10]
1391    ///
1392    ///     // Complex multi-dimensional assignment
1393    ///     let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1394    ///     let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1395    ///     tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1396    ///     // Assigns to every 2nd row, every 2nd column, last 3 in depth
1397    ///
1398    ///     // Mixed syntax example
1399    ///     let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1400    ///     let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1401    ///     tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1402    ///     // Creates a checkerboard pattern with ones
1403    /// }
1404    /// ```
1405    ///
1406    /// # See Also
1407    ///
1408    /// - [`s!`] - The recommended macro for creating complex slice specifications
1409    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1410    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1411    ///
1412    /// [`s!`]: crate::s!
1413    pub fn slice_assign<S>(self, slices: S, values: Self) -> Self
1414    where
1415        S: SliceArg,
1416    {
1417        let shape = self.shape();
1418        let slices = slices.into_slices(&shape);
1419
1420        // Check if any slice produces 0 elements (empty assignment).
1421        // Empty assignments are no-ops and would cause issues in backend implementations.
1422        let is_empty_assignment = slices
1423            .iter()
1424            .enumerate()
1425            .any(|(i, slice)| slice.output_size(shape.dims[i]) == 0);
1426
1427        if is_empty_assignment {
1428            return self;
1429        }
1430
1431        check!(TensorCheck::slice_assign::<D>(
1432            &shape,
1433            &values.shape(),
1434            &slices
1435        ));
1436
1437        Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1438    }
1439
1440    /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1441    ///
1442    /// Like other slice methods, accepts both single slices and arrays. However, this method
1443    /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1444    /// with a constant tensor for stepped patterns.
1445    ///
1446    /// # Arguments
1447    ///
1448    /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1449    /// * `value` - The value to fill the slice with
1450    ///
1451    /// # Panics
1452    ///
1453    /// - If slices exceed tensor dimensions
1454    /// - If any slice has a step != 1 (not yet supported)
1455    ///
1456    /// # Examples
1457    ///
1458    /// ```rust
1459    /// use burn_tensor::backend::Backend;
1460    /// use burn_tensor::{Tensor, s};
1461    ///
1462    /// fn example<B: Backend>() {
1463    ///     let device = B::Device::default();
1464    ///
1465    ///     // Simple fill for a single dimension
1466    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1467    ///     tensor = tensor.slice_fill(2..5, 1.0);
1468    ///     // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1469    ///
1470    ///     // Multi-dimensional fill
1471    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1472    ///     tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1473    ///     // Fills the rectangle at rows 1-2, columns 2-4 with -1
1474    ///
1475    ///     // Using negative indices
1476    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1477    ///     tensor = tensor.slice_fill(-3.., 2.0);
1478    ///     // Fills the last 3 elements with 2.0
1479    ///
1480    ///     // Complex multi-dimensional example
1481    ///     let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1482    ///     tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1483    ///     // Sets rows 1-2, all columns, last 2 in depth to 0
1484    ///
1485    ///     // Stepped slicing is supported
1486    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1487    ///     tensor = tensor.slice_fill(s![0..10;2], 1.0);
1488    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1489    /// }
1490    /// ```
1491    ///
1492    /// # See Also
1493    ///
1494    /// - [`s!`] - The macro for creating slice specifications with steps
1495    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1496    /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1497    ///
1498    /// [`s!`]: crate::s!
1499    pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self
1500    where
1501        S: SliceArg,
1502    {
1503        let shape = self.shape();
1504        let slices = slices.into_slices(&shape);
1505
1506        check!(TensorCheck::slice::<D>(&shape, &slices));
1507
1508        let slice_shape = shape.slice(&slices).unwrap();
1509        let value = Tensor::<B, 1, K>::from_data_dtype(
1510            [value.elem::<K::Elem>()],
1511            &self.device(),
1512            self.dtype(),
1513        );
1514        let value = value.expand(slice_shape);
1515        self.slice_assign(&slices, value)
1516    }
1517
1518    /// Returns a new tensor with the specified dimension sliced.
1519    ///
1520    /// # Arguments
1521    ///
1522    /// * `dim`: The dimension to slice.
1523    /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1524    ///   slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1525    ///
1526    /// # Returns
1527    ///
1528    /// A new tensor with the specified dimension sliced.
1529    ///
1530    /// # Panics
1531    ///
1532    /// If the slice is out of bounds for the specified dimension.
1533    ///
1534    /// # Examples
1535    ///
1536    /// ```rust
1537    /// # use burn_tensor::{Tensor, s};
1538    /// # use burn_tensor::backend::Backend;
1539    /// #
1540    /// # fn example<B: Backend>() {
1541    /// #     let device = B::Device::default();
1542    ///     let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1543    ///
1544    ///     // Simple range slicing
1545    ///     let sliced = tensor.clone().slice_dim(1, 1..3);
1546    ///     assert_eq!(sliced.shape().dims, [3, 2, 5]);
1547    ///
1548    ///     // Slicing with step - take every 2nd element
1549    ///     let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1550    ///     assert_eq!(sliced.shape().dims, [3, 4, 3]); // Takes indices 0, 2, 4
1551    ///
1552    ///     // Reverse slicing with negative step
1553    ///     let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1554    ///     assert_eq!(sliced.shape().dims, [3, 4, 5]); // Reverses dimension 1
1555    ///
1556    ///     // Select from index 2 with step 3
1557    ///     let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1558    ///     assert_eq!(sliced.shape().dims, [1, 4, 5]); // Takes only index 2
1559    ///
1560    ///     // Select single index (reduces dimension to size 1)
1561    ///     let sliced = tensor.slice_dim(0, 1);
1562    ///     assert_eq!(sliced.shape().dims, [1, 4, 5]);
1563    /// # }
1564    /// ```
1565    ///
1566    /// # See Also
1567    ///
1568    /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1569    /// - [`s!`] - The macro for creating complex slice specifications
1570    ///
1571    /// [`s!`]: crate::s!
1572    pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1573    where
1574        S: Into<Slice>,
1575    {
1576        check!(TensorCheck::check_dim::<D>(dim));
1577        let slice: Slice = slice.into();
1578
1579        let mut slices = vec![Slice::full(); D];
1580        slices[dim] = slice;
1581
1582        self.slice(&slices)
1583    }
1584
1585    /// Returns the device of the current tensor.
1586    pub fn device(&self) -> B::Device {
1587        K::device(&self.primitive)
1588    }
1589
1590    /// Move the tensor to the given device.
1591    pub fn to_device(self, device: &B::Device) -> Self {
1592        Self::new(K::to_device(self.primitive, device))
1593    }
1594
1595    /// Select tensor elements along the given dimension corresponding to the given indices.
1596    ///
1597    /// # Arguments
1598    ///
1599    /// * `dim` - The dimension to select from. Supports negative indexing.
1600    /// * `indices` - The indices of the elements to select.
1601    ///
1602    /// # Example
1603    ///
1604    /// ```rust
1605    /// use burn_tensor::backend::Backend;
1606    /// use burn_tensor::{Tensor, Int};
1607    ///
1608    /// fn example<B: Backend>() {
1609    ///   let device = B::Device::default();
1610    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1611    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1612    ///   let tensor = tensor.select(0, indices);
1613    ///   println!("{tensor}");
1614    ///   //  [[1.0, -2.0, 3.0]]
1615    /// }
1616    /// ```
1617    pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1618        let dim = dim.expect_dim_index(D);
1619        check!(TensorCheck::select::<D>(dim));
1620        Self::new(K::select(self.primitive, dim, indices.primitive))
1621    }
1622
1623    /// Assign the selected elements along the given dimension corresponding to the given indices
1624    /// from the value tensor to the original tensor using sum reduction.
1625    ///
1626    /// # Note
1627    /// For booleans, the sum operator is logical or.
1628    ///
1629    /// # Arguments
1630    ///
1631    /// * `dim` - The dimension along which to select. Supports negative indexing.
1632    /// * `indices` - The indices to select from the tensor.
1633    /// * `values` - The values to assign to the selected indices.
1634    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1635    ///
1636    /// # Example
1637    ///
1638    /// Example using a 3D tensor:
1639    ///
1640    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1641    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1642    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1643    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1644    ///
1645    /// # Warning
1646    ///
1647    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1648    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1649    pub fn select_assign(
1650        self,
1651        dim: impl AsIndex,
1652        indices: Tensor<B, 1, Int>,
1653        values: Tensor<B, D, K>,
1654        update: IndexingUpdateOp,
1655    ) -> Self {
1656        let dim = dim.expect_dim_index(D);
1657        check!(TensorCheck::select_assign::<D>(
1658            dim,
1659            &indices.shape(),
1660            &values.shape()
1661        ));
1662
1663        Self::new(K::select_assign(
1664            self.primitive,
1665            dim,
1666            indices.primitive,
1667            values.primitive,
1668            update,
1669        ))
1670    }
1671
1672    /// Update the given tensor with the value tensor where the mask is true.
1673    ///
1674    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1675    /// a scalar.
1676    ///
1677    /// # Example
1678    ///
1679    /// ```rust
1680    /// use burn_tensor::backend::Backend;
1681    /// use burn_tensor::{Tensor, Shape, Bool};
1682    ///
1683    /// fn example<B: Backend>() {
1684    ///   let device = B::Device::default();
1685    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1686    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1687    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1688    ///   let tensor = tensor.mask_where(mask, value);
1689    ///   println!("{tensor}");
1690    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1691    /// }
1692    /// ```
1693    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1694        Self::new(K::mask_where(
1695            self.primitive,
1696            mask.primitive,
1697            value.primitive,
1698        ))
1699    }
1700
1701    /// Update the given tensor with the value where the mask is true.
1702    ///
1703    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1704    /// a tensor.
1705    ///
1706    /// # Example
1707    ///
1708    /// ```rust
1709    /// use burn_tensor::backend::Backend;
1710    /// use burn_tensor::{Tensor, Shape, Bool};
1711    ///
1712    /// fn example<B: Backend>() {
1713    ///   let device = B::Device::default();
1714    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1715    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1716    ///   let tensor = tensor.mask_fill(mask, 3.0);
1717    ///   println!("{tensor}");
1718    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1719    /// }
1720    /// ```
1721    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1722        Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
1723    }
1724
1725    /// Gather tensor elements corresponding to the given indices from the specified dim.
1726    ///
1727    /// Example using a 3D tensor:
1728    ///
1729    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1730    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1731    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1732    ///
1733    /// # Notes
1734    ///
1735    /// The index tensor should have the same shape as the original tensor except for the dim
1736    /// specified.
1737    ///
1738    /// # Warning
1739    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1740    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1741    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1742        check!(TensorCheck::gather::<D>(
1743            dim,
1744            &self.shape(),
1745            &indices.shape()
1746        ));
1747
1748        Self::new(K::gather(dim, self.primitive, indices.primitive))
1749    }
1750
1751    /// Assign the gathered elements corresponding to the given indices along the specified dimension
1752    /// from the value tensor to the original tensor using sum reduction.
1753    ///
1754    /// Example using a 3D tensor:
1755    ///
1756    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1757    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1758    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1759    ///
1760    /// # Arguments
1761    /// * `dim` - The axis along which to scatter elements.
1762    /// * `indices` - The indices of the elements to scatter.
1763    /// * `values` - The values to scatter into the tensor.
1764    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1765    ///
1766    /// # Notes
1767    ///
1768    /// The index tensor should have the same shape as the original tensor except for the specified
1769    /// dimension. The value and index tensors should have the same shape.
1770    ///
1771    /// Other references to the input tensor will not be modified by this operation.
1772    ///
1773    /// # Warning
1774    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1775    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1776    pub fn scatter(
1777        self,
1778        dim: usize,
1779        indices: Tensor<B, D, Int>,
1780        values: Self,
1781        update: IndexingUpdateOp,
1782    ) -> Self {
1783        check!(TensorCheck::scatter::<D>(
1784            dim,
1785            &self.shape(),
1786            &indices.shape(),
1787            &values.shape()
1788        ));
1789
1790        Self::new(K::scatter(
1791            dim,
1792            self.primitive,
1793            indices.primitive,
1794            values.primitive,
1795            update,
1796        ))
1797    }
1798
1799    /// Converts the data of the current tensor.
1800    ///
1801    /// # Note
1802    ///
1803    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1804    /// tensors at once. This may improve laziness, especially if executed on a different
1805    /// thread in native environments.
1806    pub fn into_data(self) -> TensorData {
1807        self.try_into_data().expect(
1808            "Error while reading data: use `try_into_data` instead to catch the error at runtime",
1809        )
1810    }
1811
1812    /// Converts the data of the current tensor and returns any error that might have occurred since the
1813    /// last time the device was synchronized.
1814    ///
1815    /// # Note
1816    ///
1817    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1818    /// tensors at once. This may improve laziness, especially if executed on a different
1819    /// thread in native environments.
1820    pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {
1821        crate::try_read_sync(self.into_data_async()).expect(
1822            "Failed to read tensor data synchronously.
1823        This can happen on platforms that don't support blocking futures like WASM.
1824        If possible, try using into_data_async instead.",
1825        )
1826    }
1827
1828    /// Converts the data of the current tensor.
1829    ///
1830    /// # Note
1831    ///
1832    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1833    /// tensors at once. This may improve laziness, especially if executed on a different
1834    /// thread in native environments.
1835    pub fn to_data(&self) -> TensorData {
1836        self.clone().into_data()
1837    }
1838
1839    /// Returns the data of the current tensor.
1840    pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {
1841        K::into_data_async(self.primitive).await
1842    }
1843
1844    /// Returns the data of the current tensor.
1845    pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {
1846        self.clone().into_data_async().await
1847    }
1848
1849    /// Create a tensor from the given data on the given device.
1850    pub fn from_data<T>(data: T, device: &B::Device) -> Self
1851    where
1852        T: Into<TensorData>,
1853    {
1854        let data = data.into();
1855        check!(TensorCheck::creation_ops::<D>(
1856            "From Data",
1857            data.shape.as_slice()
1858        ));
1859        Self::new(K::from_data(data, device))
1860    }
1861
1862    /// Create a tensor from the given data on the given device enforcing the given data type.
1863    pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
1864    where
1865        T: Into<TensorData>,
1866    {
1867        let data = data.into();
1868        check!(TensorCheck::creation_ops::<D>(
1869            "From Data",
1870            data.shape.as_slice()
1871        ));
1872        Self::new(K::from_data_dtype(data, device, dtype))
1873    }
1874
1875    /// Repeat the tensor along the given dimension.
1876    ///
1877    /// The output tensor has the same shape, except along the given dimension.
1878    ///
1879    /// # Arguments
1880    /// - `dim`: The dimension to repeat.
1881    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1882    ///
1883    /// # Returns
1884    ///
1885    /// A new tensor with the given dimension repeated `times` times.
1886    ///
1887    /// # Example
1888    ///
1889    /// ```rust
1890    /// use burn_tensor::backend::Backend;
1891    /// use burn_tensor::Tensor;
1892    ///
1893    /// fn example<B: Backend>() {
1894    ///     let device = Default::default();
1895    ///     // Create a 2D tensor with dimensions [3, 2]
1896    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1897    ///
1898    ///     // Repeat the tensor along the dimension 0 twice.
1899    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1900    ///     // The resulting tensor will have dimensions [6, 2].
1901    ///     let repeated = tensor.repeat_dim(0, 2);
1902    ///     println!("{repeated}");
1903    /// }
1904    /// ```
1905    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1906        if times > 0 {
1907            Self::new(K::repeat_dim(self.primitive, dim, times))
1908        } else {
1909            let shape = self.shape().repeat(dim, times).unwrap();
1910            Self::empty(shape, &self.device())
1911        }
1912    }
1913
1914    /// Repeat the tensor along the given dimensions.
1915    /// # Arguments
1916    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1917    ///
1918    /// # Returns
1919    ///
1920    /// A new tensor with the given dimensions repeated `times` times.
1921    ///
1922    /// # Panics
1923    ///
1924    /// If `sizes` contains more elements than the number of dimensions.
1925    ///
1926    /// # Example
1927    ///
1928    /// ```rust
1929    ///
1930    /// use burn_tensor::backend::Backend;
1931    /// use burn_tensor::Tensor;
1932    ///
1933    /// fn example<B: Backend>() {
1934    ///     let device = Default::default();
1935    ///     // Create a 2D tensor with dimensions [3, 2]
1936    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1937    ///
1938    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1939    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1940    ///     // The resulting tensor will have dimensions [6, 2].
1941    ///     let repeated = tensor.repeat(&[2, 1]);
1942    /// }
1943    /// ```
1944    pub fn repeat(self, sizes: &[usize]) -> Self {
1945        if sizes.contains(&0) {
1946            let mut shape = self.shape();
1947            for (dim, &times) in sizes.iter().enumerate() {
1948                shape = shape.repeat(dim, times).unwrap();
1949            }
1950
1951            return Self::empty(shape, &self.device());
1952        }
1953
1954        let mut tensor = self;
1955        for (dim, &times) in sizes.iter().enumerate() {
1956            if times > 1 {
1957                tensor = tensor.repeat_dim(dim, times);
1958            }
1959        }
1960        tensor
1961    }
1962
1963    /// Applies element-wise equal comparison.
1964    ///
1965    /// # Returns
1966    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1967    ///
1968    /// # Panics
1969    ///
1970    /// If the two tensors don't have the same shape.
1971    ///
1972    /// # Example
1973    ///
1974    /// ```rust
1975    /// use burn_tensor::backend::Backend;
1976    /// use burn_tensor::Tensor;
1977    ///
1978    /// fn example<B: Backend>() {
1979    ///     let device = Default::default();
1980    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1981    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1982    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
1983    ///     // [[false, true], [true, true], [true, true]]
1984    ///     let equal = t1.equal(t2);
1985    ///     println!("{equal}");
1986    /// }
1987    /// ```
1988    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1989        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1990        Tensor::new(K::equal(self.primitive, other.primitive))
1991    }
1992
1993    /// Applies element-wise non-equality comparison.
1994    ///
1995    /// # Returns
1996    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1997    ///
1998    /// # Panics
1999    ///
2000    /// If the two tensors don't have the same shape.
2001    ///
2002    /// # Example
2003    ///
2004    /// ```rust
2005    /// use burn_tensor::backend::Backend;
2006    /// use burn_tensor::Tensor;
2007    ///
2008    /// fn example<B: Backend>() {
2009    ///     let device = Default::default();
2010    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2011    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2012    ///     // Compare the elements of the two 2D tensors for inequality.
2013    ///     // [[true, false], [false, false], [false, false]]
2014    ///     let not_equal = t1.not_equal(t2);
2015    ///     println!("{not_equal}");
2016    /// }
2017    /// ```
2018    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
2019        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
2020        Tensor::new(K::not_equal(self.primitive, other.primitive))
2021    }
2022
2023    /// Applies element wise equal comparison and returns a boolean tensor.
2024    ///
2025    /// # Arguments
2026    ///
2027    /// * `other` - The element to compare.
2028    ///
2029    /// # Example
2030    ///
2031    /// ```rust
2032    /// use burn_tensor::backend::Backend;
2033    /// use burn_tensor::{Tensor, Shape};
2034    ///
2035    /// fn example<B: Backend>() {
2036    ///    let device = B::Device::default();
2037    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2038    ///    let tensor = tensor.equal_elem(3.0);
2039    ///    println!("{tensor}");
2040    ///    // [[false, false, true], [false, false, false]]
2041    /// }
2042    /// ```
2043    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2044        Tensor::new(K::equal_elem(self.primitive, other.elem()))
2045    }
2046
2047    /// Applies element wise non-equality comparison and returns a boolean tensor.
2048    ///
2049    /// # Arguments
2050    ///
2051    /// * `other` - The element to compare.
2052    ///
2053    /// # Example
2054    ///
2055    /// ```rust
2056    /// use burn_tensor::backend::Backend;
2057    /// use burn_tensor::{Tensor, Shape};
2058    ///
2059    /// fn example<B: Backend>() {
2060    ///    let device = B::Device::default();
2061    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2062    ///    let tensor = tensor.not_equal_elem(3.0);
2063    ///    println!("{tensor}");
2064    ///    // [[true, true, false], [true, true, true]]
2065    /// }
2066    /// ```
2067    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2068        Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
2069    }
2070
2071    /// Concatenates all tensors into a new one along the given dimension.
2072    ///
2073    /// # Panics
2074    ///
2075    /// - If `dim` is higher than the rank.
2076    /// - If `tensors` is an empty vector.
2077    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
2078    ///
2079    /// # Example
2080    ///
2081    /// ```rust
2082    /// use burn_tensor::backend::Backend;
2083    /// use burn_tensor::Tensor;
2084    ///
2085    /// fn example<B: Backend>() {
2086    ///     let device = Default::default();
2087    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
2088    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2089    ///
2090    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
2091    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
2092    ///     // The resulting tensor will have shape [2, 7].
2093    ///     let concat = Tensor::cat(vec![t1, t2], 1);
2094    ///     println!("{concat}");
2095    /// }
2096    /// ```
2097    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
2098        check!(TensorCheck::cat(&tensors, dim));
2099
2100        // Filter out tensors with size 0 along the concatenation dimension.
2101        // Empty tensors don't contribute to the output and would cause issues
2102        // in backend implementations (e.g., division by zero in slice_assign).
2103        // Safety: TensorCheck::cat ensures tensors is non-empty
2104        let first_tensor = tensors.first().unwrap();
2105        let device = first_tensor.device();
2106        let mut shape = first_tensor.shape();
2107
2108        let non_empty_primitives: Vec<_> = tensors
2109            .into_iter()
2110            .filter(|t| t.shape().dims[dim] > 0)
2111            .map(|t| t.primitive)
2112            .collect();
2113
2114        // If all tensors were empty, return an empty tensor with size 0 on concat dim
2115        if non_empty_primitives.is_empty() {
2116            shape.dims[dim] = 0;
2117            return Self::empty(shape, &device);
2118        }
2119
2120        Self::new(K::cat(non_empty_primitives, dim))
2121    }
2122
2123    /// Concatenates all tensors into a new one along a new dimension.
2124    ///
2125    /// # Panics
2126    ///
2127    /// - If all tensors don't have the same shape.
2128    /// - If given dimension is not with range of 0..D2
2129    ///
2130    /// # Example
2131    ///
2132    /// ```rust
2133    /// use burn_tensor::backend::Backend;
2134    /// use burn_tensor::Tensor;
2135    ///
2136    /// fn example<B: Backend>() {
2137    ///     let device = Default::default();
2138    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2139    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2140    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2141    ///
2142    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
2143    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
2144    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
2145    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
2146    ///     // The resulting tensor will have shape [3, 2, 3].
2147    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
2148    ///     println!("{stacked}");
2149    /// }
2150    /// ```
2151    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
2152        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
2153        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
2154        Tensor::<B, D2, K>::cat(tensors, dim)
2155    }
2156
2157    /// Iterate over slices of tensors alongside a given dimension.
2158    ///
2159    /// # Panics
2160    ///
2161    /// If given dimension is greater than or equal to tensor rank.
2162    ///
2163    /// # Returns
2164    ///
2165    /// A tensor iterator.
2166    ///
2167    /// # Example
2168    ///
2169    /// ```rust
2170    /// use burn_tensor::backend::Backend;
2171    /// use burn_tensor::Tensor;
2172    /// fn example<B: Backend>() {
2173    ///   let device = Default::default();
2174    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2175    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
2176    ///   let iter = tensor.iter_dim(0);
2177    ///   for (i,tensor) in iter.enumerate() {
2178    ///     println!("Tensor {}: {}", i, tensor);
2179    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
2180    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
2181    ///  }
2182    /// }
2183    /// ```
2184    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
2185        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
2186        DimIter::new(self, dim)
2187    }
2188
2189    /// Returns a new tensor with the given dimension narrowed to the given range.
2190    ///
2191    /// # Panics
2192    ///
2193    /// - If the dimension is greater than the number of dimensions of the tensor.
2194    /// - If the given range exceeds the number of elements on the given dimension.
2195    ///
2196    /// # Returns
2197    ///
2198    /// A new tensor with the given dimension narrowed to the given range.
2199    ///
2200    /// # Example
2201    ///
2202    /// ```rust
2203    /// use burn_tensor::backend::Backend;
2204    /// use burn_tensor::Tensor;
2205    ///
2206    /// fn example<B: Backend>() {
2207    ///     let device = Default::default();
2208    ///     // Create a 2D tensor with dimensions [4, 3]
2209    ///     let tensor = Tensor::<B, 2>::from_data(
2210    ///         [
2211    ///             [3.0, 4.9, 2.0],
2212    ///             [2.0, 1.9, 3.0],
2213    ///             [6.0, 1.5, 7.0],
2214    ///             [3.0, 4.9, 9.0],
2215    ///         ],
2216    ///         &device,
2217    ///     );
2218    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
2219    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
2220    ///     // The resulting tensor will have dimensions [3, 3].
2221    ///     let narrowed = tensor.narrow(0, 1, 3);
2222    ///     println!("{narrowed}");
2223    /// }
2224    /// ```
2225    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
2226        check!(TensorCheck::dim_ops::<D>("narrow", dim));
2227        check!(TensorCheck::narrow(&self, dim, start, length));
2228        let dims = self.dims();
2229
2230        let ranges: [Range<usize>; D] = dims
2231            .iter()
2232            .enumerate()
2233            .map(|(i, d)| {
2234                if i == dim {
2235                    start..(start + length)
2236                } else {
2237                    0..*d
2238                }
2239            })
2240            .collect::<Vec<_>>()
2241            .try_into()
2242            .unwrap();
2243
2244        Self::slice(self, ranges)
2245    }
2246
2247    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
2248    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2249    ///
2250    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2251    /// Otherwise all chunks will be of equal size except for the last one.
2252    ///
2253    /// # Panics
2254    ///
2255    /// If the dimension is greater than the number of dimensions of the tensor.
2256    ///
2257    /// # Returns
2258    /// A vector of tensors.
2259    ///
2260    /// # Example
2261    ///
2262    /// ```rust
2263    /// use burn_tensor::backend::Backend;
2264    /// use burn_tensor::Tensor;
2265    ///
2266    /// fn example<B: Backend>() {
2267    ///     let device = Default::default();
2268    ///     // Create a 2D tensor with dimensions [4, 3]
2269    ///     let tensor = Tensor::<B, 2>::from_data(
2270    ///         [
2271    ///             [3.0, 4.9, 2.0],
2272    ///             [2.0, 1.9, 3.0],
2273    ///             [6.0, 1.5, 7.0],
2274    ///             [3.0, 4.9, 9.0],
2275    ///         ],
2276    ///         &device,
2277    ///     );
2278    ///     // Split the tensor along the dimension 1 into 2 chunks.
2279    ///     // The first chuck will have shape [4, 2]:
2280    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
2281    ///     // The second chunk will have shape [4, 1]:
2282    ///     // [[2.0], [3.0], [7.0], [9.0]]
2283    ///     let chunks = tensor.chunk(2, 1);
2284    ///     println!("{chunks:?}");
2285    /// }
2286    /// ```
2287    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
2288        check!(TensorCheck::dim_ops::<D>("chunk", dim));
2289        let size = self.shape().dims[dim];
2290        if size < chunks {
2291            return (0..size)
2292                .map(|i| Self::narrow(self.clone(), dim, i, 1))
2293                .collect();
2294        }
2295
2296        let mut tensors = Vec::with_capacity(chunks);
2297        let mut sum_chunk_size = 0;
2298        if size.is_multiple_of(chunks) {
2299            let chunk_size = size / chunks;
2300            for _ in 0..chunks {
2301                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2302                sum_chunk_size += chunk_size;
2303            }
2304        } else {
2305            let chunk_size = (size / chunks) + 1; // assumes not divisible
2306            for _ in 0..chunks - 1 {
2307                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2308                sum_chunk_size += chunk_size;
2309            }
2310            let remainder = size % chunk_size;
2311            tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
2312        }
2313
2314        tensors
2315    }
2316
2317    /// Splits the tensor into chunks of a specified size along a given dimension.
2318    /// Each chunk is a view of the original tensor.
2319    ///
2320    /// If the tensor size along the given dimension is not divisible by `split_size`,
2321    /// then the last chunk will be smaller.
2322    ///
2323    /// # Panics
2324    ///
2325    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2326    ///
2327    /// # Returns
2328    ///
2329    /// A vector of tensors.
2330    ///
2331    /// # Example
2332    /// ```rust
2333    /// use burn_tensor::backend::Backend;
2334    /// use burn_tensor::Tensor;
2335    ///
2336    /// fn example<B: Backend>() {
2337    ///     let device = Default::default();
2338    ///     // Create a 1D tensor with 5 elements
2339    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2340    ///     // Split the tensor into chunks of size 2 along dimension 0
2341    ///     let chunks = tensor.split(2, 0);
2342    ///     // The result is a vector of tensors:
2343    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2344    ///     println!("{:?}", chunks);
2345    /// }
2346    /// ```
2347    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2348        check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2349        let size = self.shape().dims[dim];
2350        let mut tensors = Vec::new();
2351
2352        let mut start = 0;
2353        while start < size {
2354            let length = usize::min(split_size, size - start);
2355            tensors.push(Self::narrow(self.clone(), dim, start, length));
2356            start += length;
2357        }
2358
2359        tensors
2360    }
2361
2362    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2363    /// Each chunk is a view of the original tensor.
2364    ///
2365    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2366    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2367    ///
2368    /// # Panics
2369    ///
2370    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2371    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2372    ///
2373    /// # Returns
2374    ///
2375    /// A vector of tensors.
2376    ///
2377    /// # Example
2378    /// ```rust
2379    /// use burn_tensor::backend::Backend;
2380    /// use burn_tensor::Tensor;
2381    ///
2382    /// fn example<B: Backend>() {
2383    ///     let device = Default::default();
2384    ///     // Create a 1D tensor with 5 elements
2385    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2386    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
2387    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2388    ///     // The result is a vector of tensors:
2389    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2390    ///     println!("{:?}", chunks);
2391    /// }
2392    /// ```
2393    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2394        check!(TensorCheck::split_with_sizes::<D>(
2395            &self.shape(),
2396            &split_sizes,
2397            dim
2398        ));
2399        let mut tensors = Vec::new();
2400
2401        let mut start = 0;
2402        for length in split_sizes {
2403            if length == 0 {
2404                continue;
2405            }
2406            tensors.push(Self::narrow(self.clone(), dim, start, length));
2407            start += length;
2408        }
2409
2410        tensors
2411    }
2412
2413    /// Tests if any element in the `tensor` evaluates to True.
2414    ///
2415    /// # Arguments
2416    ///
2417    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2418    ///
2419    /// # Returns
2420    ///
2421    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2422    /// evaluates to True, False otherwise.
2423    ///
2424    /// # Example
2425    ///
2426    /// ```rust
2427    /// use burn_tensor::backend::Backend;
2428    /// use burn_tensor::{Tensor, Bool};
2429    ///
2430    /// fn example<B: Backend>() {
2431    ///   let device = Default::default();
2432    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2433    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2434    ///
2435    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2436    ///   let any_tensor = tensor.any();
2437    ///   println!("{}", any_tensor);
2438    ///   // Tensor { data: [true], ... }
2439    ///
2440    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2441    ///   let any_tensor_two = tensor_two.any();
2442    ///   println!("{}", any_tensor_two);
2443    ///   // Tensor { data: [false], ... }
2444    /// }
2445    /// ```
2446    pub fn any(self) -> Tensor<B, 1, Bool> {
2447        Tensor::new(K::any(self.primitive))
2448    }
2449
2450    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2451    ///
2452    /// # Arguments
2453    ///
2454    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2455    /// * `dim` - The axis along which to test.
2456    ///
2457    /// # Returns
2458    ///
2459    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2460    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2461    /// evaluates to True, False otherwise.
2462    ///
2463    /// # Example
2464    ///
2465    /// ```rust
2466    /// use burn_tensor::backend::Backend;
2467    /// use burn_tensor::{Tensor, Bool};
2468    ///
2469    /// fn example<B: Backend>() {
2470    ///     let device = Default::default();
2471    ///     let tensor =
2472    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2473    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
2474    ///     // [[true], [true]],
2475    ///     let any_dim = tensor.clone().any_dim(1);
2476    ///     println!("{any_dim}");
2477    /// }
2478    /// ```
2479    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2480        Tensor::new(K::any_dim(self.primitive, dim))
2481    }
2482
2483    /// Tests if all elements in the `tensor` evaluate to True.
2484    ///
2485    /// # Arguments
2486    ///
2487    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2488    ///
2489    /// # Returns
2490    ///
2491    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2492    /// evaluate to True, False otherwise.
2493    ///
2494    /// # Example
2495    ///
2496    /// ```rust
2497    /// use burn_tensor::backend::Backend;
2498    /// use burn_tensor::{Tensor, Bool};
2499    ///
2500    /// fn example<B: Backend>() {
2501    ///     let device = Default::default();
2502    ///     let tensor =
2503    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2504    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
2505    ///     // [false]
2506    ///     let all = tensor.all();
2507    ///     println!("{all}");
2508    /// }
2509    /// ```
2510    pub fn all(self) -> Tensor<B, 1, Bool> {
2511        Tensor::new(K::all(self.primitive))
2512    }
2513
2514    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2515    ///
2516    /// # Arguments
2517    ///
2518    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2519    /// * `dim` - The axis along which to test.
2520    ///
2521    /// # Returns
2522    ///
2523    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2524    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2525    /// evaluates to True, False otherwise.
2526    ///
2527    /// # Example
2528    ///
2529    /// ```rust
2530    /// use burn_tensor::backend::Backend;
2531    /// use burn_tensor::{Tensor, Bool};
2532    ///
2533    /// fn example<B: Backend>() {
2534    ///     let device = Default::default();
2535    ///     let tensor =
2536    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2537    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
2538    ///     // [[true, true, false]]
2539    ///     let all_dim = tensor.clone().all_dim(0);
2540    ///     println!("{all_dim}");
2541    /// }
2542    /// ```
2543    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2544        Tensor::new(K::all_dim(self.primitive, dim))
2545    }
2546
2547    /// Convert the tensor into a scalar.
2548    ///
2549    /// # Panics
2550    ///
2551    /// - If the tensor doesn't have one element.
2552    /// - If the backend fails to read the tensor data synchronously.
2553    ///
2554    /// # Returns
2555    ///
2556    /// The scalar value of the tensor.
2557    ///
2558    /// # Example
2559    ///
2560    /// ```rust
2561    /// use burn_tensor::backend::Backend;
2562    /// use burn_tensor::Tensor;
2563    ///
2564    /// fn example<B: Backend>() {
2565    ///     let device = Default::default();
2566    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2567    ///     // Convert the tensor with a single element into a scalar.
2568    ///     let scalar = tensor.into_scalar();
2569    ///     println!("{scalar}");
2570    /// }
2571    /// ```
2572    pub fn into_scalar(self) -> K::Elem {
2573        crate::try_read_sync(self.into_scalar_async())
2574            .expect(
2575            "Failed to read tensor data synchronously. This can happen on platforms
2576            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2577            )
2578            .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime")
2579    }
2580
2581    /// Convert the tensor into a scalar and returns any error that might have occurred since the
2582    /// last time the device was synchronized.
2583    ///
2584    /// # Panics
2585    ///
2586    /// - If the tensor doesn't have one element.
2587    /// - If the backend fails to read the tensor data synchronously.
2588    ///
2589    /// # Returns
2590    ///
2591    /// The scalar value of the tensor.
2592    pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {
2593        crate::try_read_sync(self.into_scalar_async()).expect(
2594            "Failed to read tensor data synchronously. This can happen on platforms
2595            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2596        )
2597    }
2598
2599    /// Convert the tensor into a scalar.
2600    ///
2601    /// # Panics
2602    ///
2603    /// If the tensor doesn't have one element.
2604    pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {
2605        check!(TensorCheck::into_scalar::<D>(&self.shape()));
2606
2607        Ok(self.into_data_async().await?.iter().next().unwrap())
2608    }
2609
2610    /// Broadcast the tensor to the given shape.
2611    ///
2612    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2613    /// (which can be inferred with `-1`).
2614    ///
2615    /// # Arguments
2616    ///
2617    /// * `shape` - The shape to broadcast the tensor to.
2618    ///   Can contain -1 for dimensions that should be inferred.
2619    ///   The number of elements in the shape must be greater or equal as
2620    ///   the number of dimensions of the tensor.
2621    ///
2622    /// # Panics
2623    ///
2624    /// If the tensor cannot be broadcasted to the given shape.
2625    ///
2626    /// # Returns
2627    ///
2628    /// A new tensor with the given shape.
2629    ///
2630    /// # Example
2631    ///
2632    /// ```rust
2633    /// use burn_tensor::backend::Backend;
2634    /// use burn_tensor::Tensor;
2635    ///
2636    /// fn example<B: Backend>() {
2637    ///     let device = Default::default();
2638    ///     // Create a 2D tensor with dimensions [3, 1]
2639    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2640    ///     // Expand the tensor to a new shape [3, 4]
2641    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2642    ///     let expanded = tensor.expand([3, 4]);
2643    ///     println!("{}", expanded);
2644    /// }
2645    /// ```
2646    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2647        let shape = shape.into_shape(&self.shape());
2648        check!(TensorCheck::expand::<D, D2>(
2649            "expand",
2650            &self.shape(),
2651            &shape,
2652        ));
2653
2654        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2655    }
2656
2657    /// Unfold windows along a dimension.
2658    ///
2659    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2660    /// where windows are advanced by `step` at each index.
2661    ///
2662    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2663    ///
2664    /// The new view will have the unfolded dimension replaced by two dimensions;
2665    /// one in the position of the original dimension, with size equal to the number of windows,
2666    /// and one appended to the right-most position, with size equal to `size`.
2667    ///
2668    /// # Warning
2669    ///
2670    /// For the `ndarray` and `candle` backends; this is not a view but a copy
2671    /// with duplicated data.
2672    ///
2673    /// # Arguments
2674    ///
2675    /// * `dim` - the dimension to unfold.
2676    /// * `size` - the size of each unfolded window.
2677    /// * `step` - the step between each window.
2678    ///
2679    /// # Returns
2680    ///
2681    /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2682    pub fn unfold<const D2: usize, I: AsIndex>(
2683        self,
2684        dim: I,
2685        size: usize,
2686        step: usize,
2687    ) -> Tensor<B, D2, K> {
2688        let dim = dim.expect_dim_index(D);
2689        check!(TensorCheck::unfold::<D, D2>(
2690            "unfold",
2691            &self.shape(),
2692            dim,
2693            size,
2694            step,
2695        ));
2696        Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2697    }
2698}
2699
2700/// Iterator given by (Tensor::iter_dim).
2701pub struct DimIter<B, const D: usize, K>
2702where
2703    B: Backend,
2704    K: BasicOps<B>,
2705{
2706    start: usize,
2707    end: usize,
2708    dim: usize,
2709    ranges: [Range<usize>; D],
2710    tensor: Tensor<B, D, K>,
2711}
2712
2713impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2714    type Item = Tensor<B, D, K>;
2715
2716    fn next(&mut self) -> Option<Self::Item> {
2717        if self.start >= self.end {
2718            return None;
2719        }
2720
2721        let mut ranges = self.ranges.clone();
2722        ranges[self.dim] = self.start..(self.start + 1);
2723
2724        let slice = self.tensor.clone().slice(ranges);
2725        self.start += 1;
2726
2727        Some(slice)
2728    }
2729}
2730
2731impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2732    fn next_back(&mut self) -> Option<Self::Item> {
2733        if self.start >= self.end {
2734            return None;
2735        }
2736
2737        let mut ranges = self.ranges.clone();
2738        ranges[self.dim] = (self.end - 1)..self.end;
2739
2740        let slice = self.tensor.clone().slice(ranges);
2741        self.end = self.end.saturating_sub(1);
2742
2743        Some(slice)
2744    }
2745}
2746
2747impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2748    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2749        let dims = tensor.dims();
2750        let ranges = dims
2751            .iter()
2752            .map(|&dim| 0..dim)
2753            .collect::<Vec<Range<usize>>>();
2754        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2755        Self {
2756            end: dims[dim],
2757            ranges,
2758            start: 0,
2759            dim,
2760            tensor,
2761        }
2762    }
2763}
2764
2765impl<B, const D: usize, K> Tensor<B, D, K>
2766where
2767    B: Backend,
2768    K: BasicOps<B>,
2769    <K as BasicOps<B>>::Elem: Debug,
2770{
2771    #[inline]
2772    fn push_newline_indent(acc: &mut String, indent: usize) {
2773        acc.push('\n');
2774        for _ in 0..indent {
2775            acc.push(' ');
2776        }
2777    }
2778    fn fmt_inner_tensor(
2779        &self,
2780        acc: &mut String,
2781        depth: usize,
2782        multi_index: &mut [usize],
2783        range: (usize, usize),
2784        precision: Option<usize>,
2785    ) {
2786        let (start, end) = range;
2787        for i in start..end {
2788            if i > 0 {
2789                acc.push_str(", ");
2790            }
2791            multi_index[depth] = i;
2792            let range: [Range<usize>; D] =
2793                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2794
2795            let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());
2796
2797            if let Some(Ok(data)) = data {
2798                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2799                match (precision, K::name()) {
2800                    (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2801                    (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2802                    _ => acc.push_str(&format!("{elem:?}")),
2803                }
2804            } else {
2805                acc.push_str("<Tensor data not available>");
2806            }
2807        }
2808    }
2809
2810    fn fmt_outer_tensor(
2811        &self,
2812        acc: &mut String,
2813        depth: usize,
2814        multi_index: &mut [usize],
2815        print_options: &PrintOptions,
2816        summarize: bool,
2817        range: (usize, usize),
2818    ) {
2819        let (start, end) = range;
2820        for i in start..end {
2821            if i > start {
2822                acc.push(',');
2823                Self::push_newline_indent(acc, depth + 1);
2824            }
2825            acc.push('[');
2826            multi_index[depth] = i;
2827            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2828            acc.push(']');
2829        }
2830    }
2831
2832    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2833    ///
2834    /// This function is designed to work with tensors of any dimensionality.
2835    /// It traverses the tensor dimensions recursively, converting the elements
2836    /// to strings and appending them to the accumulator string with the
2837    /// appropriate formatting.
2838    ///
2839    /// # Arguments
2840    ///
2841    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2842    /// * `depth` - The current depth of the tensor dimensions being processed.
2843    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2844    fn display_recursive(
2845        &self,
2846        acc: &mut String,
2847        depth: usize,
2848        multi_index: &mut [usize],
2849        print_options: &PrintOptions,
2850        summarize: bool,
2851    ) {
2852        let edge_items = print_options.edge_items;
2853
2854        if depth == 0 {
2855            acc.push('[');
2856        }
2857
2858        if depth == self.dims().len() - 1 {
2859            // if we are at the innermost dimension, just push its elements into the accumulator
2860            if summarize && self.dims()[depth] > 2 * edge_items {
2861                // print the starting `edge_items` elements
2862                self.fmt_inner_tensor(
2863                    acc,
2864                    depth,
2865                    multi_index,
2866                    (0, edge_items),
2867                    print_options.precision,
2868                );
2869                acc.push_str(", ...");
2870                // print the last `edge_items` elements
2871                self.fmt_inner_tensor(
2872                    acc,
2873                    depth,
2874                    multi_index,
2875                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2876                    print_options.precision,
2877                );
2878            } else {
2879                // print all the elements
2880                self.fmt_inner_tensor(
2881                    acc,
2882                    depth,
2883                    multi_index,
2884                    (0, self.dims()[depth]),
2885                    print_options.precision,
2886                );
2887            }
2888        } else {
2889            // otherwise, iterate through the current dimension and recursively display the inner tensors
2890            if summarize && self.dims()[depth] > 2 * edge_items {
2891                self.fmt_outer_tensor(
2892                    acc,
2893                    depth,
2894                    multi_index,
2895                    print_options,
2896                    summarize,
2897                    (0, edge_items),
2898                );
2899
2900                acc.push(',');
2901                Self::push_newline_indent(acc, depth + 1);
2902                acc.push_str("...");
2903                Self::push_newline_indent(acc, depth + 1);
2904
2905                self.fmt_outer_tensor(
2906                    acc,
2907                    depth,
2908                    multi_index,
2909                    print_options,
2910                    summarize,
2911                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2912                );
2913            } else {
2914                self.fmt_outer_tensor(
2915                    acc,
2916                    depth,
2917                    multi_index,
2918                    print_options,
2919                    summarize,
2920                    (0, self.dims()[depth]),
2921                );
2922            }
2923        }
2924
2925        if depth == 0 {
2926            acc.push(']');
2927        }
2928    }
2929}
2930
2931#[derive(Clone, Debug)]
2932/// Options for Tensor pretty printing
2933pub struct PrintOptions {
2934    /// number of elements to start summarizing tensor
2935    pub threshold: usize,
2936
2937    /// number of starting elements and ending elements to display
2938    pub edge_items: usize,
2939
2940    /// Precision for floating point numbers
2941    pub precision: Option<usize>,
2942}
2943
2944static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
2945
2946impl PrintOptions {
2947    /// Print options with default values
2948    pub const fn const_default() -> Self {
2949        Self {
2950            threshold: 1000,
2951            edge_items: 3,
2952            precision: None,
2953        }
2954    }
2955}
2956
2957impl Default for PrintOptions {
2958    fn default() -> Self {
2959        Self::const_default()
2960    }
2961}
2962
2963/// Set print options
2964pub fn set_print_options(options: PrintOptions) {
2965    let mut print_opts = PRINT_OPTS.write().unwrap();
2966    *print_opts = options;
2967}
2968
2969/// Pretty print tensors
2970impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
2971where
2972    B: Backend,
2973    B::IntElem: core::fmt::Display,
2974    K: BasicOps<B>,
2975    <K as BasicOps<B>>::Elem: Debug,
2976{
2977    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2978        writeln!(f, "Tensor {{")?;
2979
2980        {
2981            // Do not lock the mutex for the whole function
2982            let mut po = { PRINT_OPTS.read().unwrap().clone() };
2983
2984            // Override the precision if it is set from the formatter
2985            // This will be possible when the tensor is printed using the `{:.*}` syntax
2986            if let Some(precision) = f.precision() {
2987                po.precision = Some(precision);
2988            }
2989
2990            let mut acc = String::new();
2991            let mut multi_index = vec![0; D];
2992            let summarize = self.shape().num_elements() > po.threshold;
2993
2994            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
2995
2996            writeln!(f, "  data:")?;
2997            write!(f, "{acc}")?;
2998            writeln!(f, ",")?;
2999        }
3000
3001        writeln!(f, "  shape:  {:?},", self.dims())?;
3002        writeln!(f, "  device:  {:?},", self.device())?;
3003        writeln!(f, "  backend:  {:?},", B::name(&self.device()))?;
3004        writeln!(f, "  kind:  {:?},", K::name())?;
3005
3006        let dtype = self.primitive.dtype();
3007
3008        writeln!(f, "  dtype:  {:?},", dtype.name())?;
3009        write!(f, "}}")
3010    }
3011}
3012
3013/// Trait used for movedim arguments
3014pub trait MovedimArgs {
3015    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3016    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3017}
3018
3019impl MovedimArgs for Vec<i32> {
3020    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3021        let set = self
3022            .iter()
3023            .map(|&dim| {
3024                if dim < 0 {
3025                    (D as i32 + dim) as usize
3026                } else {
3027                    dim as usize
3028                }
3029            })
3030            .collect::<Vec<usize>>();
3031        check!(TensorCheck::movedim_args_vec::<D>(&set));
3032
3033        set
3034    }
3035}
3036
3037impl MovedimArgs for Vec<usize> {
3038    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3039        check!(TensorCheck::movedim_args_vec::<D>(&self));
3040        self
3041    }
3042}
3043
3044impl MovedimArgs for usize {
3045    #[allow(clippy::vec_init_then_push)]
3046    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3047        check!(TensorCheck::movedim_args_usize::<D>(self));
3048
3049        let mut set = Vec::with_capacity(1);
3050        set.push(self);
3051
3052        set
3053    }
3054}
3055
3056impl MovedimArgs for i32 {
3057    #[allow(clippy::vec_init_then_push)]
3058    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3059        check!(TensorCheck::movedim_args_i32::<D>(self));
3060
3061        let dim = if self < 0 {
3062            (D as i32 + self) as usize
3063        } else {
3064            self as usize
3065        };
3066
3067        let mut set = Vec::with_capacity(1);
3068        set.push(dim);
3069
3070        set
3071    }
3072}
3073
3074/// Trait used for reshape arguments.
3075pub trait ReshapeArgs<const D2: usize> {
3076    /// Converts to a shape.
3077    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3078        self,
3079        tensor: &Tensor<B, D, K>,
3080    ) -> Shape;
3081}
3082
3083impl<const D2: usize> ReshapeArgs<D2> for Shape {
3084    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3085        self,
3086        tensor: &Tensor<B, D, K>,
3087    ) -> Shape {
3088        check!(TensorCheck::reshape_args_usize::<D, D2>(
3089            &tensor.shape(),
3090            &self
3091        ));
3092
3093        self
3094    }
3095}
3096impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3097    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3098        self,
3099        tensor: &Tensor<B, D, K>,
3100    ) -> Shape {
3101        let shape = Shape::from(self);
3102
3103        check!(TensorCheck::reshape_args_usize::<D, D2>(
3104            &tensor.shape(),
3105            &shape
3106        ));
3107
3108        shape
3109    }
3110}
3111
3112impl<const D2: usize> ReshapeArgs<D2> for [i64; D2] {
3113    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3114        self,
3115        tensor: &Tensor<B, D, K>,
3116    ) -> Shape {
3117        // Validate the reshape arguments
3118        check!(TensorCheck::reshape_args_i64(&self));
3119
3120        // Temporary shape
3121        let mut new_shape: [i64; D2] = [1; D2];
3122
3123        // We need to find the index of the 0 dimension and
3124        // replace it with the actual dimension value.
3125        for (i, &s) in self.iter().enumerate() {
3126            if s != 0 {
3127                new_shape[i] = s;
3128            } else {
3129                new_shape[i] = tensor.dims()[i] as i64;
3130            }
3131        }
3132
3133        // Find the index of the inferred dimension (-1)
3134        let infer_index = new_shape.iter().position(|x| x == &-1);
3135
3136        // Handle the case where the dimension is inferred (via -1)
3137        if let Some(index) = infer_index {
3138            // Handle the case where the dimension is inferred
3139            let mut product = 1;
3140            for (i, &s) in new_shape.iter().enumerate() {
3141                if i != index {
3142                    product *= s;
3143                }
3144            }
3145            let product_current = tensor.shape().num_elements() as i64;
3146
3147            new_shape[index] = product_current / product;
3148
3149            // Check if the reshape is valid
3150            if product_current % product != 0 {
3151                panic!(
3152                    "Cannot reshape tensor of shape {:?} to shape {:?}",
3153                    tensor.shape(),
3154                    new_shape
3155                );
3156            }
3157        };
3158
3159        // Convert each element to usize
3160        let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3161
3162        Shape::from(new_shape)
3163    }
3164}
3165
3166impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3167    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3168        self,
3169        tensor: &Tensor<B, D, K>,
3170    ) -> Shape {
3171        // Convert i32 array to i64 array and use existing implementation
3172        let i64_array: [i64; D2] = self.map(|x| x as i64);
3173        ReshapeArgs::into_shape(i64_array, tensor)
3174    }
3175}
3176
3177/// Trait used for broadcast arguments.
3178pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3179    /// Converts to a shape.
3180    fn into_shape(self, shape: &Shape) -> Shape;
3181}
3182
3183impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3184    fn into_shape(self, _shape: &Shape) -> Shape {
3185        self
3186    }
3187}
3188
3189impl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {
3190    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3191    fn into_shape(self, shape: &Shape) -> Shape {
3192        if self.len() < shape.num_dims() {
3193            panic!("Broadcast arguments must be greater than the number of dimensions");
3194        }
3195
3196        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3197        let new_shape: Vec<_> = self
3198            .iter()
3199            .rev()
3200            .map(|x| {
3201                let primitive = x.index();
3202                if primitive < -1 || primitive == 0 {
3203                    panic!("Broadcast arguments must be positive or -1");
3204                }
3205                primitive
3206            })
3207            .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3208            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3209            .collect::<Vec<_>>()
3210            .into_iter()
3211            .rev()
3212            .collect();
3213
3214        if new_shape.contains(&0) {
3215            panic!("Cannot substitute -1 for a non-existing dimension");
3216        }
3217
3218        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3219
3220        Shape::from(new_shape)
3221    }
3222}
3223
3224impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3225where
3226    B: Backend,
3227    K: BasicOps<B>,
3228    K::Elem: Debug + Copy + Serialize,
3229{
3230    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3231        let data = self.to_data();
3232        data.serialize(serializer)
3233    }
3234}
3235
3236impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3237where
3238    B: Backend,
3239    K: BasicOps<B>,
3240    K::Elem: Debug + Copy + Deserialize<'de>,
3241{
3242    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3243        let tensor = Tensor::from_data(
3244            TensorData::deserialize(deserializer)?,
3245            &<B::Device as Default>::default(),
3246        );
3247        Ok(tensor)
3248    }
3249}
3250
3251#[cfg(test)]
3252mod tests {
3253    use crate::{Shape, s};
3254
3255    #[test]
3256    fn slice_range_single_dim_leading() {
3257        let shape = Shape::new([8, 4]);
3258
3259        // Half-open range
3260        let slices = shape.clone().into_slices([0..5]);
3261        assert_eq!(slices[0].to_range(8), 0..5);
3262        let slices = shape.clone().into_slices([-3..-1]);
3263        assert_eq!(slices[0].to_range(8), 5..7);
3264
3265        // Inclusive range
3266        let slices = shape.clone().into_slices([0..=4]);
3267        assert_eq!(slices[0].to_range(8), 0..5);
3268        let slices = shape.clone().into_slices([-2..=-1]);
3269        assert_eq!(slices[0].to_range(8), 6..8);
3270
3271        // Unbounded start
3272        let slices = shape.clone().into_slices([..3]);
3273        assert_eq!(slices[0].to_range(8), 0..3);
3274        let slices = shape.clone().into_slices([..-5]);
3275        assert_eq!(slices[0].to_range(8), 0..3);
3276
3277        // Unbounded end
3278        let slices = shape.clone().into_slices([5..]);
3279        assert_eq!(slices[0].to_range(8), 5..8);
3280        let slices = shape.clone().into_slices([-3..]);
3281        assert_eq!(slices[0].to_range(8), 5..8);
3282
3283        // Full range
3284        let slices = shape.into_slices([..]);
3285        assert_eq!(slices[0].to_range(8), 0..8);
3286    }
3287
3288    #[test]
3289    fn test_negative_slice_indices() {
3290        use crate::Slice;
3291
3292        // Test negative indices conversion
3293        let slice: Slice = (-3..-1).into();
3294        assert_eq!(slice.start, -3);
3295        assert_eq!(slice.end, Some(-1));
3296
3297        // Test to_range conversion with size 8
3298        let range = slice.to_range(8);
3299        assert_eq!(range, 5..7);
3300
3301        // Test with shape slice
3302        let shape = Shape::new([8, 4]);
3303        let result = shape.clone().into_slices([-3..-1]);
3304        assert_eq!(result[0].to_range(8), 5..7);
3305
3306        // Test more negative index cases
3307        let slice2: Slice = (-5..).into();
3308        assert_eq!(slice2.to_range(10), 5..10);
3309
3310        let slice3: Slice = (..-2).into();
3311        assert_eq!(slice3.to_range(10), 0..8);
3312
3313        // Test with s! macro - single dimension returns Slice directly
3314        let slice4 = s![-3..-1];
3315        assert_eq!(slice4.start, -3);
3316        assert_eq!(slice4.end, Some(-1));
3317    }
3318
3319    #[test]
3320    fn slice_range_multi_dim() {
3321        let shape = Shape::new([8, 4]);
3322
3323        // Multiple ways to provide ranges
3324        let slices = shape.clone().into_slices([0..5, 0..4]);
3325        assert_eq!(slices[0].to_range(8), 0..5);
3326        assert_eq!(slices[1].to_range(4), 0..4);
3327
3328        let slices = shape.clone().into_slices([0.., 0..]);
3329        assert_eq!(slices[0].to_range(8), 0..8);
3330        assert_eq!(slices[1].to_range(4), 0..4);
3331
3332        let slices = shape.clone().into_slices([0..=7, 0..=3]);
3333        assert_eq!(slices[0].to_range(8), 0..8);
3334        assert_eq!(slices[1].to_range(4), 0..4);
3335
3336        let slices = shape.clone().into_slices([0..5, 0..3]);
3337        assert_eq!(slices[0].to_range(8), 0..5);
3338        assert_eq!(slices[1].to_range(4), 0..3);
3339
3340        let slices = shape.into_slices([0.., 0..]);
3341        assert_eq!(slices[0].to_range(8), 0..8);
3342        assert_eq!(slices[1].to_range(4), 0..4);
3343    }
3344
3345    #[test]
3346    fn slice_range_multi_dim_index() {
3347        let shape = Shape::new([8, 4]);
3348
3349        // Indices (single integer) should also convert to correct range
3350        let slices = shape.clone().into_slices([0, 2]);
3351        assert_eq!(slices[0].to_range(8), 0..1);
3352        assert_eq!(slices[1].to_range(4), 2..3);
3353
3354        let slices = shape.into_slices([-1, -1]);
3355        assert_eq!(slices[0].to_range(8), 7..8);
3356        assert_eq!(slices[1].to_range(4), 3..4);
3357    }
3358
3359    #[test]
3360    fn slice_range_multi_dim_heterogeneous() {
3361        // Slice macro `s![]` can be used to provide different range types
3362        let shape = Shape::new([8, 4, 2]);
3363        let slice = s![0..5, .., -1];
3364        let slices = shape.into_slices(slice);
3365        assert_eq!(slices[0].to_range(8), 0..5);
3366        assert_eq!(slices[1].to_range(4), 0..4);
3367        assert_eq!(slices[2].to_range(2), 1..2);
3368
3369        let shape = Shape::new([8, 4, 2, 3]);
3370        let slice = s![..=4, 0..=3, .., -2..];
3371        let slices = shape.into_slices(slice);
3372        assert_eq!(slices[0].to_range(8), 0..5);
3373        assert_eq!(slices[1].to_range(4), 0..4);
3374        assert_eq!(slices[2].to_range(2), 0..2);
3375        assert_eq!(slices[3].to_range(3), 1..3);
3376
3377        let shape = Shape::new([3, 4]);
3378        let slice = s![1..-1, ..];
3379        let slices = shape.into_slices(slice);
3380        assert_eq!(slices[0].to_range(3), 1..2);
3381        assert_eq!(slices[1].to_range(4), 0..4);
3382    }
3383}