burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::future::Future;
11use core::iter::repeat;
12use core::{fmt::Debug, ops::Range};
13use serde::{Deserialize, Deserializer};
14
15use serde::{Serialize, Serializer};
16
17use super::{Slice, SliceArg, TensorMetadata, Transaction};
18use crate::indexing::{AsIndex, canonicalize_dim, wrap_index};
19use crate::{
20    Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, backend::Backend, check,
21    ops::Device,
22};
23use crate::{DType, Element, TensorPrimitive};
24use crate::{cast::ToElement, check::TensorCheck};
25
26/// A tensor with a given backend, shape and data type.
27///
28/// # Indexing
29/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
30/// or [`select`](Tensor::select) for numeric types.
31///
32/// ## Example
33///
34/// ```rust
35/// use burn_tensor::backend::Backend;
36/// use burn_tensor::Tensor;
37/// use burn_tensor::Int;
38///
39/// fn example<B: Backend>() {
40///     let device = Default::default();
41///
42///     let tensor = Tensor::<B, 2>::from_data(
43///         [
44///             [3.0, 4.9, 2.0],
45///             [2.0, 1.9, 3.0],
46///             [6.0, 1.5, 7.0],
47///             [3.0, 4.9, 9.0],
48///         ],
49///         &device,
50///     );
51///
52///     // Slice the tensor to get the second and third rows:
53///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
54///     // The resulting tensor will have dimensions [2, 3].
55///     let slice = tensor.clone().slice([1..3]);
56///     println!("{slice}");
57///
58///     // Slice the tensor to get the first two rows and the first 2 columns:
59///     // [[3.0, 4.9], [2.0, 1.9]]
60///     // The resulting tensor will have dimensions [2, 2].
61///     let slice = tensor.clone().slice([0..2, 0..2]);
62///     println!("{slice}");
63///
64///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
65///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
66///     // The resulting tensor will have dimensions [4, 2]
67///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
68///     let indexed = tensor.select(1, indices);
69///     println!("{indexed}");
70/// }
71/// ```
72#[derive(new, Clone, Debug)]
73pub struct Tensor<B, const D: usize, K = Float>
74where
75    B: Backend,
76    K: TensorKind<B>,
77{
78    pub(crate) primitive: K::Primitive,
79}
80
81impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
82where
83    B: Backend,
84    K: BasicOps<B>,
85    T: Into<TensorData>,
86{
87    fn from(value: T) -> Self {
88        Tensor::from_data(value.into(), &Default::default())
89    }
90}
91
92impl<B, const D: usize, K> Tensor<B, D, K>
93where
94    B: Backend,
95    K: BasicOps<B>,
96    K::Elem: Element,
97{
98    /// Executes an operation on the tensor and modifies its value.
99    ///
100    /// # Notes
101    ///
102    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
103    /// no other reference pointing to the same tensor.
104    ///
105    /// Wrapping operations with inplace is not an optimization, it's mainly there if you
106    /// want to mutate a tensor by using owned operations. A plausible usage would be to
107    /// update the weights of a mutable model reference.
108    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
109        let mut tensor_owned = Tensor::empty([0; D], &self.device());
110        core::mem::swap(&mut tensor_owned, self);
111
112        let mut tensor_new = func(tensor_owned);
113        core::mem::swap(&mut tensor_new, self);
114    }
115
116    /// Converts the tensor into a primitive tensor.
117    pub fn into_primitive(self) -> K::Primitive {
118        self.primitive
119    }
120
121    /// Converts from a primitive tensor into a tensor.
122    pub fn from_primitive(tensor: K::Primitive) -> Self {
123        Self::new(tensor)
124    }
125
126    /// Returns the number of dimensions of the tensor.
127    pub fn rank(&self) -> usize {
128        self.primitive.rank()
129    }
130
131    /// Returns the tensor primitive data type.
132    ///
133    /// # Note
134    /// Some element types are encoded in different primitive types depending on the backend
135    /// (e.g., bool could be encoded as `u8` or `u32`).
136    pub fn dtype(&self) -> DType {
137        self.primitive.dtype()
138    }
139
140    /// Create an empty tensor of the given shape.
141    ///
142    /// # Arguments
143    ///
144    /// - `shape`: The shape of the tensor.
145    /// - `device`: The device where the tensor will be created.
146    ///
147    /// # Example
148    /// ```rust
149    /// use burn_tensor::backend::Backend;
150    /// use burn_tensor::Tensor;
151    ///
152    /// fn example<B: Backend>() {
153    ///    let device = Default::default();
154    ///    // Create an empty tensor with dimensions [2, 3, 4].
155    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
156    /// }
157    /// ```
158    pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
159        let shape = shape.into();
160        check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
161        Self::new(K::empty(shape, device, K::Elem::dtype()))
162    }
163
164    /// Create a tensor of the given shape where each element is equal to the provided value.
165    ///
166    /// # Example
167    ///
168    /// ```rust
169    /// use burn_tensor::backend::Backend;
170    /// use burn_tensor::{Tensor, Shape};
171    ///
172    /// fn example<B: Backend>() {
173    ///   let device = B::Device::default();
174    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
175    ///   println!("{tensor}");
176    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
177    /// }
178    /// ```
179    pub fn full<S: Into<Shape>, E: ElementConversion>(
180        shape: S,
181        fill_value: E,
182        device: &B::Device,
183    ) -> Self {
184        let shape = shape.into();
185        check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
186        Self::new(K::full(shape, fill_value, device, K::Elem::dtype()))
187    }
188
189    /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
190    /// filled with the provided value.
191    ///
192    /// # Example
193    ///
194    /// ```rust
195    /// use burn_tensor::backend::Backend;
196    /// use burn_tensor::{Tensor, Shape};
197    ///
198    /// fn example<B: Backend>() {
199    ///    let device = B::Device::default();
200    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
201    ///    let tensor = tensor.full_like(5.0);
202    ///    println!("{tensor}");
203    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
204    /// }
205    /// ```
206    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
207        Self::new(K::full(
208            self.shape(),
209            fill_value,
210            &self.device(),
211            self.dtype(),
212        ))
213    }
214
215    /// Returns the dimensions of the current tensor.
216    ///
217    /// # Example
218    /// ```rust
219    /// use burn_tensor::backend::Backend;
220    /// use burn_tensor::Tensor;
221    ///
222    /// fn example<B: Backend>() {
223    ///   let device = Default::default();
224    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
225    ///   let dims = tensor.dims(); // [2, 3, 4]
226    ///   println!("{dims:?}");
227    /// }
228    /// ```
229    pub fn dims(&self) -> [usize; D] {
230        Self::shape(self).dims()
231    }
232
233    /// Returns the shape of the current tensor.
234    ///
235    /// # Example
236    /// ```rust
237    /// use burn_tensor::backend::Backend;
238    /// use burn_tensor::Tensor;
239    ///
240    /// fn example<B: Backend>() {
241    ///    let device = Default::default();
242    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
243    ///    // Shape { dims: [2, 3, 4] }
244    ///    let shape = tensor.shape();
245    /// }
246    /// ```
247    pub fn shape(&self) -> Shape {
248        self.primitive.shape()
249    }
250
251    /// Reshape the tensor to have the given shape.
252    ///
253    /// The tensor has the same data and number of elements as the input.
254    ///
255    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
256    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
257    ///
258    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
259    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
260    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
261    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
262    /// with [1, 3, 4] dimensions to [1, 12].
263    ///
264    /// # Arguments
265    /// - `shape`: The new shape of the tensor.
266    ///
267    /// # Panics
268    /// - If the tensor contains more than one `-1` in the shape.
269    /// - If the tensor contains values that are not positive (other than -1).
270    /// - If the shape does not match the number of elements of the original shape.
271    ///
272    /// # Example
273    ///
274    /// ```rust
275    /// use burn_tensor::backend::Backend;
276    /// use burn_tensor::Tensor;
277    ///
278    /// fn example<B: Backend>() {
279    ///    let device = Default::default();
280    ///    // Create a tensor with dimensions [2, 3, 4]
281    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
282    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
283    ///    let reshaped = tensor.reshape([2, -1]);
284    ///    println!("{reshaped}");
285    /// }
286    /// ```
287    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
288        // Convert reshape args to shape
289        let shape = shape.into_shape(&self);
290        Tensor::new(K::reshape(self.primitive, shape))
291    }
292
293    /// Transpose the tensor.
294    ///
295    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
296    /// applied on the last two dimensions. For example, the transpose of a tensor with shape
297    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
298    ///
299    /// See also [`permute`](Tensor::permute).
300    ///
301    /// # Arguments
302    ///
303    /// * `tensor` - The tensor to transpose.
304    ///
305    /// # Returns
306    ///
307    /// The transposed tensor.
308    ///
309    /// # Example
310    ///
311    /// ```rust
312    /// use burn_tensor::backend::Backend;
313    /// use burn_tensor::Tensor;
314    ///
315    /// fn example<B: Backend>() {
316    ///     let device = Default::default();
317    ///     // Create a 2D tensor of shape [2, 3]
318    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
319    ///
320    ///     // Transpose the tensor:
321    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
322    ///     // The resulting tensor will have dimensions [3, 2].
323    ///     let transposed = tensor.transpose();
324    ///     println!("{transposed}");
325    /// }
326    /// ```
327    pub fn transpose(self) -> Tensor<B, D, K> {
328        Tensor::new(K::transpose(self.primitive))
329    }
330
331    /// Alias for `transpose`.
332    #[inline(always)]
333    pub fn t(self) -> Tensor<B, D, K> {
334        self.transpose()
335    }
336
337    /// Swaps two dimensions of a tensor.
338    ///
339    /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
340    ///
341    /// # Arguments
342    ///
343    /// * `tensor` - The tensor to swap the dimensions of.
344    /// * `dim1` - The first dimension to swap, supports negative indexing.
345    /// * `dim2` - The second dimension to swap, supports negative indexing.
346    ///
347    /// # Returns
348    ///
349    /// The tensor with the dimensions swapped.
350    ///
351    /// # Panics
352    ///
353    /// When dimensions are out of bounds.
354    ///
355    /// # Example
356    ///
357    /// ```rust
358    /// use burn_tensor::backend::Backend;
359    /// use burn_tensor::Tensor;
360    ///
361    /// fn example<B: Backend>() {
362    ///     let device = Default::default();
363    ///     // Create a 2D tensor of shape [2, 3]
364    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
365    ///
366    ///     // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
367    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
368    ///     // The resulting tensor will have dimensions [3, 2].
369    ///     let swapped = tensor.swap_dims(0, -1);
370    ///     println!("{swapped}");
371    /// }
372    /// ```
373    pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
374    where
375        Dim1: AsIndex,
376        Dim2: AsIndex,
377    {
378        let dim1 = canonicalize_dim(dim1, D, false);
379        let dim2 = canonicalize_dim(dim2, D, false);
380        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
381        if dim1 == dim2 {
382            self
383        } else {
384            Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
385        }
386    }
387
388    /// Permute the dimensions of the tensor.
389    ///
390    /// This is a no-op when the resolved `axes` match the current order.
391    ///
392    /// # Arguments
393    ///
394    /// * `axes` - The new order of the dimensions. The length of the axes
395    ///   must be equal to the number of dimensions of the tensor.
396    ///   The values must be unique and in the range of the number of dimensions.
397    ///   The values can be negative, in which case they are used as an offset from the end.
398    ///
399    /// # Returns
400    ///
401    /// The tensor with the dimensions permuted.
402    ///
403    /// # Example
404    ///
405    /// ```rust
406    /// use burn_tensor::backend::Backend;
407    /// use burn_tensor::Tensor;
408    ///
409    /// fn example<B: Backend>() {
410    ///     let device = Default::default();
411    ///     // Create a 2D tensor of shape [3, 2]
412    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
413    ///
414    ///     // Permute the dimensions 1 and 0:
415    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
416    ///     // The resulting tensor will have dimensions [3, 2].
417    ///     let permuted = tensor.permute([1, 0]);
418    ///     println!("{permuted}");
419    /// }
420    /// ```
421    pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
422    where
423        Dim: AsIndex,
424    {
425        let mut no_op = true;
426        let mut fixed_axes = [0; D];
427        for (i, axis) in axes.into_iter().enumerate() {
428            let dim = canonicalize_dim(axis, D, false);
429            no_op &= dim == i;
430            fixed_axes[i] = dim;
431        }
432
433        if no_op {
434            self
435        } else {
436            check!(TensorCheck::permute(fixed_axes));
437            Tensor::new(K::permute(self.primitive, &fixed_axes))
438        }
439    }
440
441    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
442    ///
443    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
444    /// at the positions not specified in destination.
445    ///
446    /// # Arguments
447    ///
448    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
449    ///   The values can be negative, in which case they are used as an offset from the end.
450    ///
451    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
452    ///
453    /// # Panics
454    ///
455    /// - If the source and destination dimensions are not of the same length.
456    /// - If the source and destination vectors contain duplicate values.
457    /// - If the source and destination vectors contain values that are out of bounds.
458    ///
459    /// # Returns
460    ///
461    /// The tensor with the dimensions moved.
462    ///
463    /// # Example
464    ///
465    /// ```rust
466    /// use burn_tensor::backend::Backend;
467    /// use burn_tensor::Tensor;
468    ///
469    /// fn example<B: Backend>() {
470    ///     let device = Default::default();
471    ///     // Create a 3D tensor of shape [3, 2, 1]
472    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
473    ///
474    ///     // Move the dimensions 0 and 1:
475    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
476    ///     // The resulting tensor will have dimensions [2, 3, 1].
477    ///     let moved = tensor.movedim(1, 0);
478    ///     println!("{moved}");
479    /// }
480    /// ```
481    ///
482    /// # Note
483    ///
484    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
485    /// for it
486    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
487        let source_dims = src.into_dim_vec::<D>();
488        let destination_dims = dst.into_dim_vec::<D>();
489
490        check!(TensorCheck::movedim_args_length(
491            &source_dims,
492            &destination_dims
493        ));
494
495        let mut m = [-1; D];
496        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
497            m[d] = s as isize;
498        }
499        let mut axes: [isize; D] = [0; D];
500        let mut source_i = 0;
501        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
502            *item = if m[dest_i] != -1 {
503                m[dest_i]
504            } else {
505                while source_dims.contains(&source_i) {
506                    source_i += 1;
507                }
508                let result = source_i as isize;
509                source_i += 1;
510                result
511            };
512        }
513
514        self.permute(axes)
515    }
516
517    /// Reverse the order of elements in the tensor along the given dimensions.
518    ///
519    /// # Arguments
520    ///
521    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
522    ///   The values can be negative, in which case they are used as an offset from the end.
523    ///
524    /// # Returns
525    ///
526    /// The tensor with the axes flipped.
527    ///
528    /// # Example
529    ///
530    /// ```rust
531    /// use burn_tensor::backend::Backend;
532    /// use burn_tensor::Tensor;
533    ///
534    /// fn example<B: Backend>() {
535    ///     let device = Default::default();
536    ///     // Create a 2D tensor with dimensions [4, 3]
537    ///     let tensor = Tensor::<B, 2>::from_data(
538    ///         [
539    ///             [3.0, 4.9, 2.0],
540    ///             [2.0, 1.9, 3.0],
541    ///             [4.0, 5.9, 8.0],
542    ///             [1.4, 5.8, 6.0],
543    ///         ],
544    ///         &device,
545    ///     );
546    ///
547    ///     // Flip the elements in dimensions 0 and 1:
548    ///     // [[6.0, 5.8, 1.4],
549    ///     //  [8.0, 5.9, 4.0],
550    ///     //  [3.0, 1.9, 2.0],
551    ///     //  [2.0, 4.9, 3.0]]
552    ///     // The resulting tensor will have dimensions [4, 3].
553    ///     let flipped = tensor.flip([0, 1]);
554    ///     println!("{flipped}");
555    /// }
556    /// ```
557    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
558        // Convert the axes to usize and handle negative values without using vector
559        let mut transformed_axes: [usize; N] = [0; N];
560        for (i, &x) in axes.iter().enumerate() {
561            transformed_axes[i] = if x < 0 {
562                (D as isize + x) as usize
563            } else {
564                x as usize
565            };
566        }
567
568        // Check if the axes are valid
569        check!(TensorCheck::flip(D, &transformed_axes));
570
571        Tensor::new(K::flip(self.primitive, &transformed_axes))
572    }
573
574    /// Flatten the tensor along a given range of dimensions.
575    ///
576    /// This function collapses the specified range of dimensions into a single dimension,
577    /// effectively flattening the tensor in that range.
578    ///
579    /// # Arguments
580    ///
581    /// - `start_dim`: The starting dimension of the range to be flattened,
582    ///   supports negative indexing.
583    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
584    ///   supports negative indexing.
585    ///
586    /// # Type Parameters
587    ///
588    /// - `D2`: The resulting number of dimensions in the flattened tensor.
589    ///
590    /// # Returns
591    ///
592    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    ///
598    /// use burn_tensor::backend::Backend;
599    /// use burn_tensor::{Tensor, Shape};
600    ///
601    /// fn example<B: Backend>() {
602    ///     let device = Default::default();
603    ///     // Create a 3D tensor with dimensions [2, 3, 4]
604    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
605    ///
606    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
607    ///     // The resulting tensor will have dimensions [2, 12]
608    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
609    ///     println!("{flattened}");
610    /// }
611    /// ```
612    pub fn flatten<const D2: usize>(
613        self,
614        start_dim: impl AsIndex,
615        end_dim: impl AsIndex,
616    ) -> Tensor<B, D2, K> {
617        let start_dim = canonicalize_dim(start_dim, D, false);
618        let end_dim = canonicalize_dim(end_dim, D, false);
619        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
620
621        let current_dims = self.shape().dims;
622        let mut new_dims: [usize; D2] = [0; D2];
623        let mut flatten_dims = 1;
624
625        for i in current_dims[start_dim..=end_dim].iter() {
626            flatten_dims *= i;
627        }
628
629        new_dims[..start_dim].copy_from_slice(&current_dims[..start_dim]);
630        new_dims[start_dim] = flatten_dims;
631        new_dims[start_dim + 1..].copy_from_slice(&current_dims[end_dim + 1..]);
632
633        Tensor::new(K::reshape(self.primitive, new_dims.into()))
634    }
635
636    /// Squeeze the tensor along all dimensions, removing dimensions
637    /// of size one, and effectively reducing the rank of the tensor.
638    ///
639    /// # Type Parameters
640    ///
641    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
642    ///
643    /// # Returns
644    ///
645    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
646    ///
647    /// # Example
648    ///
649    /// ```rust
650    ///
651    /// use burn_tensor::backend::Backend;
652    /// use burn_tensor::{Tensor, Shape};
653    ///
654    /// fn example<B: Backend>() {
655    ///     let device = Default::default();
656    ///     // Create a 4D tensor with dimensions [1, 3, 1, 3]
657    ///     let tensor = Tensor::<B, 4>::from_data(
658    ///         [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
659    ///         &device,
660    ///     );
661    ///
662    ///     // Squeeze the tensor dimensions.
663    ///     // The resulting tensor will have dimensions [3, 3].
664    ///     let squeezed = tensor.squeeze::<2>();
665    ///     println!("{squeezed}");
666    /// }
667    /// ```
668    pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
669        let new_dims = self
670            .shape()
671            .dims
672            .iter()
673            .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
674            .collect::<Vec<_>>();
675        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
676
677        Tensor::new(K::reshape(self.primitive, new_dims.into()))
678    }
679
680    /// Squeeze the tensor along the given dimension, removing the specified dimension
681    /// of size one, and effectively reducing the rank of the tensor by one.
682    ///
683    /// # Arguments
684    ///
685    /// - `dim`: The dimension to be squeezed.
686    ///
687    /// # Type Parameters
688    ///
689    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
690    ///
691    /// # Panics
692    ///
693    /// If the size in the squeezed dimension is not 1.
694    ///
695    /// # Returns
696    ///
697    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
698    ///
699    /// # Example
700    ///
701    /// ```rust
702    ///
703    /// use burn_tensor::backend::Backend;
704    /// use burn_tensor::{Tensor, Shape};
705    ///
706    /// fn example<B: Backend>() {
707    ///     let device = Default::default();
708    ///     // Create a 3D tensor with dimensions [3, 1, 3]
709    ///     let tensor = Tensor::<B, 3>::from_data(
710    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
711    ///         &device,
712    ///     );
713    ///
714    ///     // Squeeze the dimension 1.
715    ///     // The resulting tensor will have dimensions [3, 3].
716    ///     let squeezed = tensor.squeeze_dim::<2>(1);
717    ///     println!("{squeezed}");
718    /// }
719    /// ```
720    pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
721        check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
722
723        let current_dims = self.shape().dims;
724        let mut new_dims: [usize; D2] = [0; D2];
725
726        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
727        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
728
729        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
730        Tensor::new(K::reshape(self.primitive, new_dims.into()))
731    }
732
733    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
734    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
735    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
736    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
737    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
738    /// from the back.
739    ///
740    /// # Arguments
741    ///
742    /// - `dims`: The dimension(s) to be squeezed.
743    ///
744    /// # Type Parameters
745    ///
746    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
747    ///
748    /// # Returns
749    ///
750    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
751    ///
752    /// # Example
753    ///
754    /// ```rust
755    ///
756    /// use burn_tensor::backend::Backend;
757    /// use burn_tensor::{Tensor, Shape};
758    ///
759    /// fn example<B: Backend>() {
760    ///     let device = Default::default();
761    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
762    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
763    ///
764    ///     // Squeeze the dimensions 1 and 3.
765    ///     // The resulting tensor will have dimensions [2, 4].
766    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
767    ///     println!("{squeezed}");
768    /// }
769    /// ```
770    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
771        let current_dims = self.shape().dims;
772        let mut dim_indices: Vec<usize>;
773
774        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
775        if dims.is_empty() {
776            dim_indices = current_dims
777                .iter()
778                .enumerate()
779                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
780                .collect();
781        } else {
782            // If negative dims, count from the back
783            dim_indices = dims
784                .iter()
785                .map(|&d| {
786                    if d < 0 {
787                        (current_dims.len() as isize + d) as usize
788                    } else {
789                        d as usize
790                    }
791                })
792                .collect();
793        }
794
795        // Sort indices and remove duplicates
796        dim_indices.sort_unstable();
797        dim_indices.dedup();
798
799        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
800        check!(TensorCheck::squeeze_dims_input::<D2>(
801            &dim_indices,
802            &current_dims
803        ));
804
805        // Calculate new dimensions
806        let mut new_dims = Vec::new();
807        for (index, &dim_size) in current_dims.iter().enumerate() {
808            // Exclude the dimension if it's explicitly marked for squeezing
809            if dim_indices.contains(&index) {
810                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
811                continue;
812            }
813            new_dims.push(dim_size);
814        }
815
816        // Check that after squeezing, we still respect the D2 size
817        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
818
819        Tensor::new(K::reshape(self.primitive, new_dims.into()))
820    }
821
822    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
823    ///
824    /// # Type Parameters
825    ///
826    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.
827    ///
828    /// # Panics
829    ///
830    /// If the output size `D2` is smaller than the current number of dimensions.
831    ///
832    /// # Returns
833    ///
834    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
835    ///
836    /// # Example
837    ///
838    /// ```rust
839    /// use burn_tensor::backend::Backend;
840    /// use burn_tensor::{Tensor, Shape};
841    ///
842    /// fn example<B: Backend>() {
843    ///     let device = Default::default();
844    ///     // Create a 2D tensor with dimensions [3, 3]
845    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
846    ///     // Unsqueeze the tensor up to 4 dimensions.
847    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
848    ///     let unsqueezed = tensor.unsqueeze::<4>();
849    ///     println!("{unsqueezed}");
850    /// }
851    /// ```
852    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
853        check!(TensorCheck::unsqueeze::<D, D2>());
854
855        let mut dims = [1; D2];
856        let num_ones = D2 - D;
857        let shape = self.shape();
858
859        dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
860
861        let shape = Shape::new(dims);
862        self.reshape(shape)
863    }
864
865    /// Creates a new tensor with a dimension of size one inserted at the specified position.
866    ///
867    /// # Example
868    ///
869    /// ```rust
870    /// use burn_tensor::backend::Backend;
871    /// use burn_tensor::{Tensor, Shape};
872    ///
873    /// fn example<B: Backend>() {
874    ///     let device = Default::default();
875    ///     // Create a 2D tensor with dimensions [3, 3]
876    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
877    ///     // Unsqueeze the dimension 1.
878    ///     // The resulting tensor will have dimensions [3, 1, 3].
879    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
880    ///     println!("{unsqueezed}");
881    /// }
882    /// ```
883    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
884        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
885
886        let mut dims = [1; D2];
887        let shape = self.shape();
888
889        dims[0..dim].copy_from_slice(&shape[0..dim]);
890
891        if dim < D {
892            dims[dim] = 1;
893            dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
894        } else {
895            dims[dim] = 1;
896        }
897
898        let shape = Shape::new(dims);
899        self.reshape(shape)
900    }
901
902    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
903    /// The indices can be negative, in which case they are counted from the last to the first dimension.
904    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
905    /// is the number of duplicates.
906    /// # Example
907    ///
908    /// ```rust
909    /// use burn_tensor::backend::Backend;
910    /// use burn_tensor::{Tensor, Shape};
911    ///
912    /// fn example<B: Backend>() {
913    ///     let device = Default::default();
914    ///     // Create a 3D tensor with dimensions [3, 4, 5]
915    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
916    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
917    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
918    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
919    ///     println!("{unsqueezed}");
920    /// }
921    /// ```
922    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
923        let mut new_dims = [1; D2];
924        let old_dims = self.shape().dims;
925        //for checking if the dimension is in the acceptable range
926
927        //part 1: convert the negative indices to positive
928        let mut neg_offset = D2;
929        let mut dim_indices = axes
930            .iter()
931            .map(|d| {
932                // check if the dimension is in the acceptable range
933                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
934                (if *d < 0 {
935                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
936                    d + neg_offset as isize + 1
937                } else {
938                    *d
939                }) as usize
940            })
941            .collect::<Vec<usize>>();
942
943        //sort the indices
944        dim_indices.sort_unstable();
945
946        //Now use this to copy the chunks of the dims
947        let mut prev_idx: usize = 0;
948        let mut current_left_b: usize = 0;
949        let mut current_right_b: usize = 0;
950        let mut offset: usize = 0;
951        dim_indices.iter().for_each(|d| {
952            //check if there is space for at least one dimension
953            if prev_idx < *d {
954                current_right_b = *d - offset;
955                //copy the chunks of the dims
956                if current_right_b < D {
957                    new_dims[prev_idx..*d]
958                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);
959                } else {
960                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
961                }
962                prev_idx = *d + 1;
963                //offset is equal to the number of extracted elements from the original shape
964                offset += current_right_b - current_left_b;
965                current_left_b = current_right_b;
966            } else {
967                //it's sorted so the only reason this would happen
968                //is if multiple indices are the same
969                prev_idx += 1;
970            }
971        });
972        //copy over anything past the index of the last new dimension
973        if current_left_b < D {
974            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
975        }
976
977        //lastly, create the shape and reshape
978        let shape = Shape::new(new_dims);
979        self.reshape(shape)
980    }
981
982    /// Roll operation along a specific dimension; wrapping around the elements.
983    ///
984    /// ## Parameters
985    ///
986    /// - `shift`: The roll extent; supports negative values and wraps around.
987    /// - `dim`: The dimension to roll; supports negative indexing.
988    ///
989    /// ## Returns
990    ///
991    /// A new tensor with the specified dimension rolled by the given shift amount.
992    pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
993    where
994        Shift: AsIndex,
995        Dim: AsIndex,
996    {
997        let dim = canonicalize_dim(dim, D, false);
998        let size = self.shape().dims[dim];
999        if size == 0 {
1000            // If the dimension is empty, return the tensor as is.
1001            return self;
1002        }
1003
1004        let shift = wrap_index(shift, size);
1005        if shift == 0 {
1006            // If the shift is zero, return the tensor as is.
1007            return self;
1008        }
1009
1010        self.unchecked_roll_dim(shift, dim)
1011    }
1012
1013    /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1014    ///
1015    /// ## Parameters
1016    ///
1017    /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1018    /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1019    ///
1020    /// ## Returns
1021    ///
1022    /// A new tensor with the specified dimension rolled by the given shift amount.
1023    #[inline(always)]
1024    fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1025        #[cfg(debug_assertions)]
1026        {
1027            let size = self.shape().dims[dim];
1028            assert!(
1029                0 < shift && shift < size,
1030                "Expected: 0 < shift < size: found shift={shift}, size={size}",
1031            );
1032            assert!(
1033                dim < self.shape().num_dims(),
1034                "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1035            );
1036        }
1037
1038        Tensor::cat(
1039            vec![
1040                self.clone().slice_dim(dim, shift..),
1041                self.slice_dim(dim, ..shift),
1042            ],
1043            dim,
1044        )
1045    }
1046
1047    /// Roll operation.
1048    ///
1049    /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1050    ///
1051    /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1052    ///
1053    /// ## Parameters
1054    ///
1055    /// - `shifts`: A slice of shifts corresponding to each dimension;
1056    ///   supports negative values and wraps around.
1057    /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1058    ///
1059    /// ## Returns
1060    ///
1061    /// A new tensor with the specified dimensions rolled by the given shifts.
1062    pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1063    where
1064        Shift: AsIndex,
1065        Dim: AsIndex,
1066    {
1067        assert_eq!(
1068            dims.len(),
1069            shifts.len(),
1070            "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1071        );
1072
1073        // This is a fair amount of complexity, which could be replaced
1074        // by a simple canonicalization of `dims` and wrapping of `shifts`.
1075        // The work is done here to ensure that any roll operation
1076        // which could be a no-op is a no-op; simplifying the accounting
1077        // needed by backend-specific implementations of the inner roll op.
1078
1079        let item_count = dims.len();
1080
1081        let shape = self.shape().dims;
1082
1083        // Accumulate the effective shifts for each dimension.
1084        let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1085        for i in 0..item_count {
1086            let dim = canonicalize_dim(dims[i], D, false);
1087            accumulated_shifts[dim] += shifts[i].index();
1088        }
1089
1090        // Do this after we've checked the validity of `dims` and `shifts`.
1091        if self.shape().num_elements() == 0 {
1092            // If the tensor is empty, return it as is.
1093            return self;
1094        }
1095
1096        // Wrap the accumulated shifts, and filter out empty dimensions.
1097        let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1098        let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1099        for dim in 0..shape.len() {
1100            // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1101            let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1102            if shift == 0 {
1103                continue;
1104            }
1105
1106            effective_dims.push(dim);
1107            effective_shifts.push(shift);
1108        }
1109
1110        // If no shifts are needed, return the original tensor.
1111        if effective_shifts.is_empty() {
1112            return self;
1113        }
1114
1115        // At this point:
1116        // - `dims` contains the effective dimensions to roll, in index order,
1117        // - `shifts` contains the effective usize shifts for each dimension.
1118        // - Every shift is non-zero, and less than the size of the corresponding dimension.
1119        self.unchecked_roll(&effective_shifts, &effective_dims)
1120    }
1121
1122    /// `roll` internal implementation.
1123    ///
1124    /// ## Parameters
1125    ///
1126    /// - `shifts`: A slice of shifts corresponding to each dimension;
1127    ///   must be non-empty, the same length as `dims`, and all ``1..<size>``.
1128    /// - `dims`: A slice of dimensions to roll; must be non-empty;
1129    ///   the same length as `shifts`, and must not contain repeats.
1130    ///
1131    /// ## Panics
1132    ///
1133    /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1134    ///
1135    /// ## Returns
1136    ///
1137    /// A new tensor with the specified dimensions rolled by the given shifts.
1138    #[inline(always)]
1139    fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1140        #[cfg(debug_assertions)]
1141        {
1142            assert!(!shifts.is_empty());
1143            assert_eq!(
1144                shifts.len(),
1145                dims.len(),
1146                "Shifts and dimensions must align; found {} shifts and {} dims",
1147                shifts.len(),
1148                dims.len()
1149            );
1150
1151            let mut unique_dims = dims.to_vec();
1152            unique_dims.dedup();
1153
1154            assert_eq!(
1155                unique_dims.len(),
1156                dims.len(),
1157                "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1158                unique_dims.len(),
1159                dims.len()
1160            )
1161        }
1162
1163        let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1164
1165        if dims.len() == 1 {
1166            x
1167        } else {
1168            x.unchecked_roll(&shifts[1..], &dims[1..])
1169        }
1170    }
1171
1172    /// Returns a tensor containing the elements selected from the given slices.
1173    ///
1174    /// This method provides flexible tensor slicing with support for various range types,
1175    /// negative indices, and stepped slicing. The method accepts both single slices and
1176    /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1177    ///
1178    /// # Arguments
1179    ///
1180    /// * `slices` - Can be:
1181    ///   - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1182    ///   - An array of ranges (e.g., `[0..2, 1..4]`)
1183    ///   - The [`s!`] macro output for advanced slicing with steps
1184    ///
1185    /// # Behavior
1186    ///
1187    /// - Supports partial and full slicing in any number of dimensions
1188    /// - Handles negative indices by wrapping from the end (-1 is the last element)
1189    /// - Automatically clamps ranges that exceed tensor dimensions
1190    /// - Supports stepped slicing for selecting every nth element
1191    /// - Negative steps reverse the selection order
1192    ///
1193    /// # Panics
1194    ///
1195    /// - If the number of slices exceeds the tensor's dimensions
1196    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1197    /// - If a step is zero
1198    ///
1199    /// # Examples
1200    ///
1201    /// ```rust
1202    /// use burn_tensor::backend::Backend;
1203    /// use burn_tensor::{Tensor, Shape, s};
1204    ///
1205    /// fn example<B: Backend>() {
1206    ///     let device = B::Device::default();
1207    ///
1208    ///     // Single dimension slicing - no brackets needed!
1209    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1210    ///     let slice = tensor.clone().slice(2..8);  // Simple range
1211    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1212    ///
1213    ///     // Using s! macro for single dimension with step
1214    ///     let slice = tensor.clone().slice(s![0..10;2]);  // Every 2nd element
1215    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1216    ///
1217    ///     // Reverse a dimension with negative step
1218    ///     let slice = tensor.slice(s![..;-1]);  // Reverse entire tensor
1219    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1220    ///
1221    ///     // Multi-dimensional slicing
1222    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1223    ///
1224    ///     // Array syntax for simple ranges
1225    ///     let slice = tensor.clone().slice([1..3, 2..5]);
1226    ///     assert_eq!(slice.dims(), [2, 3]);
1227    ///
1228    ///     // Advanced multi-dimensional with s! macro
1229    ///     let slice = tensor.clone().slice(s![0..4;2, ..;-1]);  // Every 2nd row, reverse columns
1230    ///     assert_eq!(slice.dims(), [2, 6]);
1231    ///
1232    ///     // Complex 3D example with mixed slice types
1233    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1234    ///     let slice = tensor.slice(s![1..3, ..;2, -3..]);  // Rows 1-2, every 2nd col, last 3 depth
1235    ///     assert_eq!(slice.dims(), [2, 3, 3]);
1236    ///
1237    ///     // Using negative indices
1238    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1239    ///     let slice = tensor.slice(s![-2.., ..-1]);  // Last 2 rows, all but last column
1240    ///     assert_eq!(slice.dims(), [2, 5]);
1241    /// }
1242    /// ```
1243    ///
1244    /// # See Also
1245    ///
1246    /// - [`s!`] - The recommended macro for creating complex slice specifications
1247    /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1248    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1249    /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1250    ///
1251    /// [`s!`]: crate::s!
1252    pub fn slice<const D2: usize, S>(self, slices: S) -> Self
1253    where
1254        S: SliceArg<D2>,
1255    {
1256        let shape = self.shape();
1257        let slices = slices.into_slices(shape.clone());
1258
1259        // Validate slices
1260        check!(TensorCheck::slice::<D, D2>(&shape, &slices));
1261
1262        // Use the slice method that supports steps
1263        Self::new(K::slice(self.primitive, &slices))
1264    }
1265
1266    /// Assigns values to a slice of the tensor and returns the updated tensor.
1267    ///
1268    /// This method supports advanced slicing with steps, including negative steps for reverse
1269    /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1270    /// providing powerful syntax for complex patterns.
1271    ///
1272    /// # Arguments
1273    ///
1274    /// * `slices` - Slice specification (same format as `slice` method)
1275    /// * `values` - Tensor with values to assign (must match slice dimensions)
1276    ///
1277    /// # Panics
1278    ///
1279    /// - If slices exceed tensor dimensions
1280    /// - If values dimensions don't match the selected slice shape
1281    /// - If a step is zero
1282    ///
1283    /// # Examples
1284    ///
1285    /// ```rust
1286    /// use burn_tensor::backend::Backend;
1287    /// use burn_tensor::{Tensor, s};
1288    ///
1289    /// fn example<B: Backend>() {
1290    ///     let device = B::Device::default();
1291    ///
1292    ///     // Simple assignment to a sub-region
1293    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1294    ///     let values = Tensor::<B, 2>::ones([2, 3], &device);
1295    ///     tensor = tensor.slice_assign([1..3, 2..5], values);
1296    ///     // Now tensor[1..3, 2..5] contains ones
1297    ///
1298    ///     // Single dimension assignment with step
1299    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1300    ///     let values = Tensor::<B, 1>::ones([5], &device);
1301    ///     tensor = tensor.slice_assign(s![0..10;2], values);
1302    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1303    ///
1304    ///     // Reverse assignment with negative step
1305    ///     let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1306    ///     let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1307    ///     tensor = tensor.slice_assign(s![..;-1], values);
1308    ///     // Assigns in reverse: [14, 13, 12, 11, 10]
1309    ///
1310    ///     // Complex multi-dimensional assignment
1311    ///     let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1312    ///     let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1313    ///     tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1314    ///     // Assigns to every 2nd row, every 2nd column, last 3 in depth
1315    ///
1316    ///     // Mixed syntax example
1317    ///     let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1318    ///     let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1319    ///     tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1320    ///     // Creates a checkerboard pattern with ones
1321    /// }
1322    /// ```
1323    ///
1324    /// # See Also
1325    ///
1326    /// - [`s!`] - The recommended macro for creating complex slice specifications
1327    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1328    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1329    ///
1330    /// [`s!`]: crate::s!
1331    pub fn slice_assign<const D2: usize, S>(self, slices: S, values: Self) -> Self
1332    where
1333        S: SliceArg<D2>,
1334    {
1335        let shape = self.shape();
1336        let slices = slices.into_slices(shape.clone());
1337
1338        check!(TensorCheck::slice_assign::<D, D2>(
1339            &shape,
1340            &values.shape(),
1341            &slices
1342        ));
1343
1344        Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1345    }
1346
1347    /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1348    ///
1349    /// Like other slice methods, accepts both single slices and arrays. However, this method
1350    /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1351    /// with a constant tensor for stepped patterns.
1352    ///
1353    /// # Arguments
1354    ///
1355    /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1356    /// * `value` - The value to fill the slice with
1357    ///
1358    /// # Panics
1359    ///
1360    /// - If slices exceed tensor dimensions
1361    /// - If any slice has a step != 1 (not yet supported)
1362    ///
1363    /// # Examples
1364    ///
1365    /// ```rust
1366    /// use burn_tensor::backend::Backend;
1367    /// use burn_tensor::{Tensor, s};
1368    ///
1369    /// fn example<B: Backend>() {
1370    ///     let device = B::Device::default();
1371    ///
1372    ///     // Simple fill for a single dimension
1373    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1374    ///     tensor = tensor.slice_fill(2..5, 1.0);
1375    ///     // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1376    ///
1377    ///     // Multi-dimensional fill
1378    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1379    ///     tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1380    ///     // Fills the rectangle at rows 1-2, columns 2-4 with -1
1381    ///
1382    ///     // Using negative indices
1383    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1384    ///     tensor = tensor.slice_fill(-3.., 2.0);
1385    ///     // Fills the last 3 elements with 2.0
1386    ///
1387    ///     // Complex multi-dimensional example
1388    ///     let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1389    ///     tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1390    ///     // Sets rows 1-2, all columns, last 2 in depth to 0
1391    ///
1392    ///     // Stepped slicing is supported
1393    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1394    ///     tensor = tensor.slice_fill(s![0..10;2], 1.0);
1395    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1396    /// }
1397    /// ```
1398    ///
1399    /// # See Also
1400    ///
1401    /// - [`s!`] - The macro for creating slice specifications with steps
1402    /// - [`slice`](Self::slice) - Extract a slice from a tensor
1403    /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1404    ///
1405    /// [`s!`]: crate::s!
1406    pub fn slice_fill<const D2: usize, S, E: ElementConversion>(self, slices: S, value: E) -> Self
1407    where
1408        S: SliceArg<D2>,
1409    {
1410        let shape = self.shape();
1411        let slices = slices.into_slices(shape.clone());
1412
1413        check!(TensorCheck::slice::<D, D2>(&shape, &slices));
1414
1415        Self::new(K::slice_fill(self.primitive, &slices, value.elem()))
1416    }
1417
1418    /// Returns a new tensor with the specified dimension sliced.
1419    ///
1420    /// # Arguments
1421    ///
1422    /// * `dim`: The dimension to slice.
1423    /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1424    ///   slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1425    ///
1426    /// # Returns
1427    ///
1428    /// A new tensor with the specified dimension sliced.
1429    ///
1430    /// # Panics
1431    ///
1432    /// If the slice is out of bounds for the specified dimension.
1433    ///
1434    /// # Examples
1435    ///
1436    /// ```rust
1437    /// # use burn_tensor::{Tensor, s};
1438    /// # use burn_tensor::backend::Backend;
1439    /// #
1440    /// # fn example<B: Backend>() {
1441    /// #     let device = B::Device::default();
1442    ///     let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1443    ///
1444    ///     // Simple range slicing
1445    ///     let sliced = tensor.clone().slice_dim(1, 1..3);
1446    ///     assert_eq!(sliced.shape().dims, [3, 2, 5]);
1447    ///
1448    ///     // Slicing with step - take every 2nd element
1449    ///     let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1450    ///     assert_eq!(sliced.shape().dims, [3, 4, 3]); // Takes indices 0, 2, 4
1451    ///
1452    ///     // Reverse slicing with negative step
1453    ///     let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1454    ///     assert_eq!(sliced.shape().dims, [3, 4, 5]); // Reverses dimension 1
1455    ///
1456    ///     // Select from index 2 with step 3
1457    ///     let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1458    ///     assert_eq!(sliced.shape().dims, [1, 4, 5]); // Takes only index 2
1459    ///
1460    ///     // Select single index (reduces dimension to size 1)
1461    ///     let sliced = tensor.slice_dim(0, 1);
1462    ///     assert_eq!(sliced.shape().dims, [1, 4, 5]);
1463    /// # }
1464    /// ```
1465    ///
1466    /// # See Also
1467    ///
1468    /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1469    /// - [`s!`] - The macro for creating complex slice specifications
1470    ///
1471    /// [`s!`]: crate::s!
1472    pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1473    where
1474        S: Into<Slice>,
1475    {
1476        check!(TensorCheck::check_dim::<D>(dim));
1477        let slice: Slice = slice.into();
1478
1479        Self::new(K::slice_dim(self.primitive, dim, &slice))
1480    }
1481
1482    /// Returns the device of the current tensor.
1483    pub fn device(&self) -> B::Device {
1484        K::device(&self.primitive)
1485    }
1486
1487    /// Move the tensor to the given device.
1488    pub fn to_device(self, device: &B::Device) -> Self {
1489        Self::new(K::to_device(self.primitive, device))
1490    }
1491
1492    /// Select tensor elements along the given dimension corresponding to the given indices.
1493    ///
1494    /// # Arguments
1495    ///
1496    /// * `dim` - The dimension to select from. Supports negative indexing.
1497    /// * `indices` - The indices of the elements to select.
1498    ///
1499    /// # Example
1500    ///
1501    /// ```rust
1502    /// use burn_tensor::backend::Backend;
1503    /// use burn_tensor::{Tensor, Int};
1504    ///
1505    /// fn example<B: Backend>() {
1506    ///   let device = B::Device::default();
1507    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1508    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1509    ///   let tensor = tensor.select(0, indices);
1510    ///   println!("{tensor}");
1511    ///   //  [[1.0, -2.0, 3.0]]
1512    /// }
1513    /// ```
1514    pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1515        let dim = canonicalize_dim(dim, D, false);
1516        check!(TensorCheck::select::<D>(dim));
1517        Self::new(K::select(self.primitive, dim, indices))
1518    }
1519
1520    /// Assign the selected elements along the given dimension corresponding to the given indices
1521    /// from the value tensor to the original tensor using sum reduction.
1522    ///
1523    /// # Note
1524    /// For booleans, the sum operator is logical or.
1525    ///
1526    /// # Arguments
1527    ///
1528    /// * `dim` - The dimension along which to select. Supports negative indexing.
1529    /// * `indices` - The indices to select from the tensor.
1530    /// * `values` - The values to assign to the selected indices.
1531    ///
1532    /// # Example
1533    ///
1534    /// Example using a 3D tensor:
1535    ///
1536    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1537    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1538    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1539    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1540    ///
1541    /// # Warning
1542    ///
1543    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1544    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1545    pub fn select_assign(
1546        self,
1547        dim: impl AsIndex,
1548        indices: Tensor<B, 1, Int>,
1549        values: Tensor<B, D, K>,
1550    ) -> Self {
1551        let dim = canonicalize_dim(dim, D, false);
1552        check!(TensorCheck::select_assign::<D>(
1553            dim,
1554            &indices.shape(),
1555            &values.shape()
1556        ));
1557
1558        Self::new(K::select_assign(
1559            self.primitive,
1560            dim,
1561            indices,
1562            values.primitive,
1563        ))
1564    }
1565
1566    /// Converts the data of the current tensor.
1567    ///
1568    /// # Note
1569    ///
1570    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
1571    /// tensors at once. This may improve laziness, especially if executed on a different
1572    /// thread in native environments.
1573    pub fn into_data(self) -> TensorData {
1574        crate::try_read_sync(self.into_data_async()).expect(
1575            "Failed to read tensor data synchronously.
1576        This can happen on platforms that don't support blocking futures like WASM.
1577        If possible, try using into_data_async instead.",
1578        )
1579    }
1580
1581    /// Converts the data of the current tensor.
1582    ///
1583    /// # Note
1584    ///
1585    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
1586    /// tensors at once. This may improve laziness, especially if executed on a different
1587    /// thread in native environments.
1588    pub fn to_data(&self) -> TensorData {
1589        self.clone().into_data()
1590    }
1591
1592    /// Returns the data of the current tensor.
1593    pub async fn into_data_async(self) -> TensorData {
1594        K::into_data_async(self.primitive).await
1595    }
1596
1597    /// Returns the data of the current tensor.
1598    pub async fn to_data_async(&self) -> TensorData {
1599        self.clone().into_data_async().await
1600    }
1601
1602    /// Create a tensor from the given data on the given device.
1603    pub fn from_data<T>(data: T, device: &B::Device) -> Self
1604    where
1605        T: Into<TensorData>,
1606    {
1607        let data = data.into();
1608        check!(TensorCheck::creation_ops::<D>(
1609            "From Data",
1610            data.shape.as_slice()
1611        ));
1612        Self::new(K::from_data(data, device))
1613    }
1614
1615    /// Create a tensor from the given data on the given device enforcing the given data type.
1616    pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
1617    where
1618        T: Into<TensorData>,
1619    {
1620        let data = data.into();
1621        check!(TensorCheck::creation_ops::<D>(
1622            "From Data",
1623            data.shape.as_slice()
1624        ));
1625        Self::new(K::from_data_dtype(data, device, dtype))
1626    }
1627
1628    /// Repeat the tensor along the given dimension.
1629    ///
1630    /// The output tensor has the same shape, except along the given dimension.
1631    ///
1632    /// # Arguments
1633    /// - `dim`: The dimension to repeat.
1634    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1635    ///
1636    /// # Returns
1637    ///
1638    /// A new tensor with the given dimension repeated `times` times.
1639    ///
1640    /// # Example
1641    ///
1642    /// ```rust
1643    /// use burn_tensor::backend::Backend;
1644    /// use burn_tensor::Tensor;
1645    ///
1646    /// fn example<B: Backend>() {
1647    ///     let device = Default::default();
1648    ///     // Create a 2D tensor with dimensions [3, 2]
1649    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1650    ///
1651    ///     // Repeat the tensor along the dimension 0 twice.
1652    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1653    ///     // The resulting tensor will have dimensions [6, 2].
1654    ///     let repeated = tensor.repeat_dim(0, 2);
1655    ///     println!("{repeated}");
1656    /// }
1657    /// ```
1658    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1659        Self::new(K::repeat_dim(self.primitive, dim, times))
1660    }
1661
1662    /// Repeat the tensor along the given dimensions.
1663    /// # Arguments
1664    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1665    ///
1666    /// # Returns
1667    ///
1668    /// A new tensor with the given dimensions repeated `times` times.
1669    ///
1670    /// # Panics
1671    ///
1672    /// If `sizes` contains more elements than the number of dimensions.
1673    ///
1674    /// # Example
1675    ///
1676    /// ```rust
1677    ///
1678    /// use burn_tensor::backend::Backend;
1679    /// use burn_tensor::Tensor;
1680    ///
1681    /// fn example<B: Backend>() {
1682    ///     let device = Default::default();
1683    ///     // Create a 2D tensor with dimensions [3, 2]
1684    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1685    ///
1686    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1687    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1688    ///     // The resulting tensor will have dimensions [6, 2].
1689    ///     let repeated = tensor.repeat(&[2, 1]);
1690    /// }
1691    /// ```
1692    pub fn repeat(self, sizes: &[usize]) -> Self {
1693        let mut tensor = self;
1694        for (dim, &times) in sizes.iter().enumerate() {
1695            if times > 1 {
1696                tensor = tensor.repeat_dim(dim, times);
1697            }
1698        }
1699        tensor
1700    }
1701
1702    /// Applies element-wise equal comparison.
1703    ///
1704    /// # Returns
1705    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1706    ///
1707    /// # Panics
1708    ///
1709    /// If the two tensors don't have the same shape.
1710    ///
1711    /// # Example
1712    ///
1713    /// ```rust
1714    /// use burn_tensor::backend::Backend;
1715    /// use burn_tensor::Tensor;
1716    ///
1717    /// fn example<B: Backend>() {
1718    ///     let device = Default::default();
1719    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1720    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1721    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
1722    ///     // [[false, true], [true, true], [true, true]]
1723    ///     let equal = t1.equal(t2);
1724    ///     println!("{equal}");
1725    /// }
1726    /// ```
1727    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1728        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1729        Tensor::new(K::equal(self.primitive, other.primitive))
1730    }
1731
1732    /// Applies element-wise non-equality comparison.
1733    ///
1734    /// # Returns
1735    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1736    ///
1737    /// # Panics
1738    ///
1739    /// If the two tensors don't have the same shape.
1740    ///
1741    /// # Example
1742    ///
1743    /// ```rust
1744    /// use burn_tensor::backend::Backend;
1745    /// use burn_tensor::Tensor;
1746    ///
1747    /// fn example<B: Backend>() {
1748    ///     let device = Default::default();
1749    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1750    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1751    ///     // Compare the elements of the two 2D tensors for inequality.
1752    ///     // [[true, false], [false, false], [false, false]]
1753    ///     let not_equal = t1.not_equal(t2);
1754    ///     println!("{not_equal}");
1755    /// }
1756    /// ```
1757    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1758        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1759        Tensor::new(K::not_equal(self.primitive, other.primitive))
1760    }
1761
1762    /// Concatenates all tensors into a new one along the given dimension.
1763    ///
1764    /// # Panics
1765    ///
1766    /// - If `dim` is higher than the rank.
1767    /// - If `tensors` is an empty vector.
1768    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
1769    ///
1770    /// # Example
1771    ///
1772    /// ```rust
1773    /// use burn_tensor::backend::Backend;
1774    /// use burn_tensor::Tensor;
1775    ///
1776    /// fn example<B: Backend>() {
1777    ///     let device = Default::default();
1778    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
1779    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1780    ///
1781    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
1782    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
1783    ///     // The resulting tensor will have shape [2, 7].
1784    ///     let concat = Tensor::cat(vec![t1, t2], 1);
1785    ///     println!("{concat}");
1786    /// }
1787    /// ```
1788    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1789        check!(TensorCheck::cat(&tensors, dim));
1790
1791        Self::new(K::cat(
1792            tensors.into_iter().map(|vector| vector.primitive).collect(),
1793            dim,
1794        ))
1795    }
1796
1797    /// Concatenates all tensors into a new one along a new dimension.
1798    ///
1799    /// # Panics
1800    ///
1801    /// - If all tensors don't have the same shape.
1802    /// - If given dimension is not with range of 0..D2
1803    ///
1804    /// # Example
1805    ///
1806    /// ```rust
1807    /// use burn_tensor::backend::Backend;
1808    /// use burn_tensor::Tensor;
1809    ///
1810    /// fn example<B: Backend>() {
1811    ///     let device = Default::default();
1812    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1813    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1814    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1815    ///
1816    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1817    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1818    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1819    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1820    ///     // The resulting tensor will have shape [3, 2, 3].
1821    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1822    ///     println!("{stacked}");
1823    /// }
1824    /// ```
1825    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1826        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1827        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1828        Tensor::<B, D2, K>::cat(tensors, dim)
1829    }
1830
1831    /// Iterate over slices of tensors alongside a given dimension.
1832    ///
1833    /// # Panics
1834    ///
1835    /// If given dimension is greater than or equal to tensor rank.
1836    ///
1837    /// # Returns
1838    ///
1839    /// A tensor iterator.
1840    ///
1841    /// # Example
1842    ///
1843    /// ```rust
1844    /// use burn_tensor::backend::Backend;
1845    /// use burn_tensor::Tensor;
1846    /// fn example<B: Backend>() {
1847    ///   let device = Default::default();
1848    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1849    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
1850    ///   let iter = tensor.iter_dim(0);
1851    ///   for (i,tensor) in iter.enumerate() {
1852    ///     println!("Tensor {}: {}", i, tensor);
1853    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1854    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1855    ///  }
1856    /// }
1857    /// ```
1858    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1859        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1860        DimIter::new(self, dim)
1861    }
1862
1863    /// Returns a new tensor with the given dimension narrowed to the given range.
1864    ///
1865    /// # Panics
1866    ///
1867    /// - If the dimension is greater than the number of dimensions of the tensor.
1868    /// - If the given range exceeds the number of elements on the given dimension.
1869    ///
1870    /// # Returns
1871    ///
1872    /// A new tensor with the given dimension narrowed to the given range.
1873    ///
1874    /// # Example
1875    ///
1876    /// ```rust
1877    /// use burn_tensor::backend::Backend;
1878    /// use burn_tensor::Tensor;
1879    ///
1880    /// fn example<B: Backend>() {
1881    ///     let device = Default::default();
1882    ///     // Create a 2D tensor with dimensions [4, 3]
1883    ///     let tensor = Tensor::<B, 2>::from_data(
1884    ///         [
1885    ///             [3.0, 4.9, 2.0],
1886    ///             [2.0, 1.9, 3.0],
1887    ///             [6.0, 1.5, 7.0],
1888    ///             [3.0, 4.9, 9.0],
1889    ///         ],
1890    ///         &device,
1891    ///     );
1892    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1893    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1894    ///     // The resulting tensor will have dimensions [3, 3].
1895    ///     let narrowed = tensor.narrow(0, 1, 3);
1896    ///     println!("{narrowed}");
1897    /// }
1898    /// ```
1899    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1900        check!(TensorCheck::dim_ops::<D>("narrow", dim));
1901        check!(TensorCheck::narrow(&self, dim, start, length));
1902        let dims = self.dims();
1903
1904        let ranges: [Range<usize>; D] = dims
1905            .iter()
1906            .enumerate()
1907            .map(|(i, d)| {
1908                if i == dim {
1909                    start..(start + length)
1910                } else {
1911                    0..*d
1912                }
1913            })
1914            .collect::<Vec<_>>()
1915            .try_into()
1916            .unwrap();
1917
1918        Self::slice(self, ranges)
1919    }
1920
1921    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1922    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1923    ///
1924    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1925    /// Otherwise all chunks will be of equal size except for the last one.
1926    ///
1927    /// # Panics
1928    ///
1929    /// If the dimension is greater than the number of dimensions of the tensor.
1930    ///
1931    /// # Returns
1932    /// A vector of tensors.
1933    ///
1934    /// # Example
1935    ///
1936    /// ```rust
1937    /// use burn_tensor::backend::Backend;
1938    /// use burn_tensor::Tensor;
1939    ///
1940    /// fn example<B: Backend>() {
1941    ///     let device = Default::default();
1942    ///     // Create a 2D tensor with dimensions [4, 3]
1943    ///     let tensor = Tensor::<B, 2>::from_data(
1944    ///         [
1945    ///             [3.0, 4.9, 2.0],
1946    ///             [2.0, 1.9, 3.0],
1947    ///             [6.0, 1.5, 7.0],
1948    ///             [3.0, 4.9, 9.0],
1949    ///         ],
1950    ///         &device,
1951    ///     );
1952    ///     // Split the tensor along the dimension 1 into 2 chunks.
1953    ///     // The first chuck will have shape [4, 2]:
1954    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1955    ///     // The second chunk will have shape [4, 1]:
1956    ///     // [[2.0], [3.0], [7.0], [9.0]]
1957    ///     let chunks = tensor.chunk(2, 1);
1958    ///     println!("{chunks:?}");
1959    /// }
1960    /// ```
1961    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1962        check!(TensorCheck::dim_ops::<D>("chunk", dim));
1963        let size = self.shape().dims[dim];
1964        if size < chunks {
1965            return (0..size)
1966                .map(|i| Self::narrow(self.clone(), dim, i, 1))
1967                .collect();
1968        }
1969
1970        let mut tensors = Vec::with_capacity(chunks);
1971        let mut sum_chunk_size = 0;
1972        if size.is_multiple_of(chunks) {
1973            let chunk_size = size / chunks;
1974            for _ in 0..chunks {
1975                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
1976                sum_chunk_size += chunk_size;
1977            }
1978        } else {
1979            let chunk_size = (size / chunks) + 1; // assumes not divisible
1980            for _ in 0..chunks - 1 {
1981                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
1982                sum_chunk_size += chunk_size;
1983            }
1984            let remainder = size % chunk_size;
1985            tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
1986        }
1987
1988        tensors
1989    }
1990
1991    /// Splits the tensor into chunks of a specified size along a given dimension.
1992    /// Each chunk is a view of the original tensor.
1993    ///
1994    /// If the tensor size along the given dimension is not divisible by `split_size`,
1995    /// then the last chunk will be smaller.
1996    ///
1997    /// # Panics
1998    ///
1999    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2000    ///
2001    /// # Returns
2002    ///
2003    /// A vector of tensors.
2004    ///
2005    /// # Example
2006    /// ```rust
2007    /// use burn_tensor::backend::Backend;
2008    /// use burn_tensor::Tensor;
2009    ///
2010    /// fn example<B: Backend>() {
2011    ///     let device = Default::default();
2012    ///     // Create a 1D tensor with 5 elements
2013    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2014    ///     // Split the tensor into chunks of size 2 along dimension 0
2015    ///     let chunks = tensor.split(2, 0);
2016    ///     // The result is a vector of tensors:
2017    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2018    ///     println!("{:?}", chunks);
2019    /// }
2020    /// ```
2021    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2022        check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2023        let size = self.shape().dims[dim];
2024        let mut tensors = Vec::new();
2025
2026        let mut start = 0;
2027        while start < size {
2028            let length = usize::min(split_size, size - start);
2029            tensors.push(Self::narrow(self.clone(), dim, start, length));
2030            start += length;
2031        }
2032
2033        tensors
2034    }
2035
2036    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2037    /// Each chunk is a view of the original tensor.
2038    ///
2039    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2040    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2041    ///
2042    /// # Panics
2043    ///
2044    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2045    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2046    ///
2047    /// # Returns
2048    ///
2049    /// A vector of tensors.
2050    ///
2051    /// # Example
2052    /// ```rust
2053    /// use burn_tensor::backend::Backend;
2054    /// use burn_tensor::Tensor;
2055    ///
2056    /// fn example<B: Backend>() {
2057    ///     let device = Default::default();
2058    ///     // Create a 1D tensor with 5 elements
2059    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2060    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
2061    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2062    ///     // The result is a vector of tensors:
2063    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2064    ///     println!("{:?}", chunks);
2065    /// }
2066    /// ```
2067    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2068        check!(TensorCheck::split_with_sizes::<D>(
2069            &self.shape(),
2070            &split_sizes,
2071            dim
2072        ));
2073        let mut tensors = Vec::new();
2074
2075        let mut start = 0;
2076        for length in split_sizes {
2077            if length == 0 {
2078                continue;
2079            }
2080            tensors.push(Self::narrow(self.clone(), dim, start, length));
2081            start += length;
2082        }
2083
2084        tensors
2085    }
2086
2087    /// Tests if any element in the `tensor` evaluates to True.
2088    ///
2089    /// # Arguments
2090    ///
2091    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2092    ///
2093    /// # Returns
2094    ///
2095    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2096    /// evaluates to True, False otherwise.
2097    ///
2098    /// # Example
2099    ///
2100    /// ```rust
2101    /// use burn_tensor::backend::Backend;
2102    /// use burn_tensor::{Tensor, Bool};
2103    ///
2104    /// fn example<B: Backend>() {
2105    ///   let device = Default::default();
2106    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2107    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2108    ///
2109    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2110    ///   let any_tensor = tensor.any();
2111    ///   println!("{}", any_tensor);
2112    ///   // Tensor { data: [true], ... }
2113    ///
2114    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2115    ///   let any_tensor_two = tensor_two.any();
2116    ///   println!("{}", any_tensor_two);
2117    ///   // Tensor { data: [false], ... }
2118    /// }
2119    /// ```
2120    pub fn any(self) -> Tensor<B, 1, Bool> {
2121        Tensor::new(K::any(self.primitive))
2122    }
2123
2124    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2125    ///
2126    /// # Arguments
2127    ///
2128    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2129    /// * `dim` - The axis along which to test.
2130    ///
2131    /// # Returns
2132    ///
2133    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2134    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2135    /// evaluates to True, False otherwise.
2136    ///
2137    /// # Example
2138    ///
2139    /// ```rust
2140    /// use burn_tensor::backend::Backend;
2141    /// use burn_tensor::{Tensor, Bool};
2142    ///
2143    /// fn example<B: Backend>() {
2144    ///     let device = Default::default();
2145    ///     let tensor =
2146    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2147    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
2148    ///     // [[true], [true]],
2149    ///     let any_dim = tensor.clone().any_dim(1);
2150    ///     println!("{any_dim}");
2151    /// }
2152    /// ```
2153    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2154        Tensor::new(K::any_dim(self.primitive, dim))
2155    }
2156
2157    /// Tests if all elements in the `tensor` evaluate to True.
2158    ///
2159    /// # Arguments
2160    ///
2161    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2162    ///
2163    /// # Returns
2164    ///
2165    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2166    /// evaluate to True, False otherwise.
2167    ///
2168    /// # Example
2169    ///
2170    /// ```rust
2171    /// use burn_tensor::backend::Backend;
2172    /// use burn_tensor::{Tensor, Bool};
2173    ///
2174    /// fn example<B: Backend>() {
2175    ///     let device = Default::default();
2176    ///     let tensor =
2177    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2178    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
2179    ///     // [false]
2180    ///     let all = tensor.all();
2181    ///     println!("{all}");
2182    /// }
2183    /// ```
2184    pub fn all(self) -> Tensor<B, 1, Bool> {
2185        Tensor::new(K::all(self.primitive))
2186    }
2187
2188    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2189    ///
2190    /// # Arguments
2191    ///
2192    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2193    /// * `dim` - The axis along which to test.
2194    ///
2195    /// # Returns
2196    ///
2197    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2198    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2199    /// evaluates to True, False otherwise.
2200    ///
2201    /// # Example
2202    ///
2203    /// ```rust
2204    /// use burn_tensor::backend::Backend;
2205    /// use burn_tensor::{Tensor, Bool};
2206    ///
2207    /// fn example<B: Backend>() {
2208    ///     let device = Default::default();
2209    ///     let tensor =
2210    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2211    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
2212    ///     // [[true, true, false]]
2213    ///     let all_dim = tensor.clone().all_dim(0);
2214    ///     println!("{all_dim}");
2215    /// }
2216    /// ```
2217    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2218        Tensor::new(K::all_dim(self.primitive, dim))
2219    }
2220
2221    /// Convert the tensor into a scalar.
2222    ///
2223    /// # Panics
2224    ///
2225    /// - If the tensor doesn't have one element.
2226    /// - If the backend fails to read the tensor data synchronously.
2227    ///
2228    /// # Returns
2229    ///
2230    /// The scalar value of the tensor.
2231    ///
2232    /// # Example
2233    ///
2234    /// ```rust
2235    /// use burn_tensor::backend::Backend;
2236    /// use burn_tensor::Tensor;
2237    ///
2238    /// fn example<B: Backend>() {
2239    ///     let device = Default::default();
2240    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2241    ///     // Convert the tensor with a single element into a scalar.
2242    ///     let scalar = tensor.into_scalar();
2243    ///     println!("{scalar}");
2244    /// }
2245    /// ```
2246    pub fn into_scalar(self) -> K::Elem {
2247        crate::try_read_sync(self.into_scalar_async()).expect(
2248            "Failed to read tensor data synchronously. This can happen on platforms
2249            that don't support blocking futures like WASM. Try into_scalar_async instead.",
2250        )
2251    }
2252
2253    /// Convert the tensor into a scalar.
2254    ///
2255    /// # Panics
2256    ///
2257    /// If the tensor doesn't have one element.
2258    pub async fn into_scalar_async(self) -> K::Elem {
2259        check!(TensorCheck::into_scalar::<D>(&self.shape()));
2260
2261        self.into_data_async().await.iter().next().unwrap()
2262    }
2263
2264    /// Broadcast the tensor to the given shape.
2265    ///
2266    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2267    /// (which can be inferred with `-1`).
2268    ///
2269    /// # Arguments
2270    ///
2271    /// * `shape` - The shape to broadcast the tensor to.
2272    ///   Can contain -1 for dimensions that should be inferred.
2273    ///   The number of elements in the shape must be greater or equal as
2274    ///   the number of dimensions of the tensor.
2275    ///
2276    /// # Panics
2277    ///
2278    /// If the tensor cannot be broadcasted to the given shape.
2279    ///
2280    /// # Returns
2281    ///
2282    /// A new tensor with the given shape.
2283    ///
2284    /// # Example
2285    ///
2286    /// ```rust
2287    /// use burn_tensor::backend::Backend;
2288    /// use burn_tensor::Tensor;
2289    ///
2290    /// fn example<B: Backend>() {
2291    ///     let device = Default::default();
2292    ///     // Create a 2D tensor with dimensions [3, 1]
2293    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2294    ///     // Expand the tensor to a new shape [3, 4]
2295    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2296    ///     let expanded = tensor.expand([3, 4]);
2297    ///     println!("{}", expanded);
2298    /// }
2299    /// ```
2300    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2301        let shape = shape.into_shape(&self.shape());
2302        check!(TensorCheck::expand::<D, D2>(
2303            "expand",
2304            &self.shape(),
2305            &shape,
2306        ));
2307
2308        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2309    }
2310
2311    /// Unfold windows along a dimension.
2312    ///
2313    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2314    /// where windows are advanced by `step` at each index.
2315    ///
2316    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2317    ///
2318    /// The new view will have the unfolded dimension replaced by two dimensions;
2319    /// one in the position of the original dimension, with size equal to the number of windows,
2320    /// and one appended to the right-most position, with size equal to `size`.
2321    ///
2322    /// # Warning
2323    ///
2324    /// For the `ndarray` and `candle` backends; this is not a view but a copy
2325    /// with duplicated data.
2326    ///
2327    /// # Arguments
2328    ///
2329    /// * `dim` - the dimension to unfold.
2330    /// * `size` - the size of each unfolded window.
2331    /// * `step` - the step between each window.
2332    ///
2333    /// # Returns
2334    ///
2335    /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2336    pub fn unfold<const D2: usize, I: AsIndex>(
2337        self,
2338        dim: I,
2339        size: usize,
2340        step: usize,
2341    ) -> Tensor<B, D2, K> {
2342        let dim = canonicalize_dim(dim, D, false);
2343        check!(TensorCheck::unfold::<D, D2>(
2344            "unfold",
2345            &self.shape(),
2346            dim,
2347            size,
2348            step,
2349        ));
2350        Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2351    }
2352}
2353
2354/// Iterator given by (Tensor::iter_dim).
2355pub struct DimIter<B, const D: usize, K>
2356where
2357    B: Backend,
2358    K: BasicOps<B>,
2359{
2360    start: usize,
2361    end: usize,
2362    dim: usize,
2363    ranges: [Range<usize>; D],
2364    tensor: Tensor<B, D, K>,
2365}
2366
2367impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2368    type Item = Tensor<B, D, K>;
2369
2370    fn next(&mut self) -> Option<Self::Item> {
2371        if self.start >= self.end {
2372            return None;
2373        }
2374
2375        let mut ranges = self.ranges.clone();
2376        ranges[self.dim] = self.start..(self.start + 1);
2377
2378        let slice = self.tensor.clone().slice(ranges);
2379        self.start += 1;
2380
2381        Some(slice)
2382    }
2383}
2384
2385impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2386    fn next_back(&mut self) -> Option<Self::Item> {
2387        if self.start >= self.end {
2388            return None;
2389        }
2390
2391        let mut ranges = self.ranges.clone();
2392        ranges[self.dim] = (self.end - 1)..self.end;
2393
2394        let slice = self.tensor.clone().slice(ranges);
2395        self.end = self.end.saturating_sub(1);
2396
2397        Some(slice)
2398    }
2399}
2400
2401impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2402    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2403        let dims = tensor.dims();
2404        let ranges = dims
2405            .iter()
2406            .map(|&dim| 0..dim)
2407            .collect::<Vec<Range<usize>>>();
2408        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2409        Self {
2410            end: dims[dim],
2411            ranges,
2412            start: 0,
2413            dim,
2414            tensor,
2415        }
2416    }
2417}
2418
2419impl<B, const D: usize, K> Tensor<B, D, K>
2420where
2421    B: Backend,
2422    K: BasicOps<B>,
2423    <K as BasicOps<B>>::Elem: Debug,
2424{
2425    #[inline]
2426    fn push_newline_indent(acc: &mut String, indent: usize) {
2427        acc.push('\n');
2428        for _ in 0..indent {
2429            acc.push(' ');
2430        }
2431    }
2432    fn fmt_inner_tensor(
2433        &self,
2434        acc: &mut String,
2435        depth: usize,
2436        multi_index: &mut [usize],
2437        range: (usize, usize),
2438        precision: Option<usize>,
2439    ) {
2440        let (start, end) = range;
2441        for i in start..end {
2442            if i > 0 {
2443                acc.push_str(", ");
2444            }
2445            multi_index[depth] = i;
2446            let range: [Range<usize>; D] =
2447                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2448
2449            let data =
2450                burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
2451
2452            if let Some(data) = data {
2453                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2454                match (precision, K::name()) {
2455                    (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2456                    (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2457                    _ => acc.push_str(&format!("{elem:?}")),
2458                }
2459            } else {
2460                acc.push_str("<Tensor data not available>");
2461            }
2462        }
2463    }
2464
2465    fn fmt_outer_tensor(
2466        &self,
2467        acc: &mut String,
2468        depth: usize,
2469        multi_index: &mut [usize],
2470        print_options: &PrintOptions,
2471        summarize: bool,
2472        range: (usize, usize),
2473    ) {
2474        let (start, end) = range;
2475        for i in start..end {
2476            if i > start {
2477                acc.push(',');
2478                Self::push_newline_indent(acc, depth + 1);
2479            }
2480            acc.push('[');
2481            multi_index[depth] = i;
2482            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2483            acc.push(']');
2484        }
2485    }
2486
2487    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2488    ///
2489    /// This function is designed to work with tensors of any dimensionality.
2490    /// It traverses the tensor dimensions recursively, converting the elements
2491    /// to strings and appending them to the accumulator string with the
2492    /// appropriate formatting.
2493    ///
2494    /// # Arguments
2495    ///
2496    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2497    /// * `depth` - The current depth of the tensor dimensions being processed.
2498    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2499    fn display_recursive(
2500        &self,
2501        acc: &mut String,
2502        depth: usize,
2503        multi_index: &mut [usize],
2504        print_options: &PrintOptions,
2505        summarize: bool,
2506    ) {
2507        let edge_items = print_options.edge_items;
2508
2509        if depth == 0 {
2510            acc.push('[');
2511        }
2512
2513        if depth == self.dims().len() - 1 {
2514            // if we are at the innermost dimension, just push its elements into the accumulator
2515            if summarize && self.dims()[depth] > 2 * edge_items {
2516                // print the starting `edge_items` elements
2517                self.fmt_inner_tensor(
2518                    acc,
2519                    depth,
2520                    multi_index,
2521                    (0, edge_items),
2522                    print_options.precision,
2523                );
2524                acc.push_str(", ...");
2525                // print the last `edge_items` elements
2526                self.fmt_inner_tensor(
2527                    acc,
2528                    depth,
2529                    multi_index,
2530                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2531                    print_options.precision,
2532                );
2533            } else {
2534                // print all the elements
2535                self.fmt_inner_tensor(
2536                    acc,
2537                    depth,
2538                    multi_index,
2539                    (0, self.dims()[depth]),
2540                    print_options.precision,
2541                );
2542            }
2543        } else {
2544            // otherwise, iterate through the current dimension and recursively display the inner tensors
2545            if summarize && self.dims()[depth] > 2 * edge_items {
2546                self.fmt_outer_tensor(
2547                    acc,
2548                    depth,
2549                    multi_index,
2550                    print_options,
2551                    summarize,
2552                    (0, edge_items),
2553                );
2554
2555                acc.push(',');
2556                Self::push_newline_indent(acc, depth + 1);
2557                acc.push_str("...");
2558                Self::push_newline_indent(acc, depth + 1);
2559
2560                self.fmt_outer_tensor(
2561                    acc,
2562                    depth,
2563                    multi_index,
2564                    print_options,
2565                    summarize,
2566                    (self.dims()[depth] - edge_items, self.dims()[depth]),
2567                );
2568            } else {
2569                self.fmt_outer_tensor(
2570                    acc,
2571                    depth,
2572                    multi_index,
2573                    print_options,
2574                    summarize,
2575                    (0, self.dims()[depth]),
2576                );
2577            }
2578        }
2579
2580        if depth == 0 {
2581            acc.push(']');
2582        }
2583    }
2584}
2585
2586#[derive(Clone, Debug)]
2587/// Options for Tensor pretty printing
2588pub struct PrintOptions {
2589    /// number of elements to start summarizing tensor
2590    pub threshold: usize,
2591
2592    /// number of starting elements and ending elements to display
2593    pub edge_items: usize,
2594
2595    /// Precision for floating point numbers
2596    pub precision: Option<usize>,
2597}
2598
2599static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
2600
2601impl PrintOptions {
2602    /// Print options with default values
2603    pub const fn const_default() -> Self {
2604        Self {
2605            threshold: 1000,
2606            edge_items: 3,
2607            precision: None,
2608        }
2609    }
2610}
2611
2612impl Default for PrintOptions {
2613    fn default() -> Self {
2614        Self::const_default()
2615    }
2616}
2617
2618/// Set print options
2619pub fn set_print_options(options: PrintOptions) {
2620    let mut print_opts = PRINT_OPTS.write().unwrap();
2621    *print_opts = options;
2622}
2623
2624/// Pretty print tensors
2625impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
2626where
2627    B: Backend,
2628    B::IntElem: core::fmt::Display,
2629    K: BasicOps<B>,
2630    <K as BasicOps<B>>::Elem: Debug,
2631{
2632    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2633        writeln!(f, "Tensor {{")?;
2634
2635        {
2636            // Do not lock the mutex for the whole function
2637            let mut po = { PRINT_OPTS.read().unwrap().clone() };
2638
2639            // Override the precision if it is set from the formatter
2640            // This will be possible when the tensor is printed using the `{:.*}` syntax
2641            if let Some(precision) = f.precision() {
2642                po.precision = Some(precision);
2643            }
2644
2645            let mut acc = String::new();
2646            let mut multi_index = vec![0; D];
2647            let summarize = self.shape().num_elements() > po.threshold;
2648
2649            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
2650
2651            writeln!(f, "  data:")?;
2652            write!(f, "{acc}")?;
2653            writeln!(f, ",")?;
2654        }
2655
2656        writeln!(f, "  shape:  {:?},", self.dims())?;
2657        writeln!(f, "  device:  {:?},", self.device())?;
2658        writeln!(f, "  backend:  {:?},", B::name(&self.device()))?;
2659        writeln!(f, "  kind:  {:?},", K::name())?;
2660
2661        let dtype = self.primitive.dtype();
2662
2663        writeln!(f, "  dtype:  {:?},", dtype.name())?;
2664        write!(f, "}}")
2665    }
2666}
2667
2668/// Trait that list all operations that can be applied on all tensors.
2669///
2670/// # Warnings
2671///
2672/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2673pub trait BasicOps<B: Backend>: TensorKind<B> {
2674    /// The type of the tensor elements.
2675    type Elem: Element;
2676
2677    /// Creates an empty tensor with the given shape.
2678    ///
2679    /// # Arguments
2680    ///
2681    /// * `shape` - The shape of the tensor.
2682    /// * `device` - The device on which the tensor will be allocated.
2683    /// * `dtype` - The target data type.
2684    ///
2685    /// # Returns
2686    ///
2687    /// The empty tensor.
2688    ///
2689    /// # Remarks
2690    ///
2691    /// This is a low-level function used internally by the library to call different backend functions
2692    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2693    /// or use this function directly.
2694    ///
2695    /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
2696    /// which is more high-level and designed for public use.
2697    fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
2698
2699    /// Creates a tensor of the given shape where each element is equal to the provided value.
2700    ///
2701    /// # Arguments
2702    ///
2703    /// * `shape` - The shape of the tensor.
2704    /// * `fill_value` - The value with which to fill the tensor.
2705    /// * `device` - The device on which the tensor will be allocated.
2706    /// * `dtype` - The target data type.
2707    ///
2708    /// # Returns
2709    ///
2710    /// The empty tensor.
2711    ///
2712    /// # Remarks
2713    ///
2714    /// This is a low-level function used internally by the library to call different backend functions
2715    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2716    /// or use this function directly.
2717    ///
2718    /// For creating full tensors, users should prefer the [Tensor::full](Tensor::full) function,
2719    /// which is more high-level and designed for public use.
2720    fn full<E: ElementConversion>(
2721        shape: Shape,
2722        fill_value: E,
2723        device: &B::Device,
2724        dtype: DType,
2725    ) -> Self::Primitive;
2726
2727    /// Reshapes the tensor.
2728    ///
2729    /// # Arguments
2730    ///
2731    /// * `tensor` - The tensor.
2732    /// * `shape` - The new shape of the tensor.
2733    ///
2734    /// # Returns
2735    ///
2736    /// The reshaped tensor.
2737    ///
2738    /// # Remarks
2739    ///
2740    /// This is a low-level function used internally by the library to call different backend functions
2741    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2742    /// or use this function directly.
2743    ///
2744    /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
2745    /// which is more high-level and designed for public use.
2746    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2747
2748    /// Transposes a tensor.
2749    ///
2750    /// # Arguments
2751    ///
2752    /// * `tensor` - The tensor to transpose.
2753    ///
2754    /// # Returns
2755    ///
2756    /// The transposed tensor.
2757    fn transpose(tensor: Self::Primitive) -> Self::Primitive;
2758
2759    /// Swaps two dimensions of a tensor.
2760    ///
2761    /// # Arguments
2762    ///
2763    /// * `tensor` - The tensor to swap the dimensions of.
2764    /// * `dim1` - The first dimension to swap.
2765    /// * `dim2` - The second dimension to swap.
2766    ///
2767    /// # Returns
2768    ///
2769    /// The tensor with the dimensions swapped.
2770    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
2771
2772    /// Permutes the dimensions of a tensor.
2773    ///
2774    /// # Arguments
2775    ///
2776    /// * `tensor` - The tensor to permute the dimensions of.
2777    /// * `axes` - The new order of the dimensions.
2778    ///
2779    /// # Returns
2780    ///
2781    /// The tensor with the dimensions permuted.
2782    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2783
2784    /// Flips the tensor along the given axes.
2785    ///
2786    /// # Arguments
2787    ///
2788    /// * `tensor` - The tensor to flip.
2789    /// * `axes` - The axes to flip the tensor along.
2790    ///
2791    /// # Returns
2792    ///
2793    /// The tensor with the axes flipped.
2794    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2795
2796    ///  Select tensor elements corresponding to the given slices.
2797    ///
2798    /// # Arguments
2799    ///
2800    /// * `tensor` - The tensor.
2801    /// * `slices` - The slices specifying ranges and steps for each dimension.
2802    ///
2803    /// # Returns
2804    ///
2805    /// The selected elements.
2806    ///
2807    /// # Remarks
2808    ///
2809    /// This is a low-level function used internally by the library to call different backend functions
2810    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2811    /// or use this function directly.
2812    ///
2813    /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2814    /// which is more high-level and designed for public use.
2815    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
2816
2817    /// Assigns the given value to the tensor elements corresponding to the given slices.
2818    ///
2819    /// # Arguments
2820    ///
2821    /// * `tensor` - The tensor.
2822    /// * `slices` - The slices specifying which elements to assign, including support for steps.
2823    /// * `value` - The value to assign.
2824    ///
2825    /// # Returns
2826    ///
2827    /// The tensor with the assigned values.
2828    ///
2829    /// # Remarks
2830    ///
2831    /// This is a low-level function used internally by the library to call different backend functions
2832    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2833    /// or use this function directly.
2834    ///
2835    /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2836    /// which is more high-level and designed for public use.
2837    fn slice_assign(
2838        tensor: Self::Primitive,
2839        slices: &[Slice],
2840        value: Self::Primitive,
2841    ) -> Self::Primitive;
2842
2843    /// Fills the tensor elements corresponding to the given slices with the given value.
2844    ///
2845    /// # Arguments
2846    ///
2847    /// * `tensor` - The tensor.
2848    /// * `slices` - The slices specifying ranges and steps for each dimension.
2849    /// * `value` - The value to fill.
2850    ///
2851    /// # Returns
2852    ///
2853    /// The tensor with the filled values at the specified slice positions.
2854    ///
2855    /// # Remarks
2856    ///
2857    /// This is a low-level function used internally by the library to call different backend functions
2858    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2859    /// or use this function directly.
2860    ///
2861    /// For filling values in a tensor, users should prefer the [Tensor::slice_fill](Tensor::slice_fill) function,
2862    /// which is more high-level and designed for public use.
2863    fn slice_fill(tensor: Self::Primitive, slices: &[Slice], value: Self::Elem) -> Self::Primitive {
2864        let slice_shape = tensor.shape().slice(slices).unwrap();
2865        let value = Self::from_data_dtype(
2866            TensorData::from([value]),
2867            &Self::device(&tensor),
2868            Self::dtype(&tensor),
2869        );
2870        let value = Self::expand(value, slice_shape);
2871        Self::slice_assign(tensor, slices, value)
2872    }
2873
2874    /// Slices the tensor along a given dimension.
2875    ///
2876    /// # Arguments
2877    ///
2878    /// * `tensor` - The tensor to slice.
2879    /// * `dim` - The dimension along which to slice.
2880    /// * `slice` - The slice information containing range and step.
2881    ///
2882    /// # Returns
2883    ///
2884    /// The sliced tensor.
2885    ///
2886    /// # Remarks
2887    ///
2888    /// This is a low-level function used internally by the library to call different backend functions
2889    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2890    /// or use this function directly.
2891    fn slice_dim(tensor: Self::Primitive, dim: usize, slice: &Slice) -> Self::Primitive {
2892        let mut slices = vec![Slice::full(); tensor.shape().num_dims()];
2893        slices[dim] = *slice;
2894
2895        Self::slice(tensor, &slices)
2896    }
2897
2898    /// Select tensor elements along the given dimension corresponding to the given indices.
2899    ///
2900    /// # Arguments
2901    ///
2902    /// * `tensor` - The tensor to select from.
2903    /// * `dim` - The dimension along which to select.
2904    /// * `indices` - The indices of the elements to select.
2905    ///
2906    /// # Returns
2907    ///
2908    /// The selected tensor elements.
2909    ///
2910    /// # Remarks
2911    ///
2912    /// This is a low-level function used internally by the library to call different backend functions
2913    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2914    /// or use this function directly.
2915    ///
2916    /// For selecting elements from a tensor along an axis, users should prefer the
2917    /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
2918    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive;
2919
2920    /// Assign the selected elements along the given dimension corresponding to the given indices
2921    /// from the value tensor.
2922    ///
2923    /// # Arguments
2924    ///
2925    /// * `tensor` - The tensor to assign elements to.
2926    /// * `dim` - The axis along which to assign elements.
2927    /// * `indices` - The indices of the elements to assign.
2928    /// * `values` - The values to assign to the tensor.
2929    ///
2930    /// # Returns
2931    ///
2932    /// A tensor with the same shape as the input tensor, where each element is taken from the
2933    /// corresponding element of the input tensor at the corresponding index along the specified axis,
2934    /// except for the elements at the specified indices, which are taken from the corresponding
2935    /// element of the values tensor.
2936    ///
2937    /// # Remarks
2938    ///
2939    /// This is a low-level function used internally by the library to call different backend functions
2940    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2941    /// or use this function directly.
2942    ///
2943    /// For assigning elements to a tensor along an axis, users should prefer the
2944    /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
2945    fn select_assign(
2946        tensor: Self::Primitive,
2947        dim: usize,
2948        indices: Tensor<B, 1, Int>,
2949        values: Self::Primitive,
2950    ) -> Self::Primitive;
2951
2952    /// Returns the device on which the tensor is allocated.
2953    ///
2954    /// # Arguments
2955    ///
2956    /// * `tensor` - The tensor.
2957    ///
2958    /// # Returns
2959    ///
2960    /// The device on which the tensor is allocated.
2961    ///
2962    /// # Remarks
2963    ///
2964    /// This is a low-level function used internally by the library to call different backend functions
2965    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2966    /// or use this function directly.
2967    ///
2968    /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2969    /// which is more high-level and designed for public use.
2970    fn device(tensor: &Self::Primitive) -> B::Device;
2971
2972    /// Moves the tensor to the given device.
2973    ///
2974    /// # Arguments
2975    ///
2976    /// * `tensor` - The tensor.
2977    /// * `device` - The device on which the tensor will be moved.
2978    ///
2979    /// # Returns
2980    ///
2981    /// The tensor on the given device.
2982    ///
2983    /// # Remarks
2984    ///
2985    /// This is a low-level function used internally by the library to call different backend functions
2986    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2987    /// or use this function directly.
2988    ///
2989    /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2990    /// which is more high-level and designed for public use.
2991    #[allow(clippy::wrong_self_convention)]
2992    fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2993
2994    /// Extracts the data from the tensor asynchronously.
2995    ///
2996    /// # Arguments
2997    ///
2998    /// * `tensor` - The tensor.
2999    ///
3000    /// # Returns
3001    ///
3002    /// The data of the tensor.
3003    ///
3004    /// # Remarks
3005    ///
3006    /// This is a low-level function used internally by the library to call different backend functions
3007    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3008    /// or use this function directly.
3009    ///
3010    /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
3011    /// which is more high-level and designed for public use.
3012    #[allow(clippy::wrong_self_convention)]
3013    fn into_data_async(tensor: Self::Primitive) -> impl Future<Output = TensorData> + Send;
3014
3015    /// Read the data from the tensor using a transaction.
3016    ///
3017    /// # Remarks
3018    ///
3019    /// This is a low-level function used internally by the library to call different backend functions
3020    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3021    /// or use this function directly.
3022    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
3023
3024    /// Creates a tensor from the given data.
3025    ///
3026    /// # Arguments
3027    ///
3028    /// * `data` - The data of the tensor.
3029    /// * `device` - The device on which the tensor will be allocated.
3030    ///
3031    /// # Returns
3032    ///
3033    /// The tensor.
3034    ///
3035    /// # Remarks
3036    ///
3037    /// This is a low-level function used internally by the library to call different backend functions
3038    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3039    /// or use this function directly.
3040    ///
3041    /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
3042    /// which is more high-level and designed for public use.
3043    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
3044    /// Creates a tensor from the given data enforcing the given data type.
3045    ///
3046    /// # Remarks
3047    ///
3048    /// This is a low-level function used internally by the library to call different backend functions
3049    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3050    /// or use this function directly.
3051    ///
3052    /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype)
3053    /// function, which is more high-level and designed for public use.
3054    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
3055
3056    /// Repeat the tensor along the given dimension.
3057    ///
3058    /// # Arguments
3059    ///
3060    /// * `tensor` - The tensor.
3061    /// * `dim` - The dimension along which the tensor will be repeated.
3062    /// * `times` - The number of times the tensor will be repeated.
3063    ///
3064    /// # Returns
3065    ///
3066    /// The repeated tensor.
3067    ///
3068    /// # Remarks
3069    ///
3070    /// This is a low-level function used internally by the library to call different backend functions
3071    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3072    /// or use this function directly.
3073    ///
3074    /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
3075    /// which is more high-level and designed for public use.
3076    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
3077
3078    /// Concatenates the given tensors along the given dimension.
3079    ///
3080    /// # Arguments
3081    ///
3082    /// * `vectors` - The tensors to concatenate.
3083    /// * `dim` - The dimension along which the tensors will be concatenated.
3084    ///
3085    /// # Returns
3086    ///
3087    /// The concatenated tensor.
3088    ///
3089    /// # Remarks
3090    ///
3091    /// This is a low-level function used internally by the library to call different backend functions
3092    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3093    /// or use this function directly.
3094    ///
3095    /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
3096    /// which is more high-level and designed for public use.
3097    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
3098
3099    /// Equates the given tensors.
3100    ///
3101    /// # Arguments
3102    ///
3103    /// * `lhs` - The left hand side tensor.
3104    /// * `rhs` - The right hand side tensor.
3105    ///
3106    /// # Returns
3107    ///
3108    /// The tensor of booleans indicating whether the corresponding elements are equal.
3109    ///
3110    /// # Remarks
3111    ///
3112    /// This is a low-level function used internally by the library to call different backend functions
3113    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3114    /// or use this function directly.
3115    ///
3116    /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
3117    /// which is more high-level and designed for public use.
3118    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3119
3120    /// Applies element-wise non-equality comparison between the given tensors.
3121    ///
3122    /// # Arguments
3123    ///
3124    /// * `lhs` - The left hand side tensor.
3125    /// * `rhs` - The right hand side tensor.
3126    ///
3127    /// # Returns
3128    ///
3129    /// The tensor of booleans indicating whether the corresponding elements are equal.
3130    ///
3131    /// # Remarks
3132    ///
3133    /// This is a low-level function used internally by the library to call different backend functions
3134    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3135    /// or use this function directly.
3136    ///
3137    /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
3138    /// function, which is more high-level and designed for public use.
3139    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3140
3141    /// Returns the name of the element type.
3142    fn elem_type_name() -> &'static str {
3143        core::any::type_name::<Self::Elem>()
3144    }
3145
3146    /// Returns the tensor data type.
3147    fn dtype(tensor: &Self::Primitive) -> DType {
3148        tensor.dtype()
3149    }
3150
3151    /// Tests if any element in the `tensor` evaluates to True.
3152    ///
3153    /// # Arguments
3154    ///
3155    /// * `tensor` - The tensor to test.
3156    ///
3157    /// # Returns
3158    ///
3159    /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
3160    ///
3161    /// # Remarks
3162    ///
3163    /// This is a low-level function used internally by the library to call different backend functions
3164    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3165    /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
3166    /// which is more high-level and designed for public use.
3167    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
3168
3169    /// Tests if any element in the tensor evaluates to True along a given dimension dim.
3170    ///
3171    /// # Arguments
3172    ///
3173    /// * tensor - The tensor to test.
3174    /// * dim - The axis along which to test.
3175    ///
3176    /// # Returns
3177    ///
3178    /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
3179    /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
3180    ///
3181    /// # Remarks
3182    ///
3183    /// This is a low-level function used internally by the library to call different backend functions
3184    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3185    /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
3186    /// which is more high-level and designed for public use.
3187    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
3188
3189    /// Tests if all elements in the `tensor` evaluate to True.
3190    ///
3191    /// # Arguments
3192    ///
3193    /// * `tensor` - The tensor to test.
3194    ///
3195    /// # Returns
3196    ///
3197    /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
3198    ///
3199    /// # Remarks
3200    ///
3201    /// This is a low-level function used internally by the library to call different backend functions
3202    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3203    /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
3204    /// which is more high-level and designed for public use.
3205    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
3206
3207    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
3208    ///
3209    /// # Arguments
3210    ///
3211    /// * `tensor` - The tensor to test.
3212    ///
3213    /// # Returns
3214    ///
3215    /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
3216    /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
3217    ///
3218    /// # Remarks
3219    ///
3220    /// This is a low-level function used internally by the library to call different backend functions
3221    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3222    /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
3223    /// which is more high-level and designed for public use.
3224    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
3225
3226    /// Broadcasts the given tensor to the specified shape.
3227    ///
3228    /// # Arguments
3229    ///
3230    /// * `tensor` - The tensor to broadcast.
3231    /// * `shape` - The shape to broadcast to.
3232    ///
3233    /// # Returns
3234    ///
3235    /// The broadcasted tensor.
3236    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
3237
3238    /// Unfold windows along a dimension.
3239    ///
3240    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
3241    /// where windows are advanced by `step` at each index.
3242    ///
3243    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
3244    ///
3245    /// # Warning
3246    ///
3247    /// For the `ndarray` and `candle` backends; this is not a view but a full copy.
3248    ///
3249    /// # Arguments
3250    ///
3251    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
3252    /// * `dim` - the dimension to unfold.
3253    /// * `size` - the size of each unfolded window.
3254    /// * `step` - the step between each window.
3255    ///
3256    /// # Returns
3257    ///
3258    /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
3259    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
3260}
3261
3262impl<B: Backend> BasicOps<B> for Float {
3263    type Elem = B::FloatElem;
3264
3265    fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3266        TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
3267    }
3268
3269    fn full<E: ElementConversion>(
3270        shape: Shape,
3271        fill_value: E,
3272        device: &B::Device,
3273        dtype: DType,
3274    ) -> Self::Primitive {
3275        TensorPrimitive::Float(B::float_full(
3276            shape,
3277            fill_value.elem(),
3278            device,
3279            dtype.into(),
3280        ))
3281    }
3282
3283    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
3284        tr.register_float(tensor);
3285    }
3286
3287    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3288        match tensor {
3289            TensorPrimitive::Float(tensor) => {
3290                TensorPrimitive::Float(B::float_reshape(tensor, shape))
3291            }
3292            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
3293        }
3294    }
3295
3296    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
3297        match tensor {
3298            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
3299            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
3300        }
3301    }
3302
3303    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
3304        match tensor {
3305            TensorPrimitive::Float(tensor) => {
3306                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
3307            }
3308            TensorPrimitive::QFloat(tensor) => {
3309                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
3310            }
3311        }
3312    }
3313
3314    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
3315        match tensor {
3316            TensorPrimitive::Float(tensor) => {
3317                TensorPrimitive::Float(B::float_slice(tensor, slices))
3318            }
3319            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
3320        }
3321    }
3322
3323    fn slice_assign(
3324        tensor: Self::Primitive,
3325        slices: &[Slice],
3326        value: Self::Primitive,
3327    ) -> Self::Primitive {
3328        TensorPrimitive::Float(B::float_slice_assign(
3329            tensor.tensor(),
3330            slices,
3331            value.tensor(),
3332        ))
3333    }
3334
3335    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3336        match tensor {
3337            TensorPrimitive::Float(tensor) => {
3338                TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive))
3339            }
3340            TensorPrimitive::QFloat(tensor) => {
3341                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive))
3342            }
3343        }
3344    }
3345
3346    fn select_assign(
3347        tensor: Self::Primitive,
3348        dim: usize,
3349        indices: Tensor<B, 1, Int>,
3350        values: Self::Primitive,
3351    ) -> Self::Primitive {
3352        // Select assign is ambiguous for QFloat
3353        TensorPrimitive::Float(B::float_select_assign(
3354            tensor.tensor(),
3355            dim,
3356            indices.primitive,
3357            values.tensor(),
3358        ))
3359    }
3360
3361    fn device(tensor: &Self::Primitive) -> Device<B> {
3362        match tensor {
3363            TensorPrimitive::Float(tensor) => B::float_device(tensor),
3364            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
3365        }
3366    }
3367
3368    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
3369        match tensor {
3370            TensorPrimitive::Float(tensor) => {
3371                TensorPrimitive::Float(B::float_to_device(tensor, device))
3372            }
3373            TensorPrimitive::QFloat(tensor) => {
3374                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
3375            }
3376        }
3377    }
3378
3379    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
3380        match tensor {
3381            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
3382            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
3383        }
3384    }
3385
3386    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
3387        match &data.dtype {
3388            DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
3389            _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
3390        }
3391    }
3392
3393    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
3394        match dtype {
3395            DType::QFloat(_scheme) => {
3396                TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
3397            }
3398            _ if dtype.is_float() => {
3399                TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
3400            }
3401            _ => panic!("Expected float dtype, got {dtype:?}"),
3402        }
3403    }
3404
3405    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
3406        match tensor {
3407            TensorPrimitive::Float(tensor) => {
3408                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
3409            }
3410            TensorPrimitive::QFloat(tensor) => {
3411                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
3412            }
3413        }
3414    }
3415
3416    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
3417        match vectors.first().unwrap() {
3418            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
3419                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
3420                dim,
3421            )),
3422            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
3423                vectors
3424                    .into_iter()
3425                    .map(|tensor| {
3426                        if let TensorPrimitive::QFloat(t) = tensor {
3427                            t
3428                        } else {
3429                            panic!("Concatenation only works with vector of QFloat")
3430                        }
3431                    })
3432                    .collect(),
3433                dim,
3434            )),
3435        }
3436    }
3437
3438    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3439        B::float_equal(lhs.tensor(), rhs.tensor())
3440    }
3441
3442    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3443        B::float_not_equal(lhs.tensor(), rhs.tensor())
3444    }
3445
3446    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3447        B::float_any(tensor.tensor())
3448    }
3449
3450    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3451        B::float_any_dim(tensor.tensor(), dim)
3452    }
3453
3454    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3455        B::float_all(tensor.tensor())
3456    }
3457
3458    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3459        B::float_all_dim(tensor.tensor(), dim)
3460    }
3461
3462    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3463        match tensor {
3464            TensorPrimitive::Float(tensor) => {
3465                TensorPrimitive::Float(B::float_permute(tensor, axes))
3466            }
3467            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
3468        }
3469    }
3470
3471    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3472        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
3473    }
3474
3475    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3476        match tensor {
3477            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
3478            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
3479        }
3480    }
3481
3482    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
3483        TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
3484    }
3485}
3486
3487impl<B: Backend> BasicOps<B> for Int {
3488    type Elem = B::IntElem;
3489
3490    fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3491        B::int_empty(shape, device, dtype.into())
3492    }
3493
3494    fn full<E: ElementConversion>(
3495        shape: Shape,
3496        fill_value: E,
3497        device: &B::Device,
3498        dtype: DType,
3499    ) -> Self::Primitive {
3500        B::int_full(shape, fill_value.elem(), device, dtype.into())
3501    }
3502
3503    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
3504        tr.register_int(tensor);
3505    }
3506
3507    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3508        B::int_reshape(tensor, shape)
3509    }
3510
3511    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
3512        B::int_transpose(tensor)
3513    }
3514
3515    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
3516        B::int_swap_dims(tensor, dim1, dim2)
3517    }
3518
3519    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
3520        B::int_slice(tensor, slices)
3521    }
3522
3523    fn slice_assign(
3524        tensor: Self::Primitive,
3525        slices: &[Slice],
3526        value: Self::Primitive,
3527    ) -> Self::Primitive {
3528        B::int_slice_assign(tensor, slices, value)
3529    }
3530
3531    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3532        B::int_select(tensor, dim, indices.primitive)
3533    }
3534
3535    fn select_assign(
3536        tensor: Self::Primitive,
3537        dim: usize,
3538        indices: Tensor<B, 1, Int>,
3539        values: Self::Primitive,
3540    ) -> Self::Primitive {
3541        B::int_select_assign(tensor, dim, indices.primitive, values)
3542    }
3543
3544    fn device(tensor: &Self::Primitive) -> Device<B> {
3545        B::int_device(tensor)
3546    }
3547
3548    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
3549        B::int_to_device(tensor, device)
3550    }
3551
3552    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
3553        B::int_into_data(tensor).await
3554    }
3555
3556    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
3557        B::int_from_data(data.convert::<B::IntElem>(), device)
3558    }
3559
3560    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
3561        if !dtype.is_int() {
3562            panic!("Expected int dtype, got {dtype:?}")
3563        }
3564
3565        B::int_from_data(data.convert_dtype(dtype), device)
3566    }
3567
3568    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
3569        B::int_repeat_dim(tensor, dim, times)
3570    }
3571
3572    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3573        B::int_equal(lhs, rhs)
3574    }
3575
3576    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3577        B::int_not_equal(lhs, rhs)
3578    }
3579
3580    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
3581        B::int_cat(vectors, dim)
3582    }
3583
3584    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3585        B::int_any(tensor)
3586    }
3587
3588    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3589        B::int_any_dim(tensor, dim)
3590    }
3591
3592    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3593        B::int_all(tensor)
3594    }
3595
3596    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3597        B::int_all_dim(tensor, dim)
3598    }
3599
3600    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3601        B::int_permute(tensor, axes)
3602    }
3603
3604    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3605        B::int_expand(tensor, shape)
3606    }
3607
3608    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3609        B::int_flip(tensor, axes)
3610    }
3611
3612    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
3613        B::int_unfold(tensor, dim, size, step)
3614    }
3615}
3616
3617impl<B: Backend> BasicOps<B> for Bool {
3618    type Elem = B::BoolElem;
3619
3620    fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3621        if dtype != Self::Elem::dtype() {
3622            panic!("Expected bool data type, got {dtype:?}");
3623        }
3624        B::bool_empty(shape, device)
3625    }
3626
3627    fn full<E: ElementConversion>(
3628        shape: Shape,
3629        fill_value: E,
3630        device: &B::Device,
3631        dtype: DType,
3632    ) -> Self::Primitive {
3633        if dtype != Self::Elem::dtype() {
3634            panic!("Expected bool data type, got {dtype:?}");
3635        }
3636        if fill_value.elem() {
3637            B::bool_ones(shape, device)
3638        } else {
3639            B::bool_zeros(shape, device)
3640        }
3641    }
3642
3643    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
3644        tr.register_bool(tensor);
3645    }
3646
3647    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3648        B::bool_reshape(tensor, shape)
3649    }
3650
3651    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
3652        B::bool_transpose(tensor)
3653    }
3654
3655    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
3656        B::bool_swap_dims(tensor, dim1, dim2)
3657    }
3658
3659    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
3660        B::bool_slice(tensor, slices)
3661    }
3662
3663    fn slice_assign(
3664        tensor: Self::Primitive,
3665        slices: &[Slice],
3666        value: Self::Primitive,
3667    ) -> Self::Primitive {
3668        B::bool_slice_assign(tensor, slices, value)
3669    }
3670
3671    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3672        B::bool_select(tensor, dim, indices.primitive)
3673    }
3674
3675    fn select_assign(
3676        tensor: Self::Primitive,
3677        dim: usize,
3678        indices: Tensor<B, 1, Int>,
3679        values: Self::Primitive,
3680    ) -> Self::Primitive {
3681        B::bool_select_assign(tensor, dim, indices.primitive, values)
3682    }
3683
3684    fn device(tensor: &Self::Primitive) -> Device<B> {
3685        B::bool_device(tensor)
3686    }
3687
3688    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
3689        B::bool_to_device(tensor, device)
3690    }
3691
3692    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
3693        B::bool_into_data(tensor).await
3694    }
3695
3696    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
3697        B::bool_from_data(data.convert::<B::BoolElem>(), device)
3698    }
3699
3700    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
3701        // Backends only use one bool representation dtype
3702        if dtype != B::BoolElem::dtype() {
3703            panic!("Expected bool dtype, got {dtype:?}")
3704        }
3705        B::bool_from_data(data.convert_dtype(dtype), device)
3706    }
3707
3708    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
3709        B::bool_repeat_dim(tensor, dim, times)
3710    }
3711
3712    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3713        B::bool_equal(lhs, rhs)
3714    }
3715
3716    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3717        B::bool_not_equal(lhs, rhs)
3718    }
3719
3720    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
3721        B::bool_cat(vectors, dim)
3722    }
3723
3724    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3725        B::bool_any(tensor)
3726    }
3727
3728    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3729        B::bool_any_dim(tensor, dim)
3730    }
3731
3732    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3733        B::bool_all(tensor)
3734    }
3735
3736    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3737        B::bool_all_dim(tensor, dim)
3738    }
3739
3740    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3741        B::bool_permute(tensor, axes)
3742    }
3743
3744    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3745        B::bool_expand(tensor, shape)
3746    }
3747
3748    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3749        B::bool_flip(tensor, axes)
3750    }
3751
3752    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
3753        B::bool_unfold(tensor, dim, size, step)
3754    }
3755}
3756
3757/// Trait used for movedim arguments
3758pub trait MovedimArgs {
3759    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3760    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3761}
3762
3763impl MovedimArgs for Vec<i32> {
3764    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3765        let set = self
3766            .iter()
3767            .map(|&dim| {
3768                if dim < 0 {
3769                    (D as i32 + dim) as usize
3770                } else {
3771                    dim as usize
3772                }
3773            })
3774            .collect::<Vec<usize>>();
3775        check!(TensorCheck::movedim_args_vec::<D>(&set));
3776
3777        set
3778    }
3779}
3780
3781impl MovedimArgs for Vec<usize> {
3782    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3783        check!(TensorCheck::movedim_args_vec::<D>(&self));
3784        self
3785    }
3786}
3787
3788impl MovedimArgs for usize {
3789    #[allow(clippy::vec_init_then_push)]
3790    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3791        check!(TensorCheck::movedim_args_usize::<D>(self));
3792
3793        let mut set = Vec::with_capacity(1);
3794        set.push(self);
3795
3796        set
3797    }
3798}
3799
3800impl MovedimArgs for i32 {
3801    #[allow(clippy::vec_init_then_push)]
3802    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3803        check!(TensorCheck::movedim_args_i32::<D>(self));
3804
3805        let dim = if self < 0 {
3806            (D as i32 + self) as usize
3807        } else {
3808            self as usize
3809        };
3810
3811        let mut set = Vec::with_capacity(1);
3812        set.push(dim);
3813
3814        set
3815    }
3816}
3817
3818/// Trait used for reshape arguments.
3819pub trait ReshapeArgs<const D2: usize> {
3820    /// Converts to a shape.
3821    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3822        self,
3823        tensor: &Tensor<B, D, K>,
3824    ) -> Shape;
3825}
3826
3827impl<const D2: usize> ReshapeArgs<D2> for Shape {
3828    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3829        self,
3830        tensor: &Tensor<B, D, K>,
3831    ) -> Shape {
3832        check!(TensorCheck::reshape_args_usize::<D, D2>(
3833            &tensor.shape(),
3834            &self
3835        ));
3836
3837        self
3838    }
3839}
3840impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3841    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3842        self,
3843        tensor: &Tensor<B, D, K>,
3844    ) -> Shape {
3845        let shape = Shape::from(self);
3846
3847        check!(TensorCheck::reshape_args_usize::<D, D2>(
3848            &tensor.shape(),
3849            &shape
3850        ));
3851
3852        shape
3853    }
3854}
3855
3856impl<const D2: usize> ReshapeArgs<D2> for [i64; D2] {
3857    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3858        self,
3859        tensor: &Tensor<B, D, K>,
3860    ) -> Shape {
3861        // Validate the reshape arguments
3862        check!(TensorCheck::reshape_args_i64(&self));
3863
3864        // Temporary shape
3865        let mut new_shape: [i64; D2] = [1; D2];
3866
3867        // We need to find the index of the 0 dimension and
3868        // replace it with the actual dimension value.
3869        for (i, &s) in self.iter().enumerate() {
3870            if s != 0 {
3871                new_shape[i] = s;
3872            } else {
3873                new_shape[i] = tensor.dims()[i] as i64;
3874            }
3875        }
3876
3877        // Find the index of the inferred dimension (-1)
3878        let infer_index = new_shape.iter().position(|x| x == &-1);
3879
3880        // Handle the case where the dimension is inferred (via -1)
3881        if let Some(index) = infer_index {
3882            // Handle the case where the dimension is inferred
3883            let mut product = 1;
3884            for (i, &s) in new_shape.iter().enumerate() {
3885                if i != index {
3886                    product *= s;
3887                }
3888            }
3889            let product_current = tensor.shape().num_elements() as i64;
3890
3891            new_shape[index] = product_current / product;
3892
3893            // Check if the reshape is valid
3894            if product_current % product != 0 {
3895                panic!(
3896                    "Cannot reshape tensor of shape {:?} to shape {:?}",
3897                    tensor.shape(),
3898                    new_shape
3899                );
3900            }
3901        };
3902
3903        // Convert each element to usize
3904        let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3905
3906        Shape::from(new_shape)
3907    }
3908}
3909
3910impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3911    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3912        self,
3913        tensor: &Tensor<B, D, K>,
3914    ) -> Shape {
3915        // Convert i32 array to i64 array and use existing implementation
3916        let i64_array: [i64; D2] = self.map(|x| x as i64);
3917        ReshapeArgs::into_shape(i64_array, tensor)
3918    }
3919}
3920
3921/// Trait used for broadcast arguments.
3922pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3923    /// Converts to a shape.
3924    fn into_shape(self, shape: &Shape) -> Shape;
3925}
3926
3927impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3928    fn into_shape(self, _shape: &Shape) -> Shape {
3929        self
3930    }
3931}
3932impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3933    fn into_shape(self, _shape: &Shape) -> Shape {
3934        Shape::from(self)
3935    }
3936}
3937
3938impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3939    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3940    fn into_shape(self, shape: &Shape) -> Shape {
3941        if self.len() < shape.num_dims() {
3942            panic!("Broadcast arguments must be greater than the number of dimensions");
3943        }
3944
3945        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3946        let new_shape: Vec<_> = self
3947            .iter()
3948            .rev()
3949            .map(|x| {
3950                let primitive = x.to_i64();
3951                if primitive < -1 || primitive == 0 {
3952                    panic!("Broadcast arguments must be positive or -1");
3953                }
3954                primitive
3955            })
3956            .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3957            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3958            .collect::<Vec<_>>()
3959            .into_iter()
3960            .rev()
3961            .collect();
3962
3963        if new_shape.contains(&0) {
3964            panic!("Cannot substitute -1 for a non-existing dimension");
3965        }
3966
3967        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3968
3969        Shape::from(new_shape)
3970    }
3971}
3972
3973impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3974where
3975    B: Backend,
3976    K: BasicOps<B>,
3977    K::Elem: Debug + Copy + Serialize,
3978{
3979    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3980        let data = self.to_data();
3981        data.serialize(serializer)
3982    }
3983}
3984
3985impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3986where
3987    B: Backend,
3988    K: BasicOps<B>,
3989    K::Elem: Debug + Copy + Deserialize<'de>,
3990{
3991    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3992        let tensor = Tensor::from_data(
3993            TensorData::deserialize(deserializer)?,
3994            &<B::Device as Default>::default(),
3995        );
3996        Ok(tensor)
3997    }
3998}
3999
4000#[cfg(test)]
4001mod tests {
4002    use crate::{Shape, s};
4003
4004    #[test]
4005    fn slice_range_single_dim_leading() {
4006        let shape = Shape::new([8, 4]);
4007
4008        // Half-open range
4009        let slices = shape.clone().into_slices([0..5]);
4010        assert_eq!(slices[0].to_range(8), 0..5);
4011        let slices = shape.clone().into_slices([-3..-1]);
4012        assert_eq!(slices[0].to_range(8), 5..7);
4013
4014        // Inclusive range
4015        let slices = shape.clone().into_slices([0..=4]);
4016        assert_eq!(slices[0].to_range(8), 0..5);
4017        let slices = shape.clone().into_slices([-2..=-1]);
4018        assert_eq!(slices[0].to_range(8), 6..8);
4019
4020        // Unbounded start
4021        let slices = shape.clone().into_slices([..3]);
4022        assert_eq!(slices[0].to_range(8), 0..3);
4023        let slices = shape.clone().into_slices([..-5]);
4024        assert_eq!(slices[0].to_range(8), 0..3);
4025
4026        // Unbounded end
4027        let slices = shape.clone().into_slices([5..]);
4028        assert_eq!(slices[0].to_range(8), 5..8);
4029        let slices = shape.clone().into_slices([-3..]);
4030        assert_eq!(slices[0].to_range(8), 5..8);
4031
4032        // Full range
4033        let slices = shape.into_slices([..]);
4034        assert_eq!(slices[0].to_range(8), 0..8);
4035    }
4036
4037    #[test]
4038    fn test_negative_slice_indices() {
4039        use crate::Slice;
4040
4041        // Test negative indices conversion
4042        let slice: Slice = (-3..-1).into();
4043        assert_eq!(slice.start, -3);
4044        assert_eq!(slice.end, Some(-1));
4045
4046        // Test to_range conversion with size 8
4047        let range = slice.to_range(8);
4048        assert_eq!(range, 5..7);
4049
4050        // Test with shape slice
4051        let shape = Shape::new([8, 4]);
4052        let result = shape.clone().into_slices([-3..-1]);
4053        assert_eq!(result[0].to_range(8), 5..7);
4054
4055        // Test more negative index cases
4056        let slice2: Slice = (-5..).into();
4057        assert_eq!(slice2.to_range(10), 5..10);
4058
4059        let slice3: Slice = (..-2).into();
4060        assert_eq!(slice3.to_range(10), 0..8);
4061
4062        // Test with s! macro - single dimension returns Slice directly
4063        let slice4 = s![-3..-1];
4064        assert_eq!(slice4.start, -3);
4065        assert_eq!(slice4.end, Some(-1));
4066    }
4067
4068    #[test]
4069    fn slice_range_multi_dim() {
4070        let shape = Shape::new([8, 4]);
4071
4072        // Multiple ways to provide ranges
4073        let slices = shape.clone().into_slices([0..5, 0..4]);
4074        assert_eq!(slices[0].to_range(8), 0..5);
4075        assert_eq!(slices[1].to_range(4), 0..4);
4076
4077        let slices = shape.clone().into_slices([0.., 0..]);
4078        assert_eq!(slices[0].to_range(8), 0..8);
4079        assert_eq!(slices[1].to_range(4), 0..4);
4080
4081        let slices = shape.clone().into_slices([0..=7, 0..=3]);
4082        assert_eq!(slices[0].to_range(8), 0..8);
4083        assert_eq!(slices[1].to_range(4), 0..4);
4084
4085        let slices = shape.clone().into_slices([0..5, 0..3]);
4086        assert_eq!(slices[0].to_range(8), 0..5);
4087        assert_eq!(slices[1].to_range(4), 0..3);
4088
4089        let slices = shape.into_slices([0.., 0..]);
4090        assert_eq!(slices[0].to_range(8), 0..8);
4091        assert_eq!(slices[1].to_range(4), 0..4);
4092    }
4093
4094    #[test]
4095    fn slice_range_multi_dim_index() {
4096        let shape = Shape::new([8, 4]);
4097
4098        // Indices (single integer) should also convert to correct range
4099        let slices = shape.clone().into_slices([0, 2]);
4100        assert_eq!(slices[0].to_range(8), 0..1);
4101        assert_eq!(slices[1].to_range(4), 2..3);
4102
4103        let slices = shape.into_slices([-1, -1]);
4104        assert_eq!(slices[0].to_range(8), 7..8);
4105        assert_eq!(slices[1].to_range(4), 3..4);
4106    }
4107
4108    #[test]
4109    fn slice_range_multi_dim_heterogeneous() {
4110        // Slice macro `s![]` can be used to provide different range types
4111        let shape = Shape::new([8, 4, 2]);
4112        let slice = s![0..5, .., -1];
4113        let slices = shape.into_slices(slice);
4114        assert_eq!(slices[0].to_range(8), 0..5);
4115        assert_eq!(slices[1].to_range(4), 0..4);
4116        assert_eq!(slices[2].to_range(2), 1..2);
4117
4118        let shape = Shape::new([8, 4, 2, 3]);
4119        let slice = s![..=4, 0..=3, .., -2..];
4120        let slices = shape.into_slices(slice);
4121        assert_eq!(slices[0].to_range(8), 0..5);
4122        assert_eq!(slices[1].to_range(4), 0..4);
4123        assert_eq!(slices[2].to_range(2), 0..2);
4124        assert_eq!(slices[3].to_range(3), 1..3);
4125
4126        let shape = Shape::new([3, 4]);
4127        let slice = s![1..-1, ..];
4128        let slices = shape.into_slices(slice);
4129        assert_eq!(slices[0].to_range(3), 1..2);
4130        assert_eq!(slices[1].to_range(4), 0..4);
4131    }
4132}