burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::future::Future;
11use core::iter::repeat;
12use core::{fmt::Debug, ops::Range};
13use serde::{Deserialize, Deserializer};
14
15use serde::{Serialize, Serializer};
16
17use crate::tensor::api::narrow::narrow;
18use crate::{
19    Bool, Float, Int, Shape, TensorData, TensorKind, backend::Backend, check, ops::Device,
20};
21use crate::{DType, Element, TensorPrimitive};
22use crate::{cast::ToElement, check::TensorCheck};
23
24use super::{Slice, TensorMetadata, Transaction};
25
26/// A tensor with a given backend, shape and data type.
27///
28/// # Indexing
29/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
30/// or [`select`](Tensor::select) for numeric types.
31///
32/// ## Example
33///
34/// ```rust
35/// use burn_tensor::backend::Backend;
36/// use burn_tensor::Tensor;
37/// use burn_tensor::Int;
38///
39/// fn example<B: Backend>() {
40///     let device = Default::default();
41///
42///     let tensor = Tensor::<B, 2>::from_data(
43///         [
44///             [3.0, 4.9, 2.0],
45///             [2.0, 1.9, 3.0],
46///             [6.0, 1.5, 7.0],
47///             [3.0, 4.9, 9.0],
48///         ],
49///         &device,
50///     );
51///
52///     // Slice the tensor to get the second and third rows:
53///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
54///     // The resulting tensor will have dimensions [2, 3].
55///     let slice = tensor.clone().slice([1..3]);
56///     println!("{slice}");
57///
58///     // Slice the tensor to get the first two rows and the first 2 columns:
59///     // [[3.0, 4.9], [2.0, 1.9]]
60///     // The resulting tensor will have dimensions [2, 2].
61///     let slice = tensor.clone().slice([0..2, 0..2]);
62///     println!("{slice}");
63///
64///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
65///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
66///     // The resulting tensor will have dimensions [4, 2]
67///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
68///     let indexed = tensor.select(1, indices);
69///     println!("{indexed}");
70/// }
71/// ```
72#[derive(new, Clone, Debug)]
73pub struct Tensor<B, const D: usize, K = Float>
74where
75    B: Backend,
76    K: TensorKind<B>,
77{
78    pub(crate) primitive: K::Primitive,
79}
80
81impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
82where
83    B: Backend,
84    K: BasicOps<B>,
85    T: Into<TensorData>,
86{
87    fn from(value: T) -> Self {
88        Tensor::from_data(value.into(), &Default::default())
89    }
90}
91
92impl<B, const D: usize, K> Tensor<B, D, K>
93where
94    B: Backend,
95    K: BasicOps<B>,
96{
97    /// Converts the tensor into a primitive tensor.
98    pub fn into_primitive(self) -> K::Primitive {
99        self.primitive
100    }
101
102    /// Converts from a primitive tensor into a tensor.
103    pub fn from_primitive(tensor: K::Primitive) -> Self {
104        Self::new(tensor)
105    }
106
107    /// Returns the tensor primitive data type.
108    ///
109    /// # Note
110    /// Some element types are encoded in different primitive types depending on the backend
111    /// (e.g., bool could be encoded as `u8` or `u32`).
112    pub fn dtype(&self) -> DType {
113        self.primitive.dtype()
114    }
115
116    /// Create an empty tensor of the given shape.
117    ///
118    /// # Arguments
119    ///
120    /// - shape: The shape of the tensor.
121    /// - device: The device where the tensor will be created.
122    ///
123    /// # Example
124    /// ```rust
125    /// use burn_tensor::backend::Backend;
126    /// use burn_tensor::Tensor;
127    ///
128    /// fn example<B: Backend>() {
129    ///    let device = Default::default();
130    ///    // Create an empty tensor with dimensions [2, 3, 4].
131    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
132    /// }
133    /// ```
134    pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
135        let shape = shape.into();
136        check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
137        Self::new(K::empty(shape, device))
138    }
139
140    /// Returns the dimensions of the current tensor.
141    ///
142    /// # Example
143    /// ```rust
144    /// use burn_tensor::backend::Backend;
145    /// use burn_tensor::Tensor;
146    ///
147    /// fn example<B: Backend>() {
148    ///   let device = Default::default();
149    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
150    ///   let dims = tensor.dims(); // [2, 3, 4]
151    ///   println!("{dims:?}");
152    /// }
153    /// ```
154    pub fn dims(&self) -> [usize; D] {
155        Self::shape(self).dims()
156    }
157
158    /// Returns the shape of the current tensor.
159    ///
160    /// # Example
161    /// ```rust
162    /// use burn_tensor::backend::Backend;
163    /// use burn_tensor::Tensor;
164    ///
165    /// fn example<B: Backend>() {
166    ///    let device = Default::default();
167    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
168    ///    // Shape { dims: [2, 3, 4] }
169    ///    let shape = tensor.shape();
170    /// }
171    /// ```
172    pub fn shape(&self) -> Shape {
173        self.primitive.shape()
174    }
175
176    /// Reshape the tensor to have the given shape.
177    ///
178    /// The tensor has the same data and number of elements as the input.
179    ///
180    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
181    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
182    ///
183    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
184    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
185    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
186    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
187    /// with [1, 3, 4] dimensions to [1, 12].
188    ///
189    /// # Arguments
190    /// - `shape`: The new shape of the tensor.
191    ///
192    /// # Panics
193    /// - If the tensor contains more than one `-1` in the shape.
194    /// - If the tensor contains values that are not positive (other than -1).
195    /// - If the shape does not match the number of elements of the original shape.
196    ///
197    /// # Example
198    ///
199    /// ```rust
200    /// use burn_tensor::backend::Backend;
201    /// use burn_tensor::Tensor;
202    ///
203    /// fn example<B: Backend>() {
204    ///    let device = Default::default();
205    ///    // Create a tensor with dimensions [2, 3, 4]
206    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
207    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
208    ///    let reshaped = tensor.reshape([2, -1]);
209    ///    println!("{reshaped}");
210    /// }
211    /// ```
212    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
213        // Convert reshape args to shape
214        let shape = shape.into_shape(&self);
215        Tensor::new(K::reshape(self.primitive, shape))
216    }
217
218    /// Transpose the tensor.
219    ///
220    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
221    /// applied on the last two dimensions. For example, the transpose of a tensor with shape
222    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
223    ///
224    /// See also [`permute`](Tensor::permute).
225    ///
226    /// # Arguments
227    ///
228    /// * `tensor` - The tensor to transpose.
229    ///
230    /// # Returns
231    ///
232    /// The transposed tensor.
233    ///
234    /// # Example
235    ///
236    /// ```rust
237    /// use burn_tensor::backend::Backend;
238    /// use burn_tensor::Tensor;
239    ///
240    /// fn example<B: Backend>() {
241    ///     let device = Default::default();
242    ///     // Create a 2D tensor of shape [2, 3]
243    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
244    ///
245    ///     // Transpose the tensor:
246    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
247    ///     // The resulting tensor will have dimensions [3, 2].
248    ///     let transposed = tensor.transpose();
249    ///     println!("{transposed}");
250    /// }
251    /// ```
252    pub fn transpose(self) -> Tensor<B, D, K> {
253        Tensor::new(K::transpose(self.primitive))
254    }
255
256    /// Swaps two dimensions of a tensor.
257    ///
258    /// # Arguments
259    ///
260    /// * `tensor` - The tensor to swap the dimensions of.
261    /// * `dim1` - The first dimension to swap.
262    /// * `dim2` - The second dimension to swap.
263    ///
264    /// # Returns
265    ///
266    /// The tensor with the dimensions swapped.
267    ///
268    /// # Example
269    ///
270    /// ```rust
271    /// use burn_tensor::backend::Backend;
272    /// use burn_tensor::Tensor;
273    ///
274    /// fn example<B: Backend>() {
275    ///     let device = Default::default();
276    ///     // Create a 2D tensor of shape [2, 3]
277    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
278    ///
279    ///     // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`):
280    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
281    ///     // The resulting tensor will have dimensions [3, 2].
282    ///     let swapped = tensor.swap_dims(0, 1);
283    ///     println!("{swapped}");
284    /// }
285    /// ```
286    pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
287        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
288        Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
289    }
290
291    /// Permute the dimensions of the tensor.
292    ///
293    /// # Arguments
294    ///
295    /// * `axes` - The new order of the dimensions. The length of the axes
296    ///   must be equal to the number of dimensions of the tensor.
297    ///   The values must be unique and in the range of the number of dimensions.
298    ///   The values can be negative, in which case they are used as an offset from the end.
299    ///
300    /// # Returns
301    ///
302    /// The tensor with the dimensions permuted.
303    ///
304    /// # Example
305    ///
306    /// ```rust
307    /// use burn_tensor::backend::Backend;
308    /// use burn_tensor::Tensor;
309    ///
310    /// fn example<B: Backend>() {
311    ///     let device = Default::default();
312    ///     // Create a 2D tensor of shape [3, 2]
313    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
314    ///
315    ///     // Permute the dimensions 1 and 0:
316    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
317    ///     // The resulting tensor will have dimensions [3, 2].
318    ///     let permuted = tensor.permute([1, 0]);
319    ///     println!("{permuted}");
320    /// }
321    /// ```
322    pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
323        // Convert the axes to usize and handle negative values without using vector
324        let mut transformed_axes: [usize; D] = [0; D];
325        for (i, &x) in axes.iter().enumerate() {
326            transformed_axes[i] = if x < 0 {
327                (D as isize + x) as usize
328            } else {
329                x as usize
330            };
331        }
332
333        // Check if the axes are valid after the transformation
334        check!(TensorCheck::permute(transformed_axes));
335
336        Tensor::new(K::permute(self.primitive, &transformed_axes))
337    }
338
339    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
340    ///
341    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
342    /// at the positions not specified in destination.
343    ///
344    /// # Arguments
345    ///
346    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
347    ///   The values can be negative, in which case they are used as an offset from the end.
348    ///
349    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
350    ///
351    /// # Panics
352    ///
353    /// - If the source and destination dimensions are not of the same length.
354    /// - If the source and destination vectors contain duplicate values.
355    /// - If the source and destination vectors contain values that are out of bounds.
356    ///
357    /// # Returns
358    ///
359    /// The tensor with the dimensions moved.
360    ///
361    /// # Example
362    ///
363    /// ```rust
364    /// use burn_tensor::backend::Backend;
365    /// use burn_tensor::Tensor;
366    ///
367    /// fn example<B: Backend>() {
368    ///     let device = Default::default();
369    ///     // Create a 3D tensor of shape [3, 2, 1]
370    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
371    ///
372    ///     // Move the dimensions 0 and 1:
373    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
374    ///     // The resulting tensor will have dimensions [2, 3, 1].
375    ///     let moved = tensor.movedim(1, 0);
376    ///     println!("{moved}");
377    /// }
378    /// ```
379    ///
380    /// # Note
381    ///
382    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
383    /// for it
384    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
385        let source_dims = src.into_dim_vec::<D>();
386        let destination_dims = dst.into_dim_vec::<D>();
387
388        check!(TensorCheck::movedim_args_length(
389            &source_dims,
390            &destination_dims
391        ));
392
393        let mut m = [-1; D];
394        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
395            m[d] = s as isize;
396        }
397        let mut axes: [isize; D] = [0; D];
398        let mut source_i = 0;
399        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
400            *item = if m[dest_i] != -1 {
401                m[dest_i]
402            } else {
403                while source_dims.contains(&source_i) {
404                    source_i += 1;
405                }
406                let result = source_i as isize;
407                source_i += 1;
408                result
409            };
410        }
411
412        self.permute(axes)
413    }
414
415    /// Reverse the order of elements in the tensor along the given dimensions.
416    ///
417    /// # Arguments
418    ///
419    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
420    ///   The values can be negative, in which case they are used as an offset from the end.
421    ///
422    /// # Returns
423    ///
424    /// The tensor with the axes flipped.
425    ///
426    /// # Example
427    ///
428    /// ```rust
429    /// use burn_tensor::backend::Backend;
430    /// use burn_tensor::Tensor;
431    ///
432    /// fn example<B: Backend>() {
433    ///     let device = Default::default();
434    ///     // Create a 2D tensor with dimensions [4, 3]
435    ///     let tensor = Tensor::<B, 2>::from_data(
436    ///         [
437    ///             [3.0, 4.9, 2.0],
438    ///             [2.0, 1.9, 3.0],
439    ///             [4.0, 5.9, 8.0],
440    ///             [1.4, 5.8, 6.0],
441    ///         ],
442    ///         &device,
443    ///     );
444    ///
445    ///     // Flip the elements in dimensions 0 and 1:
446    ///     // [[6.0, 5.8, 1.4],
447    ///     //  [8.0, 5.9, 4.0],
448    ///     //  [3.0, 1.9, 2.0],
449    ///     //  [2.0, 4.9, 3.0]]
450    ///     // The resulting tensor will have dimensions [4, 3].
451    ///     let flipped = tensor.flip([0, 1]);
452    ///     println!("{flipped}");
453    /// }
454    /// ```
455    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
456        // Convert the axes to usize and handle negative values without using vector
457        let mut transformed_axes: [usize; N] = [0; N];
458        for (i, &x) in axes.iter().enumerate() {
459            transformed_axes[i] = if x < 0 {
460                (D as isize + x) as usize
461            } else {
462                x as usize
463            };
464        }
465
466        // Check if the axes are valid
467        check!(TensorCheck::flip(D, &transformed_axes));
468
469        Tensor::new(K::flip(self.primitive, &transformed_axes))
470    }
471
472    /// Flatten the tensor along a given range of dimensions.
473    ///
474    /// This function collapses the specified range of dimensions into a single dimension,
475    /// effectively flattening the tensor in that range.
476    ///
477    /// # Arguments
478    ///
479    /// - `start_dim`: The starting dimension of the range to be flattened.
480    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
481    ///
482    /// # Type Parameters
483    ///
484    /// - `D2`: The resulting number of dimensions in the flattened tensor.
485    ///
486    /// # Returns
487    ///
488    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
489    ///
490    /// # Example
491    ///
492    /// ```rust
493    ///
494    /// use burn_tensor::backend::Backend;
495    /// use burn_tensor::{Tensor, Shape};
496    ///
497    /// fn example<B: Backend>() {
498    ///     let device = Default::default();
499    ///     // Create a 3D tensor with dimensions [2, 3, 4]
500    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
501    ///
502    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
503    ///     // The resulting tensor will have dimensions [2, 12]
504    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
505    ///     println!("{flattened}");
506    /// }
507    /// ```
508    pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
509        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
510
511        let current_dims = self.shape().dims;
512        let mut new_dims: [usize; D2] = [0; D2];
513        let mut flatten_dims = 1;
514
515        for i in current_dims[start_dim..=end_dim].iter() {
516            flatten_dims *= i;
517        }
518
519        new_dims[..start_dim].copy_from_slice(&current_dims[..start_dim]);
520        new_dims[start_dim] = flatten_dims;
521        new_dims[start_dim + 1..].copy_from_slice(&current_dims[end_dim + 1..]);
522
523        Tensor::new(K::reshape(self.primitive, new_dims.into()))
524    }
525
526    /// Squeeze the tensor along the given dimension, removing the specified dimension
527    /// of size one, and effectively reducing the rank of the tensor by one.
528    ///
529    /// # Arguments
530    ///
531    /// - `dim`: The dimension to be squeezed.
532    ///
533    /// # Type Parameters
534    ///
535    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
536    ///
537    /// # Panics
538    ///
539    /// If the size in the squeezed dimension is not 1.
540    ///
541    /// # Returns
542    ///
543    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
544    ///
545    /// # Example
546    ///
547    /// ```rust
548    ///
549    /// use burn_tensor::backend::Backend;
550    /// use burn_tensor::{Tensor, Shape};
551    ///
552    /// fn example<B: Backend>() {
553    ///     let device = Default::default();
554    ///     // Create a 3D tensor with dimensions [3, 1, 3]
555    ///     let tensor = Tensor::<B, 3>::from_data(
556    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
557    ///         &device,
558    ///     );
559    ///
560    ///     // Squeeze the dimension 1.
561    ///     // The resulting tensor will have dimensions [3, 3].
562    ///     let squeezed = tensor.squeeze::<2>(1);
563    ///     println!("{squeezed}");
564    /// }
565    /// ```
566    pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
567        check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
568
569        let current_dims = self.shape().dims;
570        let mut new_dims: [usize; D2] = [0; D2];
571
572        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
573        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
574
575        Tensor::new(K::reshape(self.primitive, new_dims.into()))
576    }
577
578    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
579    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
580    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
581    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
582    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
583    /// from the back.
584    ///
585    /// # Arguments
586    ///
587    /// - `dims`: The dimension(s) to be squeezed.
588    ///
589    /// # Type Parameters
590    ///
591    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
592    ///
593    /// # Returns
594    ///
595    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
596    ///
597    /// # Example
598    ///
599    /// ```rust
600    ///
601    /// use burn_tensor::backend::Backend;
602    /// use burn_tensor::{Tensor, Shape};
603    ///
604    /// fn example<B: Backend>() {
605    ///     let device = Default::default();
606    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
607    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
608    ///
609    ///     // Squeeze the dimensions 1 and 3.
610    ///     // The resulting tensor will have dimensions [2, 4].
611    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
612    ///     println!("{squeezed}");
613    /// }
614    /// ```
615    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
616        let current_dims = self.shape().dims;
617        let mut dim_indices: Vec<usize>;
618
619        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
620        if dims.is_empty() {
621            dim_indices = current_dims
622                .iter()
623                .enumerate()
624                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
625                .collect();
626        } else {
627            // If negative dims, count from the back
628            dim_indices = dims
629                .iter()
630                .map(|&d| {
631                    if d < 0 {
632                        (current_dims.len() as isize + d) as usize
633                    } else {
634                        d as usize
635                    }
636                })
637                .collect();
638        }
639
640        // Sort indices and remove duplicates
641        dim_indices.sort_unstable();
642        dim_indices.dedup();
643
644        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
645        check!(TensorCheck::squeeze_dims_input::<D2>(
646            &dim_indices,
647            &current_dims
648        ));
649
650        // Calculate new dimensions
651        let mut new_dims = Vec::new();
652        for (index, &dim_size) in current_dims.iter().enumerate() {
653            // Exclude the dimension if it's explicitly marked for squeezing
654            if dim_indices.contains(&index) {
655                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
656                continue;
657            }
658            new_dims.push(dim_size);
659        }
660
661        // Check that after squeezing, we still respect the D2 size
662        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
663
664        Tensor::new(K::reshape(self.primitive, new_dims.into()))
665    }
666
667    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
668    ///
669    /// # Type Parameters
670    ///
671    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.
672    ///
673    /// # Panics
674    ///
675    /// If the output size `D2` is smaller than the current number of dimensions.
676    ///
677    /// # Returns
678    ///
679    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
680    ///
681    /// # Example
682    ///
683    /// ```rust
684    /// use burn_tensor::backend::Backend;
685    /// use burn_tensor::{Tensor, Shape};
686    ///
687    /// fn example<B: Backend>() {
688    ///     let device = Default::default();
689    ///     // Create a 2D tensor with dimensions [3, 3]
690    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
691    ///     // Unsqueeze the tensor up to 4 dimensions.
692    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
693    ///     let unsqueezed = tensor.unsqueeze::<4>();
694    ///     println!("{unsqueezed}");
695    /// }
696    /// ```
697    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
698        check!(TensorCheck::unsqueeze::<D, D2>());
699
700        let mut dims = [1; D2];
701        let num_ones = D2 - D;
702        let shape = self.shape();
703
704        dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
705
706        let shape = Shape::new(dims);
707        self.reshape(shape)
708    }
709
710    /// Creates a new tensor with a dimension of size one inserted at the specified position.
711    ///
712    /// # Example
713    ///
714    /// ```rust
715    /// use burn_tensor::backend::Backend;
716    /// use burn_tensor::{Tensor, Shape};
717    ///
718    /// fn example<B: Backend>() {
719    ///     let device = Default::default();
720    ///     // Create a 2D tensor with dimensions [3, 3]
721    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
722    ///     // Unsqueeze the dimension 1.
723    ///     // The resulting tensor will have dimensions [3, 1, 3].
724    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
725    ///     println!("{unsqueezed}");
726    /// }
727    /// ```
728    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
729        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
730
731        let mut dims = [1; D2];
732        let shape = self.shape();
733
734        dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
735
736        if dim < D {
737            dims[dim] = 1;
738            dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
739        } else {
740            dims[dim] = 1;
741        }
742
743        let shape = Shape::new(dims);
744        self.reshape(shape)
745    }
746
747    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
748    /// The indices can be negative, in which case they are counted from the last to the first dimension.
749    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
750    /// is the number of duplicates.
751    /// # Example
752    ///
753    /// ```rust
754    /// use burn_tensor::backend::Backend;
755    /// use burn_tensor::{Tensor, Shape};
756    ///
757    /// fn example<B: Backend>() {
758    ///     let device = Default::default();
759    ///     // Create a 3D tensor with dimensions [3, 4, 5]
760    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
761    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
762    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
763    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
764    ///     println!("{unsqueezed}");
765    /// }
766    /// ```
767    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
768        let mut new_dims = [1; D2];
769        let old_dims = self.shape().dims;
770        //for checking if the dimension is in the acceptable range
771
772        //part 1: convert the negative indices to positive
773        let mut neg_offset = D2;
774        let mut dim_indices = axes
775            .iter()
776            .map(|d| {
777                // check if the dimension is in the acceptable range
778                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
779                (if *d < 0 {
780                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
781                    d + neg_offset as isize + 1
782                } else {
783                    *d
784                }) as usize
785            })
786            .collect::<Vec<usize>>();
787
788        //sort the indices
789        dim_indices.sort_unstable();
790
791        //Now use this to copy the chunks of the dims
792        let mut prev_idx: usize = 0;
793        let mut current_left_b: usize = 0;
794        let mut current_right_b: usize = 0;
795        let mut offset: usize = 0;
796        dim_indices.iter().for_each(|d| {
797            //check if there is space for at least one dimension
798            if prev_idx < *d {
799                current_right_b = *d - offset;
800                //copy the chunks of the dims
801                if current_right_b < D {
802                    new_dims[prev_idx..*d]
803                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);
804                } else {
805                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
806                }
807                prev_idx = *d + 1;
808                //offset is equal to the number of extracted elements from the original shape
809                offset += current_right_b - current_left_b;
810                current_left_b = current_right_b;
811            } else {
812                //it's sorted so the only reason this would happen
813                //is if multiple indices are the same
814                prev_idx += 1;
815            }
816        });
817        //copy over anything past the index of the last new dimension
818        if current_left_b < D {
819            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
820        }
821
822        //lastly, create the shape and reshape
823        let shape = Shape::new(new_dims);
824        self.reshape(shape)
825    }
826
827    /// Returns a tensor containing the elements selected from the given ranges.
828    ///
829    /// For more complex indexing with different slice ranges, see also the slice
830    /// macro [`s!`](crate::s).
831    ///
832    /// # Arguments
833    ///
834    /// * `ranges` - A type implementing the `RangesArg` trait, which can be:
835    ///   - A single range (slice the first dimension)
836    ///   - A single index (slice the first dimension)
837    ///   - An array of ranges
838    ///
839    /// # Behavior
840    ///
841    /// - Supports partial and full slicing in any number of dimensions.
842    /// - Missing ranges are treated as full slices if D > D2.
843    /// - Handles negative indices by wrapping around from the end of the dimension.
844    /// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
845    ///
846    /// # Panics
847    ///
848    /// - If the number of ranges provided exceeds the tensor's dimensions.
849    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
850    ///
851    /// # Examples
852    ///
853    /// ```rust
854    /// use burn_tensor::backend::Backend;
855    /// use burn_tensor::{Tensor, Shape, s};
856    ///
857    /// fn example<B: Backend>() {
858    ///     let device = B::Device::default();
859    ///
860    ///     // 1D slicing
861    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
862    ///     let slice = tensor.slice([1..4]);
863    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
864    ///
865    ///     // 2D slicing
866    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
867    ///     let slice = tensor.slice([1..3, 0..2]);
868    ///     assert_eq!(slice.dims(), [2, 2]);
869    ///
870    ///     // Using negative indices
871    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
872    ///     let slice = tensor.slice([1..-1]); // Equivalent to 1..4
873    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
874    ///
875    ///     // Using the slice macro to select different ranges
876    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
877    ///     let slice = tensor.slice(s![1.., ..]); // Select rows 1 and 2, all columns
878    ///     assert_eq!(slice.dims(), [2, 4]);
879    ///
880    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..16, &device).reshape([2, 4, 2]);
881    ///     let slice = tensor.slice(s![1.., 1..=3, -1]);
882    ///     assert_eq!(slice.dims(), [1, 3, 1]);
883    /// }
884    /// ```
885    ///
886    /// # Note
887    ///
888    /// This function uses the `RangesArg` trait for flexible range specification. The trait
889    /// handles the conversion of various range formats and applies clamping and negative
890    /// index handling internally.
891    pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
892        let ranges = ranges.into_ranges(self.shape());
893
894        check!(TensorCheck::slice::<D, D2>(&self.shape(), &ranges));
895        Self::new(K::slice(self.primitive, &ranges))
896    }
897
898    /// Returns a copy of the current tensor with the selected elements changed to the new ones at
899    /// the selected indices.
900    ///
901    /// # Panics
902    ///
903    /// - If a range exceeds the number of elements on a dimension.
904    /// - If the given values don't match the given ranges.
905    ///
906    /// # Example
907    ///
908    /// ```rust
909    /// use burn_tensor::backend::Backend;
910    /// use burn_tensor::Tensor;
911    ///
912    /// fn example<B: Backend>() {
913    ///     let device = B::Device::default();
914    ///     let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
915    ///     let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
916    ///     let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
917    ///     println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
918    /// }
919    /// ```
920    pub fn slice_assign<const D2: usize>(self, ranges: [Range<usize>; D2], values: Self) -> Self {
921        check!(TensorCheck::slice_assign::<D, D2>(
922            &self.shape(),
923            &values.shape(),
924            &ranges
925        ));
926        Self::new(K::slice_assign(self.primitive, &ranges, values.primitive))
927    }
928
929    /// Returns the device of the current tensor.
930    pub fn device(&self) -> B::Device {
931        K::device(&self.primitive)
932    }
933
934    /// Move the tensor to the given device.
935    pub fn to_device(self, device: &B::Device) -> Self {
936        Self::new(K::to_device(self.primitive, device))
937    }
938
939    /// Converts the data of the current tensor.
940    ///
941    /// # Note
942    ///
943    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
944    /// tensors at once. This may improve laziness, especially if executed on a different
945    /// thread in native environments.
946    pub fn into_data(self) -> TensorData {
947        crate::try_read_sync(self.into_data_async()).expect(
948            "Failed to read tensor data synchronously.
949        This can happen on platforms that don't support blocking futures like WASM.
950        If possible, try using into_data_async instead.",
951        )
952    }
953
954    /// Converts the data of the current tensor.
955    ///
956    /// # Note
957    ///
958    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
959    /// tensors at once. This may improve laziness, especially if executed on a different
960    /// thread in native environments.
961    pub fn to_data(&self) -> TensorData {
962        self.clone().into_data()
963    }
964
965    /// Returns the data of the current tensor.
966    pub async fn into_data_async(self) -> TensorData {
967        K::into_data_async(self.primitive).await
968    }
969
970    /// Returns the data of the current tensor.
971    pub async fn to_data_async(&self) -> TensorData {
972        self.clone().into_data_async().await
973    }
974
975    /// Create a tensor from the given data on the given device.
976    pub fn from_data<T>(data: T, device: &B::Device) -> Self
977    where
978        T: Into<TensorData>,
979    {
980        let data = data.into();
981        check!(TensorCheck::creation_ops::<D>(
982            "From Data",
983            data.shape.as_slice()
984        ));
985        Self::new(K::from_data(data, device))
986    }
987
988    /// Create a tensor from the given data on the given device enforcing the given data type.
989    pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
990    where
991        T: Into<TensorData>,
992    {
993        let data = data.into();
994        check!(TensorCheck::creation_ops::<D>(
995            "From Data",
996            data.shape.as_slice()
997        ));
998        Self::new(K::from_data_dtype(data, device, dtype))
999    }
1000
1001    /// Repeat the tensor along the given dimension.
1002    ///
1003    /// The output tensor has the same shape, except along the given dimension.
1004    ///
1005    /// # Arguments
1006    /// - `dim`: The dimension to repeat.
1007    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1008    ///
1009    /// # Returns
1010    ///
1011    /// A new tensor with the given dimension repeated `times` times.
1012    ///
1013    /// # Example
1014    ///
1015    /// ```rust
1016    /// use burn_tensor::backend::Backend;
1017    /// use burn_tensor::Tensor;
1018    ///
1019    /// fn example<B: Backend>() {
1020    ///     let device = Default::default();
1021    ///     // Create a 2D tensor with dimensions [3, 2]
1022    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1023    ///
1024    ///     // Repeat the tensor along the dimension 0 twice.
1025    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1026    ///     // The resulting tensor will have dimensions [6, 2].
1027    ///     let repeated = tensor.repeat_dim(0, 2);
1028    ///     println!("{repeated}");
1029    /// }
1030    /// ```
1031    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1032        Self::new(K::repeat_dim(self.primitive, dim, times))
1033    }
1034
1035    /// Repeat the tensor along the given dimensions.
1036    /// # Arguments
1037    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1038    ///
1039    /// # Returns
1040    ///
1041    /// A new tensor with the given dimensions repeated `times` times.
1042    ///
1043    /// # Panics
1044    ///
1045    /// If `sizes` contains more elements than the number of dimensions.
1046    ///
1047    /// # Example
1048    ///
1049    /// ```rust
1050    ///
1051    /// use burn_tensor::backend::Backend;
1052    /// use burn_tensor::Tensor;
1053    ///
1054    /// fn example<B: Backend>() {
1055    ///     let device = Default::default();
1056    ///     // Create a 2D tensor with dimensions [3, 2]
1057    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1058    ///
1059    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1060    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1061    ///     // The resulting tensor will have dimensions [6, 2].
1062    ///     let repeated = tensor.repeat(&[2, 1]);
1063    /// }
1064    /// ```
1065    pub fn repeat(self, sizes: &[usize]) -> Self {
1066        let mut tensor = self;
1067        for (dim, &times) in sizes.iter().enumerate() {
1068            if times > 1 {
1069                tensor = tensor.repeat_dim(dim, times);
1070            }
1071        }
1072        tensor
1073    }
1074
1075    /// Applies element-wise equal comparison.
1076    ///
1077    /// # Returns
1078    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1079    ///
1080    /// # Panics
1081    ///
1082    /// If the two tensors don't have the same shape.
1083    ///
1084    /// # Example
1085    ///
1086    /// ```rust
1087    /// use burn_tensor::backend::Backend;
1088    /// use burn_tensor::Tensor;
1089    ///
1090    /// fn example<B: Backend>() {
1091    ///     let device = Default::default();
1092    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1093    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1094    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
1095    ///     // [[false, true], [true, true], [true, true]]
1096    ///     let equal = t1.equal(t2);
1097    ///     println!("{equal}");
1098    /// }
1099    /// ```
1100    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1101        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1102        Tensor::new(K::equal(self.primitive, other.primitive))
1103    }
1104
1105    /// Applies element-wise non-equality comparison.
1106    ///
1107    /// # Returns
1108    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1109    ///
1110    /// # Panics
1111    ///
1112    /// If the two tensors don't have the same shape.
1113    ///
1114    /// # Example
1115    ///
1116    /// ```rust
1117    /// use burn_tensor::backend::Backend;
1118    /// use burn_tensor::Tensor;
1119    ///
1120    /// fn example<B: Backend>() {
1121    ///     let device = Default::default();
1122    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1123    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1124    ///     // Compare the elements of the two 2D tensors for inequality.
1125    ///     // [[true, false], [false, false], [false, false]]
1126    ///     let not_equal = t1.not_equal(t2);
1127    ///     println!("{not_equal}");
1128    /// }
1129    /// ```
1130    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1131        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1132        Tensor::new(K::not_equal(self.primitive, other.primitive))
1133    }
1134
1135    /// Concatenates all tensors into a new one along the given dimension.
1136    ///
1137    /// # Panics
1138    ///
1139    /// - If `dim` is higher than the rank.
1140    /// - If `tensors` is an empty vector.
1141    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
1142    ///
1143    /// # Example
1144    ///
1145    /// ```rust
1146    /// use burn_tensor::backend::Backend;
1147    /// use burn_tensor::Tensor;
1148    ///
1149    /// fn example<B: Backend>() {
1150    ///     let device = Default::default();
1151    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
1152    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1153    ///
1154    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
1155    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
1156    ///     // The resulting tensor will have shape [2, 7].
1157    ///     let concat = Tensor::cat(vec![t1, t2], 1);
1158    ///     println!("{concat}");
1159    /// }
1160    /// ```
1161    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1162        check!(TensorCheck::cat(&tensors, dim));
1163
1164        Self::new(K::cat(
1165            tensors.into_iter().map(|vector| vector.primitive).collect(),
1166            dim,
1167        ))
1168    }
1169
1170    /// Concatenates all tensors into a new one along a new dimension.
1171    ///
1172    /// # Panics
1173    ///
1174    /// - If all tensors don't have the same shape.
1175    /// - If given dimension is not with range of 0..D2
1176    ///
1177    /// # Example
1178    ///
1179    /// ```rust
1180    /// use burn_tensor::backend::Backend;
1181    /// use burn_tensor::Tensor;
1182    ///
1183    /// fn example<B: Backend>() {
1184    ///     let device = Default::default();
1185    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1186    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1187    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1188    ///
1189    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1190    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1191    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1192    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1193    ///     // The resulting tensor will have shape [3, 2, 3].
1194    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1195    ///     println!("{stacked}");
1196    /// }
1197    /// ```
1198    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1199        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1200        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1201        Tensor::<B, D2, K>::cat(tensors, dim)
1202    }
1203
1204    /// Iterate over slices of tensors alongside a given dimension.
1205    ///
1206    /// # Panics
1207    ///
1208    /// If given dimension is greater than or equal to tensor rank.
1209    ///
1210    /// # Returns
1211    ///
1212    /// A tensor iterator.
1213    ///
1214    /// # Example
1215    ///
1216    /// ```rust
1217    /// use burn_tensor::backend::Backend;
1218    /// use burn_tensor::Tensor;
1219    /// fn example<B: Backend>() {
1220    ///   let device = Default::default();
1221    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1222    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
1223    ///   let iter = tensor.iter_dim(0);
1224    ///   for (i,tensor) in iter.enumerate() {
1225    ///     println!("Tensor {}: {}", i, tensor);
1226    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1227    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1228    ///  }
1229    /// }
1230    /// ```
1231    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1232        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1233        DimIter::new(self, dim)
1234    }
1235
1236    /// Returns a new tensor with the given dimension narrowed to the given range.
1237    ///
1238    /// # Panics
1239    ///
1240    /// - If the dimension is greater than the number of dimensions of the tensor.
1241    /// - If the given range exceeds the number of elements on the given dimension.
1242    ///
1243    /// # Returns
1244    ///
1245    /// A new tensor with the given dimension narrowed to the given range.
1246    ///
1247    /// # Example
1248    ///
1249    /// ```rust
1250    /// use burn_tensor::backend::Backend;
1251    /// use burn_tensor::Tensor;
1252    ///
1253    /// fn example<B: Backend>() {
1254    ///     let device = Default::default();
1255    ///     // Create a 2D tensor with dimensions [4, 3]
1256    ///     let tensor = Tensor::<B, 2>::from_data(
1257    ///         [
1258    ///             [3.0, 4.9, 2.0],
1259    ///             [2.0, 1.9, 3.0],
1260    ///             [6.0, 1.5, 7.0],
1261    ///             [3.0, 4.9, 9.0],
1262    ///         ],
1263    ///         &device,
1264    ///     );
1265    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1266    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1267    ///     // The resulting tensor will have dimensions [3, 3].
1268    ///     let narrowed = tensor.narrow(0, 1, 3);
1269    ///     println!("{narrowed}");
1270    /// }
1271    /// ```
1272    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1273        check!(TensorCheck::dim_ops::<D>("narrow", dim));
1274        check!(TensorCheck::narrow(&self, dim, start, length));
1275        Self::new(narrow::<B, K>(self.primitive, dim, start, length))
1276    }
1277
1278    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1279    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1280    ///
1281    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1282    /// Otherwise all chunks will be of equal size except for the last one.
1283    ///
1284    /// # Panics
1285    ///
1286    /// If the dimension is greater than the number of dimensions of the tensor.
1287    ///
1288    /// # Returns
1289    /// A vector of tensors.
1290    ///
1291    /// # Example
1292    ///
1293    /// ```rust
1294    /// use burn_tensor::backend::Backend;
1295    /// use burn_tensor::Tensor;
1296    ///
1297    /// fn example<B: Backend>() {
1298    ///     let device = Default::default();
1299    ///     // Create a 2D tensor with dimensions [4, 3]
1300    ///     let tensor = Tensor::<B, 2>::from_data(
1301    ///         [
1302    ///             [3.0, 4.9, 2.0],
1303    ///             [2.0, 1.9, 3.0],
1304    ///             [6.0, 1.5, 7.0],
1305    ///             [3.0, 4.9, 9.0],
1306    ///         ],
1307    ///         &device,
1308    ///     );
1309    ///     // Split the tensor along the dimension 1 into 2 chunks.
1310    ///     // The first chuck will have shape [4, 2]:
1311    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1312    ///     // The second chunk will have shape [4, 1]:
1313    ///     // [[2.0], [3.0], [7.0], [9.0]]
1314    ///     let chunks = tensor.chunk(2, 1);
1315    ///     println!("{chunks:?}");
1316    /// }
1317    /// ```
1318    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1319        check!(TensorCheck::dim_ops::<D>("chunk", dim));
1320        K::chunk(self.primitive, chunks, dim)
1321            .into_iter()
1322            .map(Self::new)
1323            .collect()
1324    }
1325
1326    /// Splits the tensor into chunks of a specified size along a given dimension.
1327    /// Each chunk is a view of the original tensor.
1328    ///
1329    /// If the tensor size along the given dimension is not divisible by `split_size`,
1330    /// then the last chunk will be smaller.
1331    ///
1332    /// # Panics
1333    ///
1334    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
1335    ///
1336    /// # Returns
1337    ///
1338    /// A vector of tensors.
1339    ///
1340    /// # Example
1341    /// ```rust
1342    /// use burn_tensor::backend::Backend;
1343    /// use burn_tensor::Tensor;
1344    ///
1345    /// fn example<B: Backend>() {
1346    ///     let device = Default::default();
1347    ///     // Create a 1D tensor with 5 elements
1348    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1349    ///     // Split the tensor into chunks of size 2 along dimension 0
1350    ///     let chunks = tensor.split(2, 0);
1351    ///     // The result is a vector of tensors:
1352    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
1353    ///     println!("{:?}", chunks);
1354    /// }
1355    /// ```
1356    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
1357        check!(TensorCheck::split::<D>(
1358            self.shape().dims.as_ref(),
1359            split_size,
1360            dim
1361        ));
1362        K::split(self.primitive, split_size, dim)
1363            .into_iter()
1364            .map(Self::new)
1365            .collect()
1366    }
1367
1368    /// Splits the tensor into chunks with the specified sizes along a given dimension.
1369    /// Each chunk is a view of the original tensor.
1370    ///
1371    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
1372    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
1373    ///
1374    /// # Panics
1375    ///
1376    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
1377    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
1378    ///
1379    /// # Returns
1380    ///
1381    /// A vector of tensors.
1382    ///
1383    /// # Example
1384    /// ```rust
1385    /// use burn_tensor::backend::Backend;
1386    /// use burn_tensor::Tensor;
1387    ///
1388    /// fn example<B: Backend>() {
1389    ///     let device = Default::default();
1390    ///     // Create a 1D tensor with 5 elements
1391    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1392    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
1393    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
1394    ///     // The result is a vector of tensors:
1395    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
1396    ///     println!("{:?}", chunks);
1397    /// }
1398    /// ```
1399    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
1400        check!(TensorCheck::split_with_sizes::<D>(
1401            self.shape().dims.as_ref(),
1402            &split_sizes,
1403            dim
1404        ));
1405        K::split_with_sizes(self.primitive, split_sizes, dim)
1406            .into_iter()
1407            .map(Self::new)
1408            .collect()
1409    }
1410
1411    /// Tests if any element in the `tensor` evaluates to True.
1412    ///
1413    /// # Arguments
1414    ///
1415    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1416    ///
1417    /// # Returns
1418    ///
1419    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
1420    /// evaluates to True, False otherwise.
1421    ///
1422    /// # Example
1423    ///
1424    /// ```rust
1425    /// use burn_tensor::backend::Backend;
1426    /// use burn_tensor::{Tensor, Bool};
1427    ///
1428    /// fn example<B: Backend>() {
1429    ///   let device = Default::default();
1430    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
1431    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
1432    ///
1433    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1434    ///   let any_tensor = tensor.any();
1435    ///   println!("{}", any_tensor);
1436    ///   // Tensor { data: [true], ... }
1437    ///
1438    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1439    ///   let any_tensor_two = tensor_two.any();
1440    ///   println!("{}", any_tensor_two);
1441    ///   // Tensor { data: [false], ... }
1442    /// }
1443    /// ```
1444    pub fn any(self) -> Tensor<B, 1, Bool> {
1445        Tensor::new(K::any(self.primitive))
1446    }
1447
1448    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
1449    ///
1450    /// # Arguments
1451    ///
1452    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1453    /// * `dim` - The axis along which to test.
1454    ///
1455    /// # Returns
1456    ///
1457    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1458    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1459    /// evaluates to True, False otherwise.
1460    ///
1461    /// # Example
1462    ///
1463    /// ```rust
1464    /// use burn_tensor::backend::Backend;
1465    /// use burn_tensor::{Tensor, Bool};
1466    ///
1467    /// fn example<B: Backend>() {
1468    ///     let device = Default::default();
1469    ///     let tensor =
1470    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
1471    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
1472    ///     // [[true], [true]],
1473    ///     let any_dim = tensor.clone().any_dim(1);
1474    ///     println!("{any_dim}");
1475    /// }
1476    /// ```
1477    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1478        Tensor::new(K::any_dim(self.primitive, dim))
1479    }
1480
1481    /// Tests if all elements in the `tensor` evaluate to True.
1482    ///
1483    /// # Arguments
1484    ///
1485    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1486    ///
1487    /// # Returns
1488    ///
1489    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1490    /// evaluate to True, False otherwise.
1491    ///
1492    /// # Example
1493    ///
1494    /// ```rust
1495    /// use burn_tensor::backend::Backend;
1496    /// use burn_tensor::{Tensor, Bool};
1497    ///
1498    /// fn example<B: Backend>() {
1499    ///     let device = Default::default();
1500    ///     let tensor =
1501    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
1502    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
1503    ///     // [false]
1504    ///     let all = tensor.all();
1505    ///     println!("{all}");
1506    /// }
1507    /// ```
1508    pub fn all(self) -> Tensor<B, 1, Bool> {
1509        Tensor::new(K::all(self.primitive))
1510    }
1511
1512    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1513    ///
1514    /// # Arguments
1515    ///
1516    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1517    /// * `dim` - The axis along which to test.
1518    ///
1519    /// # Returns
1520    ///
1521    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1522    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1523    /// evaluates to True, False otherwise.
1524    ///
1525    /// # Example
1526    ///
1527    /// ```rust
1528    /// use burn_tensor::backend::Backend;
1529    /// use burn_tensor::{Tensor, Bool};
1530    ///
1531    /// fn example<B: Backend>() {
1532    ///     let device = Default::default();
1533    ///     let tensor =
1534    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
1535    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
1536    ///     // [[true, true, false]]
1537    ///     let all_dim = tensor.clone().all_dim(0);
1538    ///     println!("{all_dim}");
1539    /// }
1540    /// ```
1541    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1542        Tensor::new(K::all_dim(self.primitive, dim))
1543    }
1544
1545    /// Convert the tensor into a scalar.
1546    ///
1547    /// # Panics
1548    ///
1549    /// - If the tensor doesn't have one element.
1550    /// - If the backend fails to read the tensor data synchronously.
1551    ///
1552    /// # Returns
1553    ///
1554    /// The scalar value of the tensor.
1555    ///
1556    /// # Example
1557    ///
1558    /// ```rust
1559    /// use burn_tensor::backend::Backend;
1560    /// use burn_tensor::Tensor;
1561    ///
1562    /// fn example<B: Backend>() {
1563    ///     let device = Default::default();
1564    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
1565    ///     // Convert the tensor with a single element into a scalar.
1566    ///     let scalar = tensor.into_scalar();
1567    ///     println!("{scalar}");
1568    /// }
1569    /// ```
1570    pub fn into_scalar(self) -> K::Elem {
1571        crate::try_read_sync(self.into_scalar_async()).expect(
1572            "Failed to read tensor data synchronously. This can happen on platforms
1573            that don't support blocking futures like WASM. Try into_scalar_async instead.",
1574        )
1575    }
1576
1577    /// Convert the tensor into a scalar.
1578    ///
1579    /// # Panics
1580    ///
1581    /// If the tensor doesn't have one element.
1582    pub async fn into_scalar_async(self) -> K::Elem {
1583        check!(TensorCheck::into_scalar::<D>(&self.shape()));
1584        let x = self.into_data_async().await.iter().next().unwrap();
1585        x
1586    }
1587
1588    /// Broadcast the tensor to the given shape.
1589    ///
1590    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
1591    /// (which can be inferred with `-1`).
1592    ///
1593    /// # Arguments
1594    ///
1595    /// * `shape` - The shape to broadcast the tensor to.
1596    ///   Can contain -1 for dimensions that should be inferred.
1597    ///   The number of elements in the shape must be greater or equal as
1598    ///   the number of dimensions of the tensor.
1599    ///
1600    /// # Panics
1601    ///
1602    /// If the tensor cannot be broadcasted to the given shape.
1603    ///
1604    /// # Returns
1605    ///
1606    /// A new tensor with the given shape.
1607    ///
1608    /// # Example
1609    ///
1610    /// ```rust
1611    /// use burn_tensor::backend::Backend;
1612    /// use burn_tensor::Tensor;
1613    ///
1614    /// fn example<B: Backend>() {
1615    ///     let device = Default::default();
1616    ///     // Create a 2D tensor with dimensions [3, 1]
1617    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
1618    ///     // Expand the tensor to a new shape [3, 4]
1619    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
1620    ///     let expanded = tensor.expand([3, 4]);
1621    ///     println!("{}", expanded);
1622    /// }
1623    /// ```
1624    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
1625        let shape = shape.into_shape(&self.shape());
1626        check!(TensorCheck::expand::<D, D2>(
1627            "expand",
1628            &self.shape(),
1629            &shape,
1630        ));
1631
1632        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
1633    }
1634}
1635
1636/// Iterator given by (Tensor::iter_dim).
1637pub struct DimIter<B, const D: usize, K>
1638where
1639    B: Backend,
1640    K: BasicOps<B>,
1641{
1642    start: usize,
1643    end: usize,
1644    dim: usize,
1645    ranges: [Range<usize>; D],
1646    tensor: Tensor<B, D, K>,
1647}
1648
1649impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
1650    type Item = Tensor<B, D, K>;
1651
1652    fn next(&mut self) -> Option<Self::Item> {
1653        if self.start >= self.end {
1654            return None;
1655        }
1656
1657        let mut ranges = self.ranges.clone();
1658        ranges[self.dim] = self.start..(self.start + 1);
1659
1660        let slice = self.tensor.clone().slice(ranges);
1661        self.start += 1;
1662
1663        Some(slice)
1664    }
1665}
1666
1667impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
1668    fn next_back(&mut self) -> Option<Self::Item> {
1669        if self.start >= self.end {
1670            return None;
1671        }
1672
1673        let mut ranges = self.ranges.clone();
1674        ranges[self.dim] = (self.end - 1)..self.end;
1675
1676        let slice = self.tensor.clone().slice(ranges);
1677        self.end = self.end.saturating_sub(1);
1678
1679        Some(slice)
1680    }
1681}
1682
1683impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
1684    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
1685        let dims = tensor.dims();
1686        let ranges = dims
1687            .iter()
1688            .map(|&dim| 0..dim)
1689            .collect::<Vec<Range<usize>>>();
1690        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
1691        Self {
1692            end: dims[dim],
1693            ranges,
1694            start: 0,
1695            dim,
1696            tensor,
1697        }
1698    }
1699}
1700
1701impl<B, const D: usize, K> Tensor<B, D, K>
1702where
1703    B: Backend,
1704    K: BasicOps<B>,
1705    <K as BasicOps<B>>::Elem: Debug,
1706{
1707    #[inline]
1708    fn push_newline_indent(acc: &mut String, indent: usize) {
1709        acc.push('\n');
1710        for _ in 0..indent {
1711            acc.push(' ');
1712        }
1713    }
1714    fn fmt_inner_tensor(
1715        &self,
1716        acc: &mut String,
1717        depth: usize,
1718        multi_index: &mut [usize],
1719        range: (usize, usize),
1720        precision: Option<usize>,
1721    ) {
1722        let (start, end) = range;
1723        for i in start..end {
1724            if i > 0 {
1725                acc.push_str(", ");
1726            }
1727            multi_index[depth] = i;
1728            let range: [Range<usize>; D] =
1729                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
1730
1731            let data =
1732                burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
1733
1734            if let Some(data) = data {
1735                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
1736                match (precision, K::name()) {
1737                    (Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
1738                    (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
1739                    _ => acc.push_str(&format!("{:?}", elem)),
1740                }
1741            } else {
1742                acc.push_str("<Tensor data not available>");
1743            }
1744        }
1745    }
1746
1747    fn fmt_outer_tensor(
1748        &self,
1749        acc: &mut String,
1750        depth: usize,
1751        multi_index: &mut [usize],
1752        print_options: &PrintOptions,
1753        summarize: bool,
1754        range: (usize, usize),
1755    ) {
1756        let (start, end) = range;
1757        for i in start..end {
1758            if i > start {
1759                acc.push(',');
1760                Self::push_newline_indent(acc, depth + 1);
1761            }
1762            acc.push('[');
1763            multi_index[depth] = i;
1764            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
1765            acc.push(']');
1766        }
1767    }
1768
1769    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
1770    ///
1771    /// This function is designed to work with tensors of any dimensionality.
1772    /// It traverses the tensor dimensions recursively, converting the elements
1773    /// to strings and appending them to the accumulator string with the
1774    /// appropriate formatting.
1775    ///
1776    /// # Arguments
1777    ///
1778    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
1779    /// * `depth` - The current depth of the tensor dimensions being processed.
1780    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
1781    fn display_recursive(
1782        &self,
1783        acc: &mut String,
1784        depth: usize,
1785        multi_index: &mut [usize],
1786        print_options: &PrintOptions,
1787        summarize: bool,
1788    ) {
1789        let edge_items = print_options.edge_items;
1790
1791        if depth == 0 {
1792            acc.push('[');
1793        }
1794
1795        if depth == self.dims().len() - 1 {
1796            // if we are at the innermost dimension, just push its elements into the accumulator
1797            if summarize && self.dims()[depth] > 2 * edge_items {
1798                // print the starting `edge_items` elements
1799                self.fmt_inner_tensor(
1800                    acc,
1801                    depth,
1802                    multi_index,
1803                    (0, edge_items),
1804                    print_options.precision,
1805                );
1806                acc.push_str(", ...");
1807                // print the last `edge_items` elements
1808                self.fmt_inner_tensor(
1809                    acc,
1810                    depth,
1811                    multi_index,
1812                    (self.dims()[depth] - edge_items, self.dims()[depth]),
1813                    print_options.precision,
1814                );
1815            } else {
1816                // print all the elements
1817                self.fmt_inner_tensor(
1818                    acc,
1819                    depth,
1820                    multi_index,
1821                    (0, self.dims()[depth]),
1822                    print_options.precision,
1823                );
1824            }
1825        } else {
1826            // otherwise, iterate through the current dimension and recursively display the inner tensors
1827            if summarize && self.dims()[depth] > 2 * edge_items {
1828                self.fmt_outer_tensor(
1829                    acc,
1830                    depth,
1831                    multi_index,
1832                    print_options,
1833                    summarize,
1834                    (0, edge_items),
1835                );
1836
1837                acc.push(',');
1838                Self::push_newline_indent(acc, depth + 1);
1839                acc.push_str("...");
1840                Self::push_newline_indent(acc, depth + 1);
1841
1842                self.fmt_outer_tensor(
1843                    acc,
1844                    depth,
1845                    multi_index,
1846                    print_options,
1847                    summarize,
1848                    (self.dims()[depth] - edge_items, self.dims()[depth]),
1849                );
1850            } else {
1851                self.fmt_outer_tensor(
1852                    acc,
1853                    depth,
1854                    multi_index,
1855                    print_options,
1856                    summarize,
1857                    (0, self.dims()[depth]),
1858                );
1859            }
1860        }
1861
1862        if depth == 0 {
1863            acc.push(']');
1864        }
1865    }
1866}
1867
1868#[derive(Clone, Debug)]
1869/// Options for Tensor pretty printing
1870pub struct PrintOptions {
1871    /// number of elements to start summarizing tensor
1872    pub threshold: usize,
1873
1874    /// number of starting elements and ending elements to display
1875    pub edge_items: usize,
1876
1877    /// Precision for floating point numbers
1878    pub precision: Option<usize>,
1879}
1880
1881static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
1882
1883impl PrintOptions {
1884    /// Print options with default values
1885    pub const fn const_default() -> Self {
1886        Self {
1887            threshold: 1000,
1888            edge_items: 3,
1889            precision: None,
1890        }
1891    }
1892}
1893
1894impl Default for PrintOptions {
1895    fn default() -> Self {
1896        Self::const_default()
1897    }
1898}
1899
1900/// Set print options
1901pub fn set_print_options(options: PrintOptions) {
1902    let mut print_opts = PRINT_OPTS.write().unwrap();
1903    *print_opts = options;
1904}
1905
1906/// Pretty print tensors
1907impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
1908where
1909    B: Backend,
1910    B::IntElem: core::fmt::Display,
1911    K: BasicOps<B>,
1912    <K as BasicOps<B>>::Elem: Debug,
1913{
1914    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1915        writeln!(f, "Tensor {{")?;
1916
1917        {
1918            // Do not lock the mutex for the whole function
1919            let mut po = { PRINT_OPTS.read().unwrap().clone() };
1920
1921            // Override the precision if it is set from the formatter
1922            // This will be possible when the tensor is printed using the `{:.*}` syntax
1923            if let Some(precision) = f.precision() {
1924                po.precision = Some(precision);
1925            }
1926
1927            let mut acc = String::new();
1928            let mut multi_index = vec![0; D];
1929            let summarize = self.shape().num_elements() > po.threshold;
1930
1931            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
1932
1933            writeln!(f, "  data:")?;
1934            write!(f, "{acc}")?;
1935            writeln!(f, ",")?;
1936        }
1937
1938        writeln!(f, "  shape:  {:?},", self.dims())?;
1939        writeln!(f, "  device:  {:?},", self.device())?;
1940        writeln!(f, "  backend:  {:?},", B::name(&self.device()))?;
1941        writeln!(f, "  kind:  {:?},", K::name())?;
1942
1943        let dtype = self.primitive.dtype();
1944
1945        writeln!(f, "  dtype:  {:?},", dtype.name())?;
1946        write!(f, "}}")
1947    }
1948}
1949
1950/// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g.
1951/// ```rust
1952/// use burn_tensor::backend::Backend;
1953/// use burn_tensor::{Tensor, T};
1954///
1955/// fn example<B: Backend>() {
1956///     let device = Default::default();
1957///     let tensor = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
1958///     let transposed = tensor^T;
1959/// }
1960/// ```
1961pub struct T;
1962
1963impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
1964    type Output = Self;
1965    fn bitxor(self, _: T) -> Self::Output {
1966        self.transpose()
1967    }
1968}
1969
1970/// Trait that list all operations that can be applied on all tensors.
1971///
1972/// # Warnings
1973///
1974/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
1975pub trait BasicOps<B: Backend>: TensorKind<B> {
1976    /// The type of the tensor elements.
1977    type Elem: Element;
1978
1979    /// Creates an empty tensor with the given shape.
1980    ///
1981    /// # Arguments
1982    ///
1983    /// * `shape` - The shape of the tensor.
1984    /// * `device` - The device on which the tensor will be allocated.
1985    ///
1986    /// # Returns
1987    ///
1988    /// The empty tensor.
1989    ///
1990    /// # Remarks
1991    ///
1992    /// This is a low-level function used internally by the library to call different backend functions
1993    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1994    /// or use this function directly.
1995    ///
1996    /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
1997    /// which is more high-level and designed for public use.
1998    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive;
1999
2000    /// Reshapes the tensor.
2001    ///
2002    /// # Arguments
2003    ///
2004    /// * `tensor` - The tensor.
2005    /// * `shape` - The new shape of the tensor.
2006    ///
2007    /// # Returns
2008    ///
2009    /// The reshaped tensor.
2010    ///
2011    /// # Remarks
2012    ///
2013    /// This is a low-level function used internally by the library to call different backend functions
2014    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2015    /// or use this function directly.
2016    ///
2017    /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
2018    /// which is more high-level and designed for public use.
2019    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2020
2021    /// Transposes a tensor.
2022    ///
2023    /// # Arguments
2024    ///
2025    /// * `tensor` - The tensor to transpose.
2026    ///
2027    /// # Returns
2028    ///
2029    /// The transposed tensor.
2030    fn transpose(tensor: Self::Primitive) -> Self::Primitive;
2031
2032    /// Swaps two dimensions of a tensor.
2033    ///
2034    /// # Arguments
2035    ///
2036    /// * `tensor` - The tensor to swap the dimensions of.
2037    /// * `dim1` - The first dimension to swap.
2038    /// * `dim2` - The second dimension to swap.
2039    ///
2040    /// # Returns
2041    ///
2042    /// The tensor with the dimensions swapped.
2043    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
2044
2045    /// Permutes the dimensions of a tensor.
2046    ///
2047    /// # Arguments
2048    ///
2049    /// * `tensor` - The tensor to permute the dimensions of.
2050    /// * `axes` - The new order of the dimensions.
2051    ///
2052    /// # Returns
2053    ///
2054    /// The tensor with the dimensions permuted.
2055    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2056
2057    /// Flips the tensor along the given axes.
2058    ///
2059    /// # Arguments
2060    ///
2061    /// * `tensor` - The tensor to flip.
2062    /// * `axes` - The axes to flip the tensor along.
2063    ///
2064    /// # Returns
2065    ///
2066    /// The tensor with the axes flipped.
2067    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2068
2069    ///  Select tensor elements corresponding for the given ranges.
2070    ///
2071    /// # Arguments
2072    ///
2073    /// * `tensor` - The tensor.
2074    /// * `ranges` - The ranges of the elements to select.
2075    ///
2076    /// # Returns
2077    ///
2078    /// The selected elements.
2079    ///
2080    /// # Remarks
2081    ///
2082    /// This is a low-level function used internally by the library to call different backend functions
2083    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2084    /// or use this function directly.
2085    ///
2086    /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2087    /// which is more high-level and designed for public use.
2088    fn slice(tensor: Self::Primitive, range: &[Range<usize>]) -> Self::Primitive;
2089
2090    ///  Assigns the given value to the tensor elements corresponding for the given ranges.
2091    ///
2092    /// # Arguments
2093    ///
2094    /// * `tensor` - The tensor.
2095    /// * `ranges` - The ranges of the elements to select.
2096    /// * `value` - The value to assign.
2097    ///
2098    /// # Returns
2099    ///
2100    /// The tensor with the assigned values.
2101    ///
2102    /// # Remarks
2103    ///
2104    /// This is a low-level function used internally by the library to call different backend functions
2105    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2106    /// or use this function directly.
2107    ///
2108    /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2109    /// which is more high-level and designed for public use.
2110    fn slice_assign(
2111        tensor: Self::Primitive,
2112        ranges: &[Range<usize>],
2113        value: Self::Primitive,
2114    ) -> Self::Primitive;
2115
2116    /// Returns the device on which the tensor is allocated.
2117    ///
2118    /// # Arguments
2119    ///
2120    /// * `tensor` - The tensor.
2121    ///
2122    /// # Returns
2123    ///
2124    /// The device on which the tensor is allocated.
2125    ///
2126    /// # Remarks
2127    ///
2128    /// This is a low-level function used internally by the library to call different backend functions
2129    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2130    /// or use this function directly.
2131    ///
2132    /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2133    /// which is more high-level and designed for public use.
2134    fn device(tensor: &Self::Primitive) -> B::Device;
2135
2136    /// Moves the tensor to the given device.
2137    ///
2138    /// # Arguments
2139    ///
2140    /// * `tensor` - The tensor.
2141    /// * `device` - The device on which the tensor will be moved.
2142    ///
2143    /// # Returns
2144    ///
2145    /// The tensor on the given device.
2146    ///
2147    /// # Remarks
2148    ///
2149    /// This is a low-level function used internally by the library to call different backend functions
2150    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2151    /// or use this function directly.
2152    ///
2153    /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2154    /// which is more high-level and designed for public use.
2155    fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2156
2157    /// Extracts the data from the tensor asynchronously.
2158    ///
2159    /// # Arguments
2160    ///
2161    /// * `tensor` - The tensor.
2162    ///
2163    /// # Returns
2164    ///
2165    /// The data of the tensor.
2166    ///
2167    /// # Remarks
2168    ///
2169    /// This is a low-level function used internally by the library to call different backend functions
2170    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2171    /// or use this function directly.
2172    ///
2173    /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
2174    /// which is more high-level and designed for public use.
2175    fn into_data_async(
2176        tensor: Self::Primitive,
2177    ) -> impl Future<Output = TensorData> + 'static + Send;
2178
2179    /// Read the data from the tensor using a transaction.
2180    ///
2181    /// # Remarks
2182    ///
2183    /// This is a low-level function used internally by the library to call different backend functions
2184    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2185    /// or use this function directly.
2186    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
2187
2188    /// Creates a tensor from the given data.
2189    ///
2190    /// # Arguments
2191    ///
2192    /// * `data` - The data of the tensor.
2193    /// * `device` - The device on which the tensor will be allocated.
2194    ///
2195    /// # Returns
2196    ///
2197    /// The tensor.
2198    ///
2199    /// # Remarks
2200    ///
2201    /// This is a low-level function used internally by the library to call different backend functions
2202    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2203    /// or use this function directly.
2204    ///
2205    /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
2206    /// which is more high-level and designed for public use.
2207    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
2208    /// Creates a tensor from the given data enforcing the given data type.
2209    ///
2210    /// # Remarks
2211    ///
2212    /// This is a low-level function used internally by the library to call different backend functions
2213    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2214    /// or use this function directly.
2215    ///
2216    /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype)
2217    /// function, which is more high-level and designed for public use.
2218    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
2219
2220    /// Repeat the tensor along the given dimension.
2221    ///
2222    /// # Arguments
2223    ///
2224    /// * `tensor` - The tensor.
2225    /// * `dim` - The dimension along which the tensor will be repeated.
2226    /// * `times` - The number of times the tensor will be repeated.
2227    ///
2228    /// # Returns
2229    ///
2230    /// The repeated tensor.
2231    ///
2232    /// # Remarks
2233    ///
2234    /// This is a low-level function used internally by the library to call different backend functions
2235    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2236    /// or use this function directly.
2237    ///
2238    /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
2239    /// which is more high-level and designed for public use.
2240    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
2241
2242    /// Concatenates the given tensors along the given dimension.
2243    ///
2244    /// # Arguments
2245    ///
2246    /// * `vectors` - The tensors to concatenate.
2247    /// * `dim` - The dimension along which the tensors will be concatenated.
2248    ///
2249    /// # Returns
2250    ///
2251    /// The concatenated tensor.
2252    ///
2253    /// # Remarks
2254    ///
2255    /// This is a low-level function used internally by the library to call different backend functions
2256    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2257    /// or use this function directly.
2258    ///
2259    /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
2260    /// which is more high-level and designed for public use.
2261    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
2262
2263    /// Attempts to split the tensor along the given dimension into chunks.
2264    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2265    ///
2266    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2267    /// Otherwise all chunks will be of equal size except for the last one.
2268    ///
2269    /// # Panics
2270    ///
2271    ///  If the dimension is greater than the number of dimensions of the tensor.
2272    ///
2273    /// # Returns
2274    /// A vector of tensors.
2275    ///
2276    /// # Remarks
2277    ///
2278    /// This is a low-level function used internally by the library to call different backend functions
2279    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2280    /// or use this function directly.
2281    ///
2282    /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
2283    /// which is more high-level and designed for public use.
2284    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;
2285
2286    /// Splits the tensor into chunks of a specified size along a given dimension.
2287    /// Each chunk is a view of the original tensor.
2288    ///
2289    /// # Panics
2290    ///
2291    /// If the dimension to split along is greater than the number of dimensions of the tensor.
2292    ///
2293    /// # Returns
2294    ///
2295    /// A vector of tensors.
2296    ///
2297    /// # Remarks
2298    /// This is a low-level function used internally by the library to call different backend functions
2299    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2300    /// or use this function directly.
2301    ///
2302    /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
2303    /// which is more high-level and designed for public use.
2304    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
2305
2306    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2307    /// Each chunk is a view of the original tensor.
2308    ///
2309    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2310    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2311    ///
2312    /// # Panics
2313    ///
2314    /// If the dimension to split along is greater than the number of dimensions of the tensor or
2315    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2316    ///
2317    /// # Returns
2318    ///
2319    /// A vector of tensors.
2320    ///
2321    /// # Remarks
2322    /// This is a low-level function used internally by the library to call different backend functions
2323    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2324    /// or use this function directly.
2325    ///
2326    /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
2327    /// which is more high-level and designed for public use.
2328    fn split_with_sizes(
2329        tensor: Self::Primitive,
2330        split_sizes: Vec<usize>,
2331        dim: usize,
2332    ) -> Vec<Self::Primitive>;
2333
2334    /// Equates the given tensors.
2335    ///
2336    /// # Arguments
2337    ///
2338    /// * `lhs` - The left hand side tensor.
2339    /// * `rhs` - The right hand side tensor.
2340    ///
2341    /// # Returns
2342    ///
2343    /// The tensor of booleans indicating whether the corresponding elements are equal.
2344    ///
2345    /// # Remarks
2346    ///
2347    /// This is a low-level function used internally by the library to call different backend functions
2348    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2349    /// or use this function directly.
2350    ///
2351    /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
2352    /// which is more high-level and designed for public use.
2353    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2354
2355    /// Applies element-wise non-equality comparison between the given tensors.
2356    ///
2357    /// # Arguments
2358    ///
2359    /// * `lhs` - The left hand side tensor.
2360    /// * `rhs` - The right hand side tensor.
2361    ///
2362    /// # Returns
2363    ///
2364    /// The tensor of booleans indicating whether the corresponding elements are equal.
2365    ///
2366    /// # Remarks
2367    ///
2368    /// This is a low-level function used internally by the library to call different backend functions
2369    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2370    /// or use this function directly.
2371    ///
2372    /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
2373    /// function, which is more high-level and designed for public use.
2374    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2375
2376    /// Returns the name of the element type.
2377    fn elem_type_name() -> &'static str {
2378        core::any::type_name::<Self::Elem>()
2379    }
2380
2381    /// Returns the tensor data type.
2382    fn dtype(tensor: &Self::Primitive) -> DType {
2383        tensor.dtype()
2384    }
2385
2386    /// Tests if any element in the `tensor` evaluates to True.
2387    ///
2388    /// # Arguments
2389    ///
2390    /// * `tensor` - The tensor to test.
2391    ///
2392    /// # Returns
2393    ///
2394    /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
2395    ///
2396    /// # Remarks
2397    ///
2398    /// This is a low-level function used internally by the library to call different backend functions
2399    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2400    /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
2401    /// which is more high-level and designed for public use.
2402    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2403
2404    /// Tests if any element in the tensor evaluates to True along a given dimension dim.
2405    ///
2406    /// # Arguments
2407    ///
2408    /// * tensor - The tensor to test.
2409    /// * dim - The axis along which to test.
2410    ///
2411    /// # Returns
2412    ///
2413    /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
2414    /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
2415    ///
2416    /// # Remarks
2417    ///
2418    /// This is a low-level function used internally by the library to call different backend functions
2419    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2420    /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
2421    /// which is more high-level and designed for public use.
2422    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2423
2424    /// Tests if all elements in the `tensor` evaluate to True.
2425    ///
2426    /// # Arguments
2427    ///
2428    /// * `tensor` - The tensor to test.
2429    ///
2430    /// # Returns
2431    ///
2432    /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
2433    ///
2434    /// # Remarks
2435    ///
2436    /// This is a low-level function used internally by the library to call different backend functions
2437    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2438    /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
2439    /// which is more high-level and designed for public use.
2440    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2441
2442    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2443    ///
2444    /// # Arguments
2445    ///
2446    /// * `tensor` - The tensor to test.
2447    ///
2448    /// # Returns
2449    ///
2450    /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
2451    /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
2452    ///
2453    /// # Remarks
2454    ///
2455    /// This is a low-level function used internally by the library to call different backend functions
2456    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2457    /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
2458    /// which is more high-level and designed for public use.
2459    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2460
2461    /// Broadcasts the given tensor to the specified shape.
2462    ///
2463    /// # Arguments
2464    ///
2465    /// * `tensor` - The tensor to broadcast.
2466    /// * `shape` - The shape to broadcast to.
2467    ///
2468    /// # Returns
2469    ///
2470    /// The broadcasted tensor.
2471    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2472}
2473
2474impl<B: Backend> BasicOps<B> for Float {
2475    type Elem = B::FloatElem;
2476
2477    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2478        TensorPrimitive::Float(B::float_empty(shape, device))
2479    }
2480
2481    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2482        tr.register_float(tensor);
2483    }
2484
2485    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2486        match tensor {
2487            TensorPrimitive::Float(tensor) => {
2488                TensorPrimitive::Float(B::float_reshape(tensor, shape))
2489            }
2490            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
2491        }
2492    }
2493
2494    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2495        match tensor {
2496            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
2497            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
2498        }
2499    }
2500
2501    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2502        match tensor {
2503            TensorPrimitive::Float(tensor) => {
2504                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
2505            }
2506            TensorPrimitive::QFloat(tensor) => {
2507                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
2508            }
2509        }
2510    }
2511
2512    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2513        match tensor {
2514            TensorPrimitive::Float(tensor) => {
2515                TensorPrimitive::Float(B::float_slice(tensor, ranges))
2516            }
2517            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)),
2518        }
2519    }
2520
2521    fn slice_assign(
2522        tensor: Self::Primitive,
2523        ranges: &[Range<usize>],
2524        value: Self::Primitive,
2525    ) -> Self::Primitive {
2526        match (tensor, value) {
2527            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => {
2528                TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value))
2529            }
2530            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => {
2531                TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value))
2532            }
2533            _ => panic!("Primitive type mismatch for tensor and value"),
2534        }
2535    }
2536
2537    fn device(tensor: &Self::Primitive) -> Device<B> {
2538        match tensor {
2539            TensorPrimitive::Float(tensor) => B::float_device(tensor),
2540            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
2541        }
2542    }
2543
2544    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2545        match tensor {
2546            TensorPrimitive::Float(tensor) => {
2547                TensorPrimitive::Float(B::float_to_device(tensor, device))
2548            }
2549            TensorPrimitive::QFloat(tensor) => {
2550                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
2551            }
2552        }
2553    }
2554
2555    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2556        match tensor {
2557            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
2558            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
2559        }
2560    }
2561
2562    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2563        match data.dtype {
2564            DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
2565            _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
2566        }
2567    }
2568
2569    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2570        match dtype {
2571            DType::QFloat(_strategy) => {
2572                TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
2573            }
2574            _ if dtype.is_float() => {
2575                TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
2576            }
2577            _ => panic!("Expected float dtype, got {dtype:?}"),
2578        }
2579    }
2580
2581    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2582        match tensor {
2583            TensorPrimitive::Float(tensor) => {
2584                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
2585            }
2586            TensorPrimitive::QFloat(tensor) => {
2587                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
2588            }
2589        }
2590    }
2591
2592    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2593        match vectors.first().unwrap() {
2594            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
2595                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
2596                dim,
2597            )),
2598            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
2599                vectors
2600                    .into_iter()
2601                    .map(|tensor| {
2602                        if let TensorPrimitive::QFloat(t) = tensor {
2603                            t
2604                        } else {
2605                            panic!("Concatenation only works with vector of QFloat")
2606                        }
2607                    })
2608                    .collect(),
2609                dim,
2610            )),
2611        }
2612    }
2613
2614    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2615        B::float_equal(lhs.tensor(), rhs.tensor())
2616    }
2617
2618    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2619        B::float_not_equal(lhs.tensor(), rhs.tensor())
2620    }
2621
2622    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2623        B::float_any(tensor.tensor())
2624    }
2625
2626    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2627        B::float_any_dim(tensor.tensor(), dim)
2628    }
2629
2630    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2631        B::float_all(tensor.tensor())
2632    }
2633
2634    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2635        B::float_all_dim(tensor.tensor(), dim)
2636    }
2637
2638    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2639        match tensor {
2640            TensorPrimitive::Float(tensor) => {
2641                TensorPrimitive::Float(B::float_permute(tensor, axes))
2642            }
2643            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
2644        }
2645    }
2646
2647    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2648        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
2649    }
2650
2651    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2652        match tensor {
2653            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
2654            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
2655        }
2656    }
2657
2658    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2659        match tensor {
2660            TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim)
2661                .into_iter()
2662                .map(TensorPrimitive::Float)
2663                .collect(),
2664            TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim)
2665                .into_iter()
2666                .map(TensorPrimitive::QFloat)
2667                .collect(),
2668        }
2669    }
2670
2671    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2672        match tensor {
2673            TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
2674                .into_iter()
2675                .map(TensorPrimitive::Float)
2676                .collect(),
2677            TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
2678                .into_iter()
2679                .map(TensorPrimitive::QFloat)
2680                .collect(),
2681        }
2682    }
2683
2684    fn split_with_sizes(
2685        tensor: Self::Primitive,
2686        split_sizes: Vec<usize>,
2687        dim: usize,
2688    ) -> Vec<Self::Primitive> {
2689        match tensor {
2690            TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
2691                .into_iter()
2692                .map(TensorPrimitive::Float)
2693                .collect(),
2694            TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
2695                .into_iter()
2696                .map(TensorPrimitive::QFloat)
2697                .collect(),
2698        }
2699    }
2700}
2701
2702impl<B: Backend> BasicOps<B> for Int {
2703    type Elem = B::IntElem;
2704
2705    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2706        B::int_empty(shape, device)
2707    }
2708
2709    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2710        tr.register_int(tensor);
2711    }
2712
2713    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2714        B::int_reshape(tensor, shape)
2715    }
2716
2717    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2718        B::int_transpose(tensor)
2719    }
2720
2721    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2722        B::int_swap_dims(tensor, dim1, dim2)
2723    }
2724
2725    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2726        B::int_slice(tensor, ranges)
2727    }
2728
2729    fn slice_assign(
2730        tensor: Self::Primitive,
2731        ranges: &[Range<usize>],
2732        value: Self::Primitive,
2733    ) -> Self::Primitive {
2734        B::int_slice_assign(tensor, ranges, value)
2735    }
2736
2737    fn device(tensor: &Self::Primitive) -> Device<B> {
2738        B::int_device(tensor)
2739    }
2740
2741    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2742        B::int_to_device(tensor, device)
2743    }
2744
2745    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2746        B::int_into_data(tensor).await
2747    }
2748
2749    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2750        B::int_from_data(data.convert::<B::IntElem>(), device)
2751    }
2752
2753    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2754        if !dtype.is_int() {
2755            panic!("Expected int dtype, got {dtype:?}")
2756        }
2757
2758        B::int_from_data(data.convert_dtype(dtype), device)
2759    }
2760
2761    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2762        B::int_repeat_dim(tensor, dim, times)
2763    }
2764
2765    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2766        B::int_equal(lhs, rhs)
2767    }
2768
2769    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2770        B::int_not_equal(lhs, rhs)
2771    }
2772
2773    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2774        B::int_cat(vectors, dim)
2775    }
2776
2777    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2778        B::int_any(tensor)
2779    }
2780
2781    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2782        B::int_any_dim(tensor, dim)
2783    }
2784
2785    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2786        B::int_all(tensor)
2787    }
2788
2789    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2790        B::int_all_dim(tensor, dim)
2791    }
2792
2793    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2794        B::int_permute(tensor, axes)
2795    }
2796
2797    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2798        B::int_expand(tensor, shape)
2799    }
2800
2801    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2802        B::int_flip(tensor, axes)
2803    }
2804
2805    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2806        B::int_chunk(tensor, chunks, dim)
2807    }
2808
2809    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2810        B::int_split(tensor, split_size, dim)
2811    }
2812
2813    fn split_with_sizes(
2814        tensor: Self::Primitive,
2815        split_sizes: Vec<usize>,
2816        dim: usize,
2817    ) -> Vec<Self::Primitive> {
2818        B::int_split_with_sizes(tensor, split_sizes, dim)
2819    }
2820}
2821
2822impl<B: Backend> BasicOps<B> for Bool {
2823    type Elem = B::BoolElem;
2824
2825    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2826        B::bool_empty(shape, device)
2827    }
2828
2829    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2830        tr.register_bool(tensor);
2831    }
2832
2833    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2834        B::bool_reshape(tensor, shape)
2835    }
2836
2837    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2838        B::bool_transpose(tensor)
2839    }
2840
2841    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2842        B::bool_swap_dims(tensor, dim1, dim2)
2843    }
2844
2845    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2846        B::bool_slice(tensor, ranges)
2847    }
2848
2849    fn slice_assign(
2850        tensor: Self::Primitive,
2851        ranges: &[Range<usize>],
2852        value: Self::Primitive,
2853    ) -> Self::Primitive {
2854        B::bool_slice_assign(tensor, ranges, value)
2855    }
2856
2857    fn device(tensor: &Self::Primitive) -> Device<B> {
2858        B::bool_device(tensor)
2859    }
2860
2861    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2862        B::bool_to_device(tensor, device)
2863    }
2864
2865    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2866        B::bool_into_data(tensor).await
2867    }
2868
2869    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2870        B::bool_from_data(data.convert::<B::BoolElem>(), device)
2871    }
2872
2873    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2874        // Backends only use one bool representation dtype
2875        if dtype != B::BoolElem::dtype() {
2876            panic!("Expected bool dtype, got {dtype:?}")
2877        }
2878        B::bool_from_data(data.convert_dtype(dtype), device)
2879    }
2880
2881    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2882        B::bool_repeat_dim(tensor, dim, times)
2883    }
2884
2885    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2886        B::bool_equal(lhs, rhs)
2887    }
2888
2889    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2890        B::bool_not_equal(lhs, rhs)
2891    }
2892
2893    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2894        B::bool_cat(vectors, dim)
2895    }
2896
2897    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2898        B::bool_any(tensor)
2899    }
2900
2901    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2902        B::bool_any_dim(tensor, dim)
2903    }
2904
2905    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2906        B::bool_all(tensor)
2907    }
2908
2909    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2910        B::bool_all_dim(tensor, dim)
2911    }
2912
2913    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2914        B::bool_permute(tensor, axes)
2915    }
2916
2917    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2918        B::bool_expand(tensor, shape)
2919    }
2920
2921    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2922        B::bool_flip(tensor, axes)
2923    }
2924
2925    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2926        B::bool_chunk(tensor, chunks, dim)
2927    }
2928
2929    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2930        B::bool_split(tensor, split_size, dim)
2931    }
2932
2933    fn split_with_sizes(
2934        tensor: Self::Primitive,
2935        split_sizes: Vec<usize>,
2936        dim: usize,
2937    ) -> Vec<Self::Primitive> {
2938        B::bool_split_with_sizes(tensor, split_sizes, dim)
2939    }
2940}
2941
2942/// Trait used for movedim arguments
2943pub trait MovedimArgs {
2944    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
2945    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
2946}
2947
2948impl MovedimArgs for Vec<i32> {
2949    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2950        let set = self
2951            .iter()
2952            .map(|&dim| {
2953                if dim < 0 {
2954                    (D as i32 + dim) as usize
2955                } else {
2956                    dim as usize
2957                }
2958            })
2959            .collect::<Vec<usize>>();
2960        check!(TensorCheck::movedim_args_vec::<D>(&set));
2961
2962        set
2963    }
2964}
2965
2966impl MovedimArgs for Vec<usize> {
2967    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2968        check!(TensorCheck::movedim_args_vec::<D>(&self));
2969        self
2970    }
2971}
2972
2973impl MovedimArgs for usize {
2974    #[allow(clippy::vec_init_then_push)]
2975    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2976        check!(TensorCheck::movedim_args_usize::<D>(self));
2977
2978        let mut set = Vec::with_capacity(1);
2979        set.push(self);
2980
2981        set
2982    }
2983}
2984
2985impl MovedimArgs for i32 {
2986    #[allow(clippy::vec_init_then_push)]
2987    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2988        check!(TensorCheck::movedim_args_i32::<D>(self));
2989
2990        let dim = if self < 0 {
2991            (D as i32 + self) as usize
2992        } else {
2993            self as usize
2994        };
2995
2996        let mut set = Vec::with_capacity(1);
2997        set.push(dim);
2998
2999        set
3000    }
3001}
3002
3003/// Trait used for slice arguments
3004pub trait RangesArg<const D2: usize> {
3005    /// Converts into a set of ranges to `[Range<usize>; D2]` for the `tensor.slice()` function
3006    fn into_ranges(self, shape: Shape) -> [Range<usize>; D2];
3007}
3008
3009impl<const D2: usize, T: Into<Slice>> RangesArg<D2> for [T; D2] {
3010    fn into_ranges(self, shape: Shape) -> [Range<usize>; D2] {
3011        // clamp the ranges to the shape dimensions
3012        let ranges = self
3013            .into_iter()
3014            .enumerate()
3015            .map(|(i, range)| range.into().into_range(shape.dims[i]))
3016            .collect::<Vec<_>>();
3017        ranges.try_into().unwrap()
3018    }
3019}
3020
3021impl<T: Into<Slice>> RangesArg<1> for T {
3022    fn into_ranges(self, shape: Shape) -> [Range<usize>; 1] {
3023        [self.into().into_range(shape.dims[0])]
3024    }
3025}
3026
3027/// Trait used for reshape arguments.
3028pub trait ReshapeArgs<const D2: usize> {
3029    /// Converts to a shape.
3030    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3031        self,
3032        tensor: &Tensor<B, D, K>,
3033    ) -> Shape;
3034}
3035
3036impl<const D2: usize> ReshapeArgs<D2> for Shape {
3037    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3038        self,
3039        tensor: &Tensor<B, D, K>,
3040    ) -> Shape {
3041        check!(TensorCheck::reshape_args_usize::<D, D2>(
3042            &tensor.shape(),
3043            &self
3044        ));
3045
3046        self
3047    }
3048}
3049impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3050    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3051        self,
3052        tensor: &Tensor<B, D, K>,
3053    ) -> Shape {
3054        let shape = Shape::from(self);
3055
3056        check!(TensorCheck::reshape_args_usize::<D, D2>(
3057            &tensor.shape(),
3058            &shape
3059        ));
3060
3061        shape
3062    }
3063}
3064
3065impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3066    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3067        self,
3068        tensor: &Tensor<B, D, K>,
3069    ) -> Shape {
3070        // Validate the reshape arguments
3071        check!(TensorCheck::reshape_args_i32(&self));
3072
3073        // Temporary shape
3074        let mut new_shape: [i32; D2] = [1; D2];
3075
3076        // We need to find the index of the 0 dimension and
3077        // replace it with the actual dimension value.
3078        for (i, &s) in self.iter().enumerate() {
3079            if s != 0 {
3080                new_shape[i] = s;
3081            } else {
3082                new_shape[i] = tensor.dims()[i] as i32;
3083            }
3084        }
3085
3086        // Find the index of the inferred dimension (-1)
3087        let infer_index = new_shape.iter().position(|x| x == &-1);
3088
3089        // Handle the case where the dimension is inferred (via -1)
3090        if let Some(index) = infer_index {
3091            // Handle the case where the dimension is inferred
3092            let mut product = 1;
3093            for (i, &s) in new_shape.iter().enumerate() {
3094                if i != index {
3095                    product *= s;
3096                }
3097            }
3098            let product_current = tensor.shape().num_elements() as i32;
3099
3100            new_shape[index] = product_current / product;
3101
3102            // Check if the reshape is valid
3103            if product_current % product != 0 {
3104                panic!(
3105                    "Cannot reshape tensor of shape {:?} to shape {:?}",
3106                    tensor.shape(),
3107                    new_shape
3108                );
3109            }
3110        };
3111
3112        // Convert each element to usize
3113        let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3114
3115        Shape::from(new_shape)
3116    }
3117}
3118
3119/// Trait used for broadcast arguments.
3120pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3121    /// Converts to a shape.
3122    fn into_shape(self, shape: &Shape) -> Shape;
3123}
3124
3125impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3126    fn into_shape(self, _shape: &Shape) -> Shape {
3127        self
3128    }
3129}
3130impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3131    fn into_shape(self, _shape: &Shape) -> Shape {
3132        Shape::from(self)
3133    }
3134}
3135
3136impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3137    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3138    fn into_shape(self, shape: &Shape) -> Shape {
3139        if self.len() < shape.num_dims() {
3140            panic!("Broadcast arguments must be greater than the number of dimensions");
3141        }
3142
3143        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3144        let new_shape: Vec<_> = self
3145            .iter()
3146            .rev()
3147            .map(|x| {
3148                let primitive = x.to_i64();
3149                if primitive < -1 || primitive == 0 {
3150                    panic!("Broadcast arguments must be positive or -1");
3151                }
3152                primitive
3153            })
3154            .zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3155            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3156            .collect::<Vec<_>>()
3157            .into_iter()
3158            .rev()
3159            .collect();
3160
3161        if new_shape.iter().any(|&x| x == 0) {
3162            panic!("Cannot substitute -1 for a non-existing dimension");
3163        }
3164
3165        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3166
3167        Shape::from(new_shape)
3168    }
3169}
3170
3171impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3172where
3173    B: Backend,
3174    K: BasicOps<B>,
3175    K::Elem: Debug + Copy + Serialize,
3176{
3177    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3178        let data = self.to_data();
3179        data.serialize(serializer)
3180    }
3181}
3182
3183impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3184where
3185    B: Backend,
3186    K: BasicOps<B>,
3187    K::Elem: Debug + Copy + Deserialize<'de>,
3188{
3189    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3190        let tensor = Tensor::from_data(
3191            TensorData::deserialize(deserializer)?,
3192            &<B::Device as Default>::default(),
3193        );
3194        Ok(tensor)
3195    }
3196}
3197
3198#[cfg(test)]
3199mod tests {
3200    use crate::Shape;
3201    use crate::s;
3202
3203    use super::*;
3204
3205    #[test]
3206    fn slice_range_single_dim_leading() {
3207        let shape = Shape::new([8, 4]);
3208
3209        // Half-open range
3210        assert_eq!([0..5], (0..5).into_ranges(shape.clone()));
3211        assert_eq!([0..5], [0..5].into_ranges(shape.clone()));
3212        assert_eq!([5..7], [-3..-1].into_ranges(shape.clone()));
3213
3214        // Inclusive range
3215        assert_eq!([0..5], (0..=4).into_ranges(shape.clone()));
3216        assert_eq!([0..5], [0..=4].into_ranges(shape.clone()));
3217        assert_eq!([6..8], [-2..=-1].into_ranges(shape.clone()));
3218
3219        // Unbounded start
3220        assert_eq!([0..3], (..3).into_ranges(shape.clone()));
3221        assert_eq!([0..3], [..3].into_ranges(shape.clone()));
3222        assert_eq!([0..3], [..-5].into_ranges(shape.clone()));
3223
3224        // Unbounded end
3225        assert_eq!([5..8], (5..).into_ranges(shape.clone()));
3226        assert_eq!([5..8], [5..].into_ranges(shape.clone()));
3227        assert_eq!([5..8], [-3..].into_ranges(shape.clone()));
3228
3229        // Full range
3230        assert_eq!([0..8], [..].into_ranges(shape));
3231    }
3232
3233    #[test]
3234    fn slice_range_multi_dim() {
3235        let shape = Shape::new([8, 4]);
3236
3237        // Multiple ways to provide ranges
3238        assert_eq!([0..5, 0..4], [0..5, 0..4].into_ranges(shape.clone()));
3239        assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape.clone()));
3240        assert_eq!([0..8, 0..4], [0..=7, 0..=3].into_ranges(shape.clone()));
3241
3242        assert_eq!([0..5, 0..3], [0..5, 0..3].into_ranges(shape.clone()));
3243
3244        assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape));
3245    }
3246
3247    #[test]
3248    fn slice_range_multi_dim_index() {
3249        let shape = Shape::new([8, 4]);
3250
3251        // Indices (single integer) should also convert to correct range
3252        assert_eq!([0..1, 2..3], [0, 2].into_ranges(shape.clone()));
3253        assert_eq!([7..8, 3..4], [-1, -1].into_ranges(shape.clone()));
3254        assert_eq!([7..8], (-1).into_ranges(shape.clone()));
3255        assert_eq!([7..8], 7.into_ranges(shape));
3256    }
3257
3258    #[test]
3259    fn slice_range_multi_dim_heterogeneous() {
3260        // Slice macro `s![]` can be used to provide different range types
3261        let shape = Shape::new([8, 4, 2]);
3262        let slice = s![0..5, .., -1];
3263        assert_eq!([0..5, 0..4, 1..2], slice.into_ranges(shape));
3264
3265        let shape = Shape::new([8, 4, 2, 3]);
3266        let slice = s![..=4, 0..=3, .., -2..];
3267        assert_eq!([0..5, 0..4, 0..2, 1..3], slice.into_ranges(shape));
3268
3269        let shape = Shape::new([3, 4]);
3270        let slice = s![1..-1, ..];
3271        assert_eq!([1..2, 0..4], slice.into_ranges(shape));
3272    }
3273}