burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::future::Future;
11use core::iter::repeat;
12use core::{fmt::Debug, ops::Range};
13use serde::{Deserialize, Deserializer};
14
15use serde::{Serialize, Serializer};
16
17use crate::tensor::api::narrow::narrow;
18use crate::{
19    Bool, Float, Int, Shape, TensorData, TensorKind, backend::Backend, check, ops::Device,
20};
21use crate::{DType, Element, TensorPrimitive};
22use crate::{cast::ToElement, check::TensorCheck};
23
24use super::{Slice, TensorMetadata, Transaction};
25
26/// A tensor with a given backend, shape and data type.
27///
28/// # Indexing
29/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
30/// or [`select`](Tensor::select) for numeric types.
31///
32/// ## Example
33///
34/// ```rust
35/// use burn_tensor::backend::Backend;
36/// use burn_tensor::Tensor;
37/// use burn_tensor::Int;
38///
39/// fn example<B: Backend>() {
40///     let device = Default::default();
41///
42///     let tensor = Tensor::<B, 2>::from_data(
43///         [
44///             [3.0, 4.9, 2.0],
45///             [2.0, 1.9, 3.0],
46///             [6.0, 1.5, 7.0],
47///             [3.0, 4.9, 9.0],
48///         ],
49///         &device,
50///     );
51///
52///     // Slice the tensor to get the second and third rows:
53///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
54///     // The resulting tensor will have dimensions [2, 3].
55///     let slice = tensor.clone().slice([1..3]);
56///     println!("{slice}");
57///
58///     // Slice the tensor to get the first two rows and the first 2 columns:
59///     // [[3.0, 4.9], [2.0, 1.9]]
60///     // The resulting tensor will have dimensions [2, 2].
61///     let slice = tensor.clone().slice([0..2, 0..2]);
62///     println!("{slice}");
63///
64///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
65///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
66///     // The resulting tensor will have dimensions [4, 2]
67///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
68///     let indexed = tensor.select(1, indices);
69///     println!("{indexed}");
70/// }
71/// ```
72#[derive(new, Clone, Debug)]
73pub struct Tensor<B, const D: usize, K = Float>
74where
75    B: Backend,
76    K: TensorKind<B>,
77{
78    pub(crate) primitive: K::Primitive,
79}
80
81impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
82where
83    B: Backend,
84    K: BasicOps<B>,
85    T: Into<TensorData>,
86{
87    fn from(value: T) -> Self {
88        Tensor::from_data(value.into(), &Default::default())
89    }
90}
91
92impl<B, const D: usize, K> Tensor<B, D, K>
93where
94    B: Backend,
95    K: BasicOps<B>,
96{
97    /// Converts the tensor into a primitive tensor.
98    pub fn into_primitive(self) -> K::Primitive {
99        self.primitive
100    }
101
102    /// Converts from a primitive tensor into a tensor.
103    pub fn from_primitive(tensor: K::Primitive) -> Self {
104        Self::new(tensor)
105    }
106
107    /// Returns the tensor primitive data type.
108    ///
109    /// # Note
110    /// Some element types are encoded in different primitive types depending on the backend
111    /// (e.g., bool could be encoded as `u8` or `u32`).
112    pub fn dtype(&self) -> DType {
113        self.primitive.dtype()
114    }
115
116    /// Create an empty tensor of the given shape.
117    ///
118    /// # Arguments
119    ///
120    /// - shape: The shape of the tensor.
121    /// - device: The device where the tensor will be created.
122    ///
123    /// # Example
124    /// ```rust
125    /// use burn_tensor::backend::Backend;
126    /// use burn_tensor::Tensor;
127    ///
128    /// fn example<B: Backend>() {
129    ///    let device = Default::default();
130    ///    // Create an empty tensor with dimensions [2, 3, 4].
131    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
132    /// }
133    /// ```
134    pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
135        let shape = shape.into();
136        check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
137        Self::new(K::empty(shape, device))
138    }
139
140    /// Returns the dimensions of the current tensor.
141    ///
142    /// # Example
143    /// ```rust
144    /// use burn_tensor::backend::Backend;
145    /// use burn_tensor::Tensor;
146    ///
147    /// fn example<B: Backend>() {
148    ///   let device = Default::default();
149    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
150    ///   let dims = tensor.dims(); // [2, 3, 4]
151    ///   println!("{dims:?}");
152    /// }
153    /// ```
154    pub fn dims(&self) -> [usize; D] {
155        Self::shape(self).dims()
156    }
157
158    /// Returns the shape of the current tensor.
159    ///
160    /// # Example
161    /// ```rust
162    /// use burn_tensor::backend::Backend;
163    /// use burn_tensor::Tensor;
164    ///
165    /// fn example<B: Backend>() {
166    ///    let device = Default::default();
167    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
168    ///    // Shape { dims: [2, 3, 4] }
169    ///    let shape = tensor.shape();
170    /// }
171    /// ```
172    pub fn shape(&self) -> Shape {
173        self.primitive.shape()
174    }
175
176    /// Reshape the tensor to have the given shape.
177    ///
178    /// The tensor has the same data and number of elements as the input.
179    ///
180    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
181    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
182    ///
183    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
184    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
185    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
186    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
187    /// with [1, 3, 4] dimensions to [1, 12].
188    ///
189    /// # Arguments
190    /// - `shape`: The new shape of the tensor.
191    ///
192    /// # Panics
193    /// - If the tensor contains more than one `-1` in the shape.
194    /// - If the tensor contains values that are not positive (other than -1).
195    /// - If the shape does not match the number of elements of the original shape.
196    ///
197    /// # Example
198    ///
199    /// ```rust
200    /// use burn_tensor::backend::Backend;
201    /// use burn_tensor::Tensor;
202    ///
203    /// fn example<B: Backend>() {
204    ///    let device = Default::default();
205    ///    // Create a tensor with dimensions [2, 3, 4]
206    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
207    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
208    ///    let reshaped = tensor.reshape([2, -1]);
209    ///    println!("{reshaped}");
210    /// }
211    /// ```
212    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
213        // Convert reshape args to shape
214        let shape = shape.into_shape(&self);
215        Tensor::new(K::reshape(self.primitive, shape))
216    }
217
218    /// Transpose the tensor.
219    ///
220    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
221    /// applied on the last two dimensions. For example, the transpose of a tensor with shape
222    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
223    ///
224    /// See also [`permute`](Tensor::permute).
225    ///
226    /// # Arguments
227    ///
228    /// * `tensor` - The tensor to transpose.
229    ///
230    /// # Returns
231    ///
232    /// The transposed tensor.
233    ///
234    /// # Example
235    ///
236    /// ```rust
237    /// use burn_tensor::backend::Backend;
238    /// use burn_tensor::Tensor;
239    ///
240    /// fn example<B: Backend>() {
241    ///     let device = Default::default();
242    ///     // Create a 2D tensor of shape [2, 3]
243    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
244    ///
245    ///     // Transpose the tensor:
246    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
247    ///     // The resulting tensor will have dimensions [3, 2].
248    ///     let transposed = tensor.transpose();
249    ///     println!("{transposed}");
250    /// }
251    /// ```
252    pub fn transpose(self) -> Tensor<B, D, K> {
253        Tensor::new(K::transpose(self.primitive))
254    }
255
256    /// Swaps two dimensions of a tensor.
257    ///
258    /// # Arguments
259    ///
260    /// * `tensor` - The tensor to swap the dimensions of.
261    /// * `dim1` - The first dimension to swap.
262    /// * `dim2` - The second dimension to swap.
263    ///
264    /// # Returns
265    ///
266    /// The tensor with the dimensions swapped.
267    ///
268    /// # Example
269    ///
270    /// ```rust
271    /// use burn_tensor::backend::Backend;
272    /// use burn_tensor::Tensor;
273    ///
274    /// fn example<B: Backend>() {
275    ///     let device = Default::default();
276    ///     // Create a 2D tensor of shape [2, 3]
277    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
278    ///
279    ///     // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`):
280    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
281    ///     // The resulting tensor will have dimensions [3, 2].
282    ///     let swapped = tensor.swap_dims(0, 1);
283    ///     println!("{swapped}");
284    /// }
285    /// ```
286    pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
287        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
288        Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
289    }
290
291    /// Permute the dimensions of the tensor.
292    ///
293    /// # Arguments
294    ///
295    /// * `axes` - The new order of the dimensions. The length of the axes
296    ///   must be equal to the number of dimensions of the tensor.
297    ///   The values must be unique and in the range of the number of dimensions.
298    ///   The values can be negative, in which case they are used as an offset from the end.
299    ///
300    /// # Returns
301    ///
302    /// The tensor with the dimensions permuted.
303    ///
304    /// # Example
305    ///
306    /// ```rust
307    /// use burn_tensor::backend::Backend;
308    /// use burn_tensor::Tensor;
309    ///
310    /// fn example<B: Backend>() {
311    ///     let device = Default::default();
312    ///     // Create a 2D tensor of shape [3, 2]
313    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
314    ///
315    ///     // Permute the dimensions 1 and 0:
316    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
317    ///     // The resulting tensor will have dimensions [3, 2].
318    ///     let permuted = tensor.permute([1, 0]);
319    ///     println!("{permuted}");
320    /// }
321    /// ```
322    pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
323        // Convert the axes to usize and handle negative values without using vector
324        let mut transformed_axes: [usize; D] = [0; D];
325        for (i, &x) in axes.iter().enumerate() {
326            transformed_axes[i] = if x < 0 {
327                (D as isize + x) as usize
328            } else {
329                x as usize
330            };
331        }
332
333        // Check if the axes are valid after the transformation
334        check!(TensorCheck::permute(transformed_axes));
335
336        Tensor::new(K::permute(self.primitive, &transformed_axes))
337    }
338
339    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
340    ///
341    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
342    /// at the positions not specified in destination.
343    ///
344    /// # Arguments
345    ///
346    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
347    ///   The values can be negative, in which case they are used as an offset from the end.
348    ///
349    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
350    ///
351    /// # Panics
352    ///
353    /// - If the source and destination dimensions are not of the same length.
354    /// - If the source and destination vectors contain duplicate values.
355    /// - If the source and destination vectors contain values that are out of bounds.
356    ///
357    /// # Returns
358    ///
359    /// The tensor with the dimensions moved.
360    ///
361    /// # Example
362    ///
363    /// ```rust
364    /// use burn_tensor::backend::Backend;
365    /// use burn_tensor::Tensor;
366    ///
367    /// fn example<B: Backend>() {
368    ///     let device = Default::default();
369    ///     // Create a 3D tensor of shape [3, 2, 1]
370    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
371    ///
372    ///     // Move the dimensions 0 and 1:
373    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
374    ///     // The resulting tensor will have dimensions [2, 3, 1].
375    ///     let moved = tensor.movedim(1, 0);
376    ///     println!("{moved}");
377    /// }
378    /// ```
379    ///
380    /// # Note
381    ///
382    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
383    /// for it
384    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
385        let source_dims = src.into_dim_vec::<D>();
386        let destination_dims = dst.into_dim_vec::<D>();
387
388        check!(TensorCheck::movedim_args_length(
389            &source_dims,
390            &destination_dims
391        ));
392
393        let mut m = [-1; D];
394        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
395            m[d] = s as isize;
396        }
397        let mut axes: [isize; D] = [0; D];
398        let mut source_i = 0;
399        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
400            *item = if m[dest_i] != -1 {
401                m[dest_i]
402            } else {
403                while source_dims.contains(&source_i) {
404                    source_i += 1;
405                }
406                let result = source_i as isize;
407                source_i += 1;
408                result
409            };
410        }
411
412        self.permute(axes)
413    }
414
415    /// Reverse the order of elements in the tensor along the given dimensions.
416    ///
417    /// # Arguments
418    ///
419    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
420    ///   The values can be negative, in which case they are used as an offset from the end.
421    ///
422    /// # Returns
423    ///
424    /// The tensor with the axes flipped.
425    ///
426    /// # Example
427    ///
428    /// ```rust
429    /// use burn_tensor::backend::Backend;
430    /// use burn_tensor::Tensor;
431    ///
432    /// fn example<B: Backend>() {
433    ///     let device = Default::default();
434    ///     // Create a 2D tensor with dimensions [4, 3]
435    ///     let tensor = Tensor::<B, 2>::from_data(
436    ///         [
437    ///             [3.0, 4.9, 2.0],
438    ///             [2.0, 1.9, 3.0],
439    ///             [4.0, 5.9, 8.0],
440    ///             [1.4, 5.8, 6.0],
441    ///         ],
442    ///         &device,
443    ///     );
444    ///
445    ///     // Flip the elements in dimensions 0 and 1:
446    ///     // [[6.0, 5.8, 1.4],
447    ///     //  [8.0, 5.9, 4.0],
448    ///     //  [3.0, 1.9, 2.0],
449    ///     //  [2.0, 4.9, 3.0]]
450    ///     // The resulting tensor will have dimensions [4, 3].
451    ///     let flipped = tensor.flip([0, 1]);
452    ///     println!("{flipped}");
453    /// }
454    /// ```
455    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
456        // Convert the axes to usize and handle negative values without using vector
457        let mut transformed_axes: [usize; N] = [0; N];
458        for (i, &x) in axes.iter().enumerate() {
459            transformed_axes[i] = if x < 0 {
460                (D as isize + x) as usize
461            } else {
462                x as usize
463            };
464        }
465
466        // Check if the axes are valid
467        check!(TensorCheck::flip(D, &transformed_axes));
468
469        Tensor::new(K::flip(self.primitive, &transformed_axes))
470    }
471
472    /// Flatten the tensor along a given range of dimensions.
473    ///
474    /// This function collapses the specified range of dimensions into a single dimension,
475    /// effectively flattening the tensor in that range.
476    ///
477    /// # Arguments
478    ///
479    /// - `start_dim`: The starting dimension of the range to be flattened.
480    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
481    ///
482    /// # Type Parameters
483    ///
484    /// - `D2`: The resulting number of dimensions in the flattened tensor.
485    ///
486    /// # Returns
487    ///
488    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
489    ///
490    /// # Example
491    ///
492    /// ```rust
493    ///
494    /// use burn_tensor::backend::Backend;
495    /// use burn_tensor::{Tensor, Shape};
496    ///
497    /// fn example<B: Backend>() {
498    ///     let device = Default::default();
499    ///     // Create a 3D tensor with dimensions [2, 3, 4]
500    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
501    ///
502    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
503    ///     // The resulting tensor will have dimensions [2, 12]
504    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
505    ///     println!("{flattened}");
506    /// }
507    /// ```
508    pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
509        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
510
511        let current_dims = self.shape().dims;
512        let mut new_dims: [usize; D2] = [0; D2];
513        let mut flatten_dims = 1;
514
515        for i in current_dims[start_dim..=end_dim].iter() {
516            flatten_dims *= i;
517        }
518
519        new_dims[..start_dim].copy_from_slice(&current_dims[..start_dim]);
520        new_dims[start_dim] = flatten_dims;
521        new_dims[start_dim + 1..].copy_from_slice(&current_dims[end_dim + 1..]);
522
523        Tensor::new(K::reshape(self.primitive, new_dims.into()))
524    }
525
526    /// Squeeze the tensor along the given dimension, removing the specified dimension
527    /// of size one, and effectively reducing the rank of the tensor by one.
528    ///
529    /// # Arguments
530    ///
531    /// - `dim`: The dimension to be squeezed.
532    ///
533    /// # Type Parameters
534    ///
535    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
536    ///
537    /// # Panics
538    ///
539    /// If the size in the squeezed dimension is not 1.
540    ///
541    /// # Returns
542    ///
543    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
544    ///
545    /// # Example
546    ///
547    /// ```rust
548    ///
549    /// use burn_tensor::backend::Backend;
550    /// use burn_tensor::{Tensor, Shape};
551    ///
552    /// fn example<B: Backend>() {
553    ///     let device = Default::default();
554    ///     // Create a 3D tensor with dimensions [3, 1, 3]
555    ///     let tensor = Tensor::<B, 3>::from_data(
556    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
557    ///         &device,
558    ///     );
559    ///
560    ///     // Squeeze the dimension 1.
561    ///     // The resulting tensor will have dimensions [3, 3].
562    ///     let squeezed = tensor.squeeze::<2>(1);
563    ///     println!("{squeezed}");
564    /// }
565    /// ```
566    pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
567        check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
568
569        let current_dims = self.shape().dims;
570        let mut new_dims: [usize; D2] = [0; D2];
571
572        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
573        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
574
575        Tensor::new(K::reshape(self.primitive, new_dims.into()))
576    }
577
578    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
579    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
580    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
581    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
582    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
583    /// from the back.
584    ///
585    /// # Arguments
586    ///
587    /// - `dims`: The dimension(s) to be squeezed.
588    ///
589    /// # Type Parameters
590    ///
591    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.
592    ///
593    /// # Returns
594    ///
595    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
596    ///
597    /// # Example
598    ///
599    /// ```rust
600    ///
601    /// use burn_tensor::backend::Backend;
602    /// use burn_tensor::{Tensor, Shape};
603    ///
604    /// fn example<B: Backend>() {
605    ///     let device = Default::default();
606    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
607    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
608    ///
609    ///     // Squeeze the dimensions 1 and 3.
610    ///     // The resulting tensor will have dimensions [2, 4].
611    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
612    ///     println!("{squeezed}");
613    /// }
614    /// ```
615    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
616        let current_dims = self.shape().dims;
617        let mut dim_indices: Vec<usize>;
618
619        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
620        if dims.is_empty() {
621            dim_indices = current_dims
622                .iter()
623                .enumerate()
624                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
625                .collect();
626        } else {
627            // If negative dims, count from the back
628            dim_indices = dims
629                .iter()
630                .map(|&d| {
631                    if d < 0 {
632                        (current_dims.len() as isize + d) as usize
633                    } else {
634                        d as usize
635                    }
636                })
637                .collect();
638        }
639
640        // Sort indices and remove duplicates
641        dim_indices.sort_unstable();
642        dim_indices.dedup();
643
644        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
645        check!(TensorCheck::squeeze_dims_input::<D2>(
646            &dim_indices,
647            &current_dims
648        ));
649
650        // Calculate new dimensions
651        let mut new_dims = Vec::new();
652        for (index, &dim_size) in current_dims.iter().enumerate() {
653            // Exclude the dimension if it's explicitly marked for squeezing
654            if dim_indices.contains(&index) {
655                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
656                continue;
657            }
658            new_dims.push(dim_size);
659        }
660
661        // Check that after squeezing, we still respect the D2 size
662        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
663
664        Tensor::new(K::reshape(self.primitive, new_dims.into()))
665    }
666
667    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
668    ///
669    /// # Type Parameters
670    ///
671    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.
672    ///
673    /// # Panics
674    ///
675    /// If the output size `D2` is smaller than the current number of dimensions.
676    ///
677    /// # Returns
678    ///
679    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
680    ///
681    /// # Example
682    ///
683    /// ```rust
684    /// use burn_tensor::backend::Backend;
685    /// use burn_tensor::{Tensor, Shape};
686    ///
687    /// fn example<B: Backend>() {
688    ///     let device = Default::default();
689    ///     // Create a 2D tensor with dimensions [3, 3]
690    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
691    ///     // Unsqueeze the tensor up to 4 dimensions.
692    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
693    ///     let unsqueezed = tensor.unsqueeze::<4>();
694    ///     println!("{unsqueezed}");
695    /// }
696    /// ```
697    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
698        check!(TensorCheck::unsqueeze::<D, D2>());
699
700        let mut dims = [1; D2];
701        let num_ones = D2 - D;
702        let shape = self.shape();
703
704        dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
705
706        let shape = Shape::new(dims);
707        self.reshape(shape)
708    }
709
710    /// Creates a new tensor with a dimension of size one inserted at the specified position.
711    ///
712    /// # Example
713    ///
714    /// ```rust
715    /// use burn_tensor::backend::Backend;
716    /// use burn_tensor::{Tensor, Shape};
717    ///
718    /// fn example<B: Backend>() {
719    ///     let device = Default::default();
720    ///     // Create a 2D tensor with dimensions [3, 3]
721    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
722    ///     // Unsqueeze the dimension 1.
723    ///     // The resulting tensor will have dimensions [3, 1, 3].
724    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
725    ///     println!("{unsqueezed}");
726    /// }
727    /// ```
728    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
729        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
730
731        let mut dims = [1; D2];
732        let shape = self.shape();
733
734        dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
735
736        if dim < D {
737            dims[dim] = 1;
738            dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
739        } else {
740            dims[dim] = 1;
741        }
742
743        let shape = Shape::new(dims);
744        self.reshape(shape)
745    }
746
747    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
748    /// The indices can be negative, in which case they are counted from the last to the first dimension.
749    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
750    /// is the number of duplicates.
751    /// # Example
752    ///
753    /// ```rust
754    /// use burn_tensor::backend::Backend;
755    /// use burn_tensor::{Tensor, Shape};
756    ///
757    /// fn example<B: Backend>() {
758    ///     let device = Default::default();
759    ///     // Create a 3D tensor with dimensions [3, 4, 5]
760    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
761    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
762    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
763    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
764    ///     println!("{unsqueezed}");
765    /// }
766    /// ```
767    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
768        let mut new_dims = [1; D2];
769        let old_dims = self.shape().dims;
770        //for checking if the dimension is in the acceptable range
771
772        //part 1: convert the negative indices to positive
773        let mut neg_offset = D2;
774        let mut dim_indices = axes
775            .iter()
776            .map(|d| {
777                // check if the dimension is in the acceptable range
778                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
779                (if *d < 0 {
780                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
781                    d + neg_offset as isize + 1
782                } else {
783                    *d
784                }) as usize
785            })
786            .collect::<Vec<usize>>();
787
788        //sort the indices
789        dim_indices.sort_unstable();
790
791        //Now use this to copy the chunks of the dims
792        let mut prev_idx: usize = 0;
793        let mut current_left_b: usize = 0;
794        let mut current_right_b: usize = 0;
795        let mut offset: usize = 0;
796        dim_indices.iter().for_each(|d| {
797            //check if there is space for at least one dimension
798            if prev_idx < *d {
799                current_right_b = *d - offset;
800                //copy the chunks of the dims
801                if current_right_b < D {
802                    new_dims[prev_idx..*d]
803                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);
804                } else {
805                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
806                }
807                prev_idx = *d + 1;
808                //offset is equal to the number of extracted elements from the original shape
809                offset += current_right_b - current_left_b;
810                current_left_b = current_right_b;
811            } else {
812                //it's sorted so the only reason this would happen
813                //is if multiple indices are the same
814                prev_idx += 1;
815            }
816        });
817        //copy over anything past the index of the last new dimension
818        if current_left_b < D {
819            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
820        }
821
822        //lastly, create the shape and reshape
823        let shape = Shape::new(new_dims);
824        self.reshape(shape)
825    }
826
827    /// Returns a tensor containing the elements selected from the given ranges.
828    ///
829    /// For more complex indexing with different slice ranges, see also the slice
830    /// macro [`s!`](crate::s).
831    ///
832    /// # Arguments
833    ///
834    /// * `ranges` - A type implementing the `RangesArg` trait, which can be:
835    ///   - A single range (slice the first dimension)
836    ///   - A single index (slice the first dimension)
837    ///   - An array of ranges
838    ///
839    /// # Behavior
840    ///
841    /// - Supports partial and full slicing in any number of dimensions.
842    /// - Missing ranges are treated as full slices if D > D2.
843    /// - Handles negative indices by wrapping around from the end of the dimension.
844    /// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
845    ///
846    /// # Panics
847    ///
848    /// - If the number of ranges provided exceeds the tensor's dimensions.
849    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
850    ///
851    /// # Examples
852    ///
853    /// ```rust
854    /// use burn_tensor::backend::Backend;
855    /// use burn_tensor::{Tensor, Shape, s};
856    ///
857    /// fn example<B: Backend>() {
858    ///     let device = B::Device::default();
859    ///
860    ///     // 1D slicing
861    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
862    ///     let slice = tensor.slice([1..4]);
863    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
864    ///
865    ///     // 2D slicing
866    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
867    ///     let slice = tensor.slice([1..3, 0..2]);
868    ///     assert_eq!(slice.dims(), [2, 2]);
869    ///
870    ///     // Using negative indices
871    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
872    ///     let slice = tensor.slice([1..-1]); // Equivalent to 1..4
873    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
874    ///
875    ///     // Using the slice macro to select different ranges
876    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
877    ///     let slice = tensor.slice(s![1.., ..]); // Select rows 1 and 2, all columns
878    ///     assert_eq!(slice.dims(), [2, 4]);
879    ///
880    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..16, &device).reshape([2, 4, 2]);
881    ///     let slice = tensor.slice(s![1.., 1..=3, -1]);
882    ///     assert_eq!(slice.dims(), [1, 3, 1]);
883    /// }
884    /// ```
885    ///
886    /// # Note
887    ///
888    /// This function uses the `RangesArg` trait for flexible range specification. The trait
889    /// handles the conversion of various range formats and applies clamping and negative
890    /// index handling internally.
891    pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
892        let ranges = ranges.into_ranges(self.shape());
893
894        check!(TensorCheck::slice::<D, D2>(&self.shape(), &ranges));
895        Self::new(K::slice(self.primitive, &ranges))
896    }
897
898    /// Returns a copy of the current tensor with the selected elements changed to the new ones at
899    /// the selected indices.
900    ///
901    /// # Panics
902    ///
903    /// - If a range exceeds the number of elements on a dimension.
904    /// - If the given values don't match the given ranges.
905    ///
906    /// # Example
907    ///
908    /// ```rust
909    /// use burn_tensor::backend::Backend;
910    /// use burn_tensor::Tensor;
911    ///
912    /// fn example<B: Backend>() {
913    ///     let device = B::Device::default();
914    ///     let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
915    ///     let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
916    ///     let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
917    ///     println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
918    /// }
919    /// ```
920    pub fn slice_assign<const D2: usize>(self, ranges: [Range<usize>; D2], values: Self) -> Self {
921        check!(TensorCheck::slice_assign::<D, D2>(
922            &self.shape(),
923            &values.shape(),
924            &ranges
925        ));
926        Self::new(K::slice_assign(self.primitive, &ranges, values.primitive))
927    }
928
929    /// Returns the device of the current tensor.
930    pub fn device(&self) -> B::Device {
931        K::device(&self.primitive)
932    }
933
934    /// Move the tensor to the given device.
935    pub fn to_device(self, device: &B::Device) -> Self {
936        Self::new(K::to_device(self.primitive, device))
937    }
938
939    /// Converts the data of the current tensor.
940    ///
941    /// # Note
942    ///
943    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
944    /// tensors at once. This may improve laziness, especially if executed on a different
945    /// thread in native environments.
946    pub fn into_data(self) -> TensorData {
947        crate::try_read_sync(self.into_data_async()).expect(
948            "Failed to read tensor data synchronously.
949        This can happen on platforms that don't support blocking futures like WASM.
950        If possible, try using into_data_async instead.",
951        )
952    }
953
954    /// Converts the data of the current tensor.
955    ///
956    /// # Note
957    ///
958    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
959    /// tensors at once. This may improve laziness, especially if executed on a different
960    /// thread in native environments.
961    pub fn to_data(&self) -> TensorData {
962        self.clone().into_data()
963    }
964
965    /// Returns the data of the current tensor.
966    pub async fn into_data_async(self) -> TensorData {
967        K::into_data_async(self.primitive).await
968    }
969
970    /// Returns the data of the current tensor.
971    pub async fn to_data_async(&self) -> TensorData {
972        self.clone().into_data_async().await
973    }
974
975    /// Create a tensor from the given data on the given device.
976    pub fn from_data<T>(data: T, device: &B::Device) -> Self
977    where
978        T: Into<TensorData>,
979    {
980        let data = data.into();
981        check!(TensorCheck::creation_ops::<D>(
982            "From Data",
983            data.shape.as_slice()
984        ));
985        Self::new(K::from_data(data, device))
986    }
987
988    /// Create a tensor from the given data on the given device enforcing the given data type.
989    pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
990    where
991        T: Into<TensorData>,
992    {
993        let data = data.into();
994        check!(TensorCheck::creation_ops::<D>(
995            "From Data",
996            data.shape.as_slice()
997        ));
998        Self::new(K::from_data_dtype(data, device, dtype))
999    }
1000
1001    /// Repeat the tensor along the given dimension.
1002    ///
1003    /// The output tensor has the same shape, except along the given dimension.
1004    ///
1005    /// # Arguments
1006    /// - `dim`: The dimension to repeat.
1007    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1008    ///
1009    /// # Returns
1010    ///
1011    /// A new tensor with the given dimension repeated `times` times.
1012    ///
1013    /// # Example
1014    ///
1015    /// ```rust
1016    /// use burn_tensor::backend::Backend;
1017    /// use burn_tensor::Tensor;
1018    ///
1019    /// fn example<B: Backend>() {
1020    ///     let device = Default::default();
1021    ///     // Create a 2D tensor with dimensions [3, 2]
1022    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1023    ///
1024    ///     // Repeat the tensor along the dimension 0 twice.
1025    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1026    ///     // The resulting tensor will have dimensions [6, 2].
1027    ///     let repeated = tensor.repeat_dim(0, 2);
1028    ///     println!("{repeated}");
1029    /// }
1030    /// ```
1031    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1032        Self::new(K::repeat_dim(self.primitive, dim, times))
1033    }
1034
1035    /// Repeat the tensor along the given dimensions.
1036    /// # Arguments
1037    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1038    ///
1039    /// # Returns
1040    ///
1041    /// A new tensor with the given dimensions repeated `times` times.
1042    ///
1043    /// # Panics
1044    ///
1045    /// If `sizes` contains more elements than the number of dimensions.
1046    ///
1047    /// # Example
1048    ///
1049    /// ```rust
1050    ///
1051    /// use burn_tensor::backend::Backend;
1052    /// use burn_tensor::Tensor;
1053    ///
1054    /// fn example<B: Backend>() {
1055    ///     let device = Default::default();
1056    ///     // Create a 2D tensor with dimensions [3, 2]
1057    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1058    ///
1059    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1060    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1061    ///     // The resulting tensor will have dimensions [6, 2].
1062    ///     let repeated = tensor.repeat(&[2, 1]);
1063    /// }
1064    /// ```
1065    pub fn repeat(self, sizes: &[usize]) -> Self {
1066        let mut tensor = self;
1067        for (dim, &times) in sizes.iter().enumerate() {
1068            if times > 1 {
1069                tensor = tensor.repeat_dim(dim, times);
1070            }
1071        }
1072        tensor
1073    }
1074
1075    /// Applies element-wise equal comparison.
1076    ///
1077    /// # Returns
1078    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1079    ///
1080    /// # Panics
1081    ///
1082    /// If the two tensors don't have the same shape.
1083    ///
1084    /// # Example
1085    ///
1086    /// ```rust
1087    /// use burn_tensor::backend::Backend;
1088    /// use burn_tensor::Tensor;
1089    ///
1090    /// fn example<B: Backend>() {
1091    ///     let device = Default::default();
1092    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1093    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1094    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
1095    ///     // [[false, true], [true, true], [true, true]]
1096    ///     let equal = t1.equal(t2);
1097    ///     println!("{equal}");
1098    /// }
1099    /// ```
1100    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1101        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1102        Tensor::new(K::equal(self.primitive, other.primitive))
1103    }
1104
1105    /// Applies element-wise non-equality comparison.
1106    ///
1107    /// # Returns
1108    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1109    ///
1110    /// # Panics
1111    ///
1112    /// If the two tensors don't have the same shape.
1113    ///
1114    /// # Example
1115    ///
1116    /// ```rust
1117    /// use burn_tensor::backend::Backend;
1118    /// use burn_tensor::Tensor;
1119    ///
1120    /// fn example<B: Backend>() {
1121    ///     let device = Default::default();
1122    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1123    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1124    ///     // Compare the elements of the two 2D tensors for inequality.
1125    ///     // [[true, false], [false, false], [false, false]]
1126    ///     let not_equal = t1.not_equal(t2);
1127    ///     println!("{not_equal}");
1128    /// }
1129    /// ```
1130    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1131        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1132        Tensor::new(K::not_equal(self.primitive, other.primitive))
1133    }
1134
1135    /// Concatenates all tensors into a new one along the given dimension.
1136    ///
1137    /// # Panics
1138    ///
1139    /// - If `dim` is higher than the rank.
1140    /// - If `tensors` is an empty vector.
1141    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
1142    ///
1143    /// # Example
1144    ///
1145    /// ```rust
1146    /// use burn_tensor::backend::Backend;
1147    /// use burn_tensor::Tensor;
1148    ///
1149    /// fn example<B: Backend>() {
1150    ///     let device = Default::default();
1151    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
1152    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1153    ///
1154    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
1155    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
1156    ///     // The resulting tensor will have shape [2, 7].
1157    ///     let concat = Tensor::cat(vec![t1, t2], 1);
1158    ///     println!("{concat}");
1159    /// }
1160    /// ```
1161    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1162        check!(TensorCheck::cat(&tensors, dim));
1163
1164        Self::new(K::cat(
1165            tensors.into_iter().map(|vector| vector.primitive).collect(),
1166            dim,
1167        ))
1168    }
1169
1170    /// Concatenates all tensors into a new one along a new dimension.
1171    ///
1172    /// # Panics
1173    ///
1174    /// - If all tensors don't have the same shape.
1175    /// - If given dimension is not with range of 0..D2
1176    ///
1177    /// # Example
1178    ///
1179    /// ```rust
1180    /// use burn_tensor::backend::Backend;
1181    /// use burn_tensor::Tensor;
1182    ///
1183    /// fn example<B: Backend>() {
1184    ///     let device = Default::default();
1185    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1186    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1187    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1188    ///
1189    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1190    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1191    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1192    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1193    ///     // The resulting tensor will have shape [3, 2, 3].
1194    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1195    ///     println!("{stacked}");
1196    /// }
1197    /// ```
1198    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1199        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1200        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1201        Tensor::<B, D2, K>::cat(tensors, dim)
1202    }
1203
1204    /// Iterate over slices of tensors alongside a given dimension.
1205    ///
1206    /// # Panics
1207    ///
1208    /// If given dimension is greater than or equal to tensor rank.
1209    ///
1210    /// # Returns
1211    ///
1212    /// A tensor iterator.
1213    ///
1214    /// # Example
1215    ///
1216    /// ```rust
1217    /// use burn_tensor::backend::Backend;
1218    /// use burn_tensor::Tensor;
1219    /// fn example<B: Backend>() {
1220    ///   let device = Default::default();
1221    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1222    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
1223    ///   let iter = tensor.iter_dim(0);
1224    ///   for (i,tensor) in iter.enumerate() {
1225    ///     println!("Tensor {}: {}", i, tensor);
1226    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1227    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1228    ///  }
1229    /// }
1230    /// ```
1231    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1232        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1233        DimIter::new(self, dim)
1234    }
1235
1236    /// Returns a new tensor with the given dimension narrowed to the given range.
1237    ///
1238    /// # Panics
1239    ///
1240    /// - If the dimension is greater than the number of dimensions of the tensor.
1241    /// - If the given range exceeds the number of elements on the given dimension.
1242    ///
1243    /// # Returns
1244    ///
1245    /// A new tensor with the given dimension narrowed to the given range.
1246    ///
1247    /// # Example
1248    ///
1249    /// ```rust
1250    /// use burn_tensor::backend::Backend;
1251    /// use burn_tensor::Tensor;
1252    ///
1253    /// fn example<B: Backend>() {
1254    ///     let device = Default::default();
1255    ///     // Create a 2D tensor with dimensions [4, 3]
1256    ///     let tensor = Tensor::<B, 2>::from_data(
1257    ///         [
1258    ///             [3.0, 4.9, 2.0],
1259    ///             [2.0, 1.9, 3.0],
1260    ///             [6.0, 1.5, 7.0],
1261    ///             [3.0, 4.9, 9.0],
1262    ///         ],
1263    ///         &device,
1264    ///     );
1265    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1266    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1267    ///     // The resulting tensor will have dimensions [3, 3].
1268    ///     let narrowed = tensor.narrow(0, 1, 3);
1269    ///     println!("{narrowed}");
1270    /// }
1271    /// ```
1272    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1273        check!(TensorCheck::dim_ops::<D>("narrow", dim));
1274        check!(TensorCheck::narrow(&self, dim, start, length));
1275        Self::new(narrow::<B, K>(self.primitive, dim, start, length))
1276    }
1277
1278    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1279    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1280    ///
1281    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1282    /// Otherwise all chunks will be of equal size except for the last one.
1283    ///
1284    /// # Panics
1285    ///
1286    /// If the dimension is greater than the number of dimensions of the tensor.
1287    ///
1288    /// # Returns
1289    /// A vector of tensors.
1290    ///
1291    /// # Example
1292    ///
1293    /// ```rust
1294    /// use burn_tensor::backend::Backend;
1295    /// use burn_tensor::Tensor;
1296    ///
1297    /// fn example<B: Backend>() {
1298    ///     let device = Default::default();
1299    ///     // Create a 2D tensor with dimensions [4, 3]
1300    ///     let tensor = Tensor::<B, 2>::from_data(
1301    ///         [
1302    ///             [3.0, 4.9, 2.0],
1303    ///             [2.0, 1.9, 3.0],
1304    ///             [6.0, 1.5, 7.0],
1305    ///             [3.0, 4.9, 9.0],
1306    ///         ],
1307    ///         &device,
1308    ///     );
1309    ///     // Split the tensor along the dimension 1 into 2 chunks.
1310    ///     // The first chuck will have shape [4, 2]:
1311    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1312    ///     // The second chunk will have shape [4, 1]:
1313    ///     // [[2.0], [3.0], [7.0], [9.0]]
1314    ///     let chunks = tensor.chunk(2, 1);
1315    ///     println!("{chunks:?}");
1316    /// }
1317    /// ```
1318    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1319        check!(TensorCheck::dim_ops::<D>("chunk", dim));
1320        K::chunk(self.primitive, chunks, dim)
1321            .into_iter()
1322            .map(Self::new)
1323            .collect()
1324    }
1325
1326    /// Splits the tensor into chunks of a specified size along a given dimension.
1327    /// Each chunk is a view of the original tensor.
1328    ///
1329    /// If the tensor size along the given dimension is not divisible by `split_size`,
1330    /// then the last chunk will be smaller.
1331    ///
1332    /// # Panics
1333    ///
1334    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
1335    ///
1336    /// # Returns
1337    ///
1338    /// A vector of tensors.
1339    ///
1340    /// # Example
1341    /// ```rust
1342    /// use burn_tensor::backend::Backend;
1343    /// use burn_tensor::Tensor;
1344    ///
1345    /// fn example<B: Backend>() {
1346    ///     let device = Default::default();
1347    ///     // Create a 1D tensor with 5 elements
1348    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1349    ///     // Split the tensor into chunks of size 2 along dimension 0
1350    ///     let chunks = tensor.split(2, 0);
1351    ///     // The result is a vector of tensors:
1352    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
1353    ///     println!("{:?}", chunks);
1354    /// }
1355    /// ```
1356    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
1357        check!(TensorCheck::split::<D>(
1358            self.shape().dims.as_ref(),
1359            split_size,
1360            dim
1361        ));
1362        K::split(self.primitive, split_size, dim)
1363            .into_iter()
1364            .map(Self::new)
1365            .collect()
1366    }
1367
1368    /// Splits the tensor into chunks with the specified sizes along a given dimension.
1369    /// Each chunk is a view of the original tensor.
1370    ///
1371    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
1372    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
1373    ///
1374    /// # Panics
1375    ///
1376    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
1377    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
1378    ///
1379    /// # Returns
1380    ///
1381    /// A vector of tensors.
1382    ///
1383    /// # Example
1384    /// ```rust
1385    /// use burn_tensor::backend::Backend;
1386    /// use burn_tensor::Tensor;
1387    ///
1388    /// fn example<B: Backend>() {
1389    ///     let device = Default::default();
1390    ///     // Create a 1D tensor with 5 elements
1391    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1392    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
1393    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
1394    ///     // The result is a vector of tensors:
1395    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
1396    ///     println!("{:?}", chunks);
1397    /// }
1398    /// ```
1399    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
1400        check!(TensorCheck::split_with_sizes::<D>(
1401            self.shape().dims.as_ref(),
1402            &split_sizes,
1403            dim
1404        ));
1405        K::split_with_sizes(self.primitive, split_sizes, dim)
1406            .into_iter()
1407            .map(Self::new)
1408            .collect()
1409    }
1410
1411    /// Tests if any element in the `tensor` evaluates to True.
1412    ///
1413    /// # Arguments
1414    ///
1415    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1416    ///
1417    /// # Returns
1418    ///
1419    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
1420    /// evaluates to True, False otherwise.
1421    ///
1422    /// # Example
1423    ///
1424    /// ```rust
1425    /// use burn_tensor::backend::Backend;
1426    /// use burn_tensor::{Tensor, Bool};
1427    ///
1428    /// fn example<B: Backend>() {
1429    ///   let device = Default::default();
1430    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
1431    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
1432    ///
1433    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1434    ///   let any_tensor = tensor.any();
1435    ///   println!("{}", any_tensor);
1436    ///   // Tensor { data: [true], ... }
1437    ///
1438    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1439    ///   let any_tensor_two = tensor_two.any();
1440    ///   println!("{}", any_tensor_two);
1441    ///   // Tensor { data: [false], ... }
1442    /// }
1443    /// ```
1444    pub fn any(self) -> Tensor<B, 1, Bool> {
1445        Tensor::new(K::any(self.primitive))
1446    }
1447
1448    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
1449    ///
1450    /// # Arguments
1451    ///
1452    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1453    /// * `dim` - The axis along which to test.
1454    ///
1455    /// # Returns
1456    ///
1457    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1458    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1459    /// evaluates to True, False otherwise.
1460    ///
1461    /// # Example
1462    ///
1463    /// ```rust
1464    /// use burn_tensor::backend::Backend;
1465    /// use burn_tensor::{Tensor, Bool};
1466    ///
1467    /// fn example<B: Backend>() {
1468    ///     let device = Default::default();
1469    ///     let tensor =
1470    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
1471    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
1472    ///     // [[true], [true]],
1473    ///     let any_dim = tensor.clone().any_dim(1);
1474    ///     println!("{any_dim}");
1475    /// }
1476    /// ```
1477    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1478        Tensor::new(K::any_dim(self.primitive, dim))
1479    }
1480
1481    /// Tests if all elements in the `tensor` evaluate to True.
1482    ///
1483    /// # Arguments
1484    ///
1485    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1486    ///
1487    /// # Returns
1488    ///
1489    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1490    /// evaluate to True, False otherwise.
1491    ///
1492    /// # Example
1493    ///
1494    /// ```rust
1495    /// use burn_tensor::backend::Backend;
1496    /// use burn_tensor::{Tensor, Bool};
1497    ///
1498    /// fn example<B: Backend>() {
1499    ///     let device = Default::default();
1500    ///     let tensor =
1501    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
1502    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
1503    ///     // [false]
1504    ///     let all = tensor.all();
1505    ///     println!("{all}");
1506    /// }
1507    /// ```
1508    pub fn all(self) -> Tensor<B, 1, Bool> {
1509        Tensor::new(K::all(self.primitive))
1510    }
1511
1512    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1513    ///
1514    /// # Arguments
1515    ///
1516    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1517    /// * `dim` - The axis along which to test.
1518    ///
1519    /// # Returns
1520    ///
1521    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1522    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1523    /// evaluates to True, False otherwise.
1524    ///
1525    /// # Example
1526    ///
1527    /// ```rust
1528    /// use burn_tensor::backend::Backend;
1529    /// use burn_tensor::{Tensor, Bool};
1530    ///
1531    /// fn example<B: Backend>() {
1532    ///     let device = Default::default();
1533    ///     let tensor =
1534    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
1535    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
1536    ///     // [[true, true, false]]
1537    ///     let all_dim = tensor.clone().all_dim(0);
1538    ///     println!("{all_dim}");
1539    /// }
1540    /// ```
1541    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1542        Tensor::new(K::all_dim(self.primitive, dim))
1543    }
1544
1545    /// Convert the tensor into a scalar.
1546    ///
1547    /// # Panics
1548    ///
1549    /// - If the tensor doesn't have one element.
1550    /// - If the backend fails to read the tensor data synchronously.
1551    ///
1552    /// # Returns
1553    ///
1554    /// The scalar value of the tensor.
1555    ///
1556    /// # Example
1557    ///
1558    /// ```rust
1559    /// use burn_tensor::backend::Backend;
1560    /// use burn_tensor::Tensor;
1561    ///
1562    /// fn example<B: Backend>() {
1563    ///     let device = Default::default();
1564    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
1565    ///     // Convert the tensor with a single element into a scalar.
1566    ///     let scalar = tensor.into_scalar();
1567    ///     println!("{scalar}");
1568    /// }
1569    /// ```
1570    pub fn into_scalar(self) -> K::Elem {
1571        crate::try_read_sync(self.into_scalar_async()).expect(
1572            "Failed to read tensor data synchronously. This can happen on platforms
1573            that don't support blocking futures like WASM. Try into_scalar_async instead.",
1574        )
1575    }
1576
1577    /// Convert the tensor into a scalar.
1578    ///
1579    /// # Panics
1580    ///
1581    /// If the tensor doesn't have one element.
1582    pub async fn into_scalar_async(self) -> K::Elem {
1583        check!(TensorCheck::into_scalar::<D>(&self.shape()));
1584        self.into_data_async().await.iter().next().unwrap()
1585    }
1586
1587    /// Broadcast the tensor to the given shape.
1588    ///
1589    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
1590    /// (which can be inferred with `-1`).
1591    ///
1592    /// # Arguments
1593    ///
1594    /// * `shape` - The shape to broadcast the tensor to.
1595    ///   Can contain -1 for dimensions that should be inferred.
1596    ///   The number of elements in the shape must be greater or equal as
1597    ///   the number of dimensions of the tensor.
1598    ///
1599    /// # Panics
1600    ///
1601    /// If the tensor cannot be broadcasted to the given shape.
1602    ///
1603    /// # Returns
1604    ///
1605    /// A new tensor with the given shape.
1606    ///
1607    /// # Example
1608    ///
1609    /// ```rust
1610    /// use burn_tensor::backend::Backend;
1611    /// use burn_tensor::Tensor;
1612    ///
1613    /// fn example<B: Backend>() {
1614    ///     let device = Default::default();
1615    ///     // Create a 2D tensor with dimensions [3, 1]
1616    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
1617    ///     // Expand the tensor to a new shape [3, 4]
1618    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
1619    ///     let expanded = tensor.expand([3, 4]);
1620    ///     println!("{}", expanded);
1621    /// }
1622    /// ```
1623    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
1624        let shape = shape.into_shape(&self.shape());
1625        check!(TensorCheck::expand::<D, D2>(
1626            "expand",
1627            &self.shape(),
1628            &shape,
1629        ));
1630
1631        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
1632    }
1633}
1634
1635/// Iterator given by (Tensor::iter_dim).
1636pub struct DimIter<B, const D: usize, K>
1637where
1638    B: Backend,
1639    K: BasicOps<B>,
1640{
1641    start: usize,
1642    end: usize,
1643    dim: usize,
1644    ranges: [Range<usize>; D],
1645    tensor: Tensor<B, D, K>,
1646}
1647
1648impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
1649    type Item = Tensor<B, D, K>;
1650
1651    fn next(&mut self) -> Option<Self::Item> {
1652        if self.start >= self.end {
1653            return None;
1654        }
1655
1656        let mut ranges = self.ranges.clone();
1657        ranges[self.dim] = self.start..(self.start + 1);
1658
1659        let slice = self.tensor.clone().slice(ranges);
1660        self.start += 1;
1661
1662        Some(slice)
1663    }
1664}
1665
1666impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
1667    fn next_back(&mut self) -> Option<Self::Item> {
1668        if self.start >= self.end {
1669            return None;
1670        }
1671
1672        let mut ranges = self.ranges.clone();
1673        ranges[self.dim] = (self.end - 1)..self.end;
1674
1675        let slice = self.tensor.clone().slice(ranges);
1676        self.end = self.end.saturating_sub(1);
1677
1678        Some(slice)
1679    }
1680}
1681
1682impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
1683    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
1684        let dims = tensor.dims();
1685        let ranges = dims
1686            .iter()
1687            .map(|&dim| 0..dim)
1688            .collect::<Vec<Range<usize>>>();
1689        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
1690        Self {
1691            end: dims[dim],
1692            ranges,
1693            start: 0,
1694            dim,
1695            tensor,
1696        }
1697    }
1698}
1699
1700impl<B, const D: usize, K> Tensor<B, D, K>
1701where
1702    B: Backend,
1703    K: BasicOps<B>,
1704    <K as BasicOps<B>>::Elem: Debug,
1705{
1706    #[inline]
1707    fn push_newline_indent(acc: &mut String, indent: usize) {
1708        acc.push('\n');
1709        for _ in 0..indent {
1710            acc.push(' ');
1711        }
1712    }
1713    fn fmt_inner_tensor(
1714        &self,
1715        acc: &mut String,
1716        depth: usize,
1717        multi_index: &mut [usize],
1718        range: (usize, usize),
1719        precision: Option<usize>,
1720    ) {
1721        let (start, end) = range;
1722        for i in start..end {
1723            if i > 0 {
1724                acc.push_str(", ");
1725            }
1726            multi_index[depth] = i;
1727            let range: [Range<usize>; D] =
1728                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
1729
1730            let data =
1731                burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
1732
1733            if let Some(data) = data {
1734                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
1735                match (precision, K::name()) {
1736                    (Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
1737                    (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
1738                    _ => acc.push_str(&format!("{:?}", elem)),
1739                }
1740            } else {
1741                acc.push_str("<Tensor data not available>");
1742            }
1743        }
1744    }
1745
1746    fn fmt_outer_tensor(
1747        &self,
1748        acc: &mut String,
1749        depth: usize,
1750        multi_index: &mut [usize],
1751        print_options: &PrintOptions,
1752        summarize: bool,
1753        range: (usize, usize),
1754    ) {
1755        let (start, end) = range;
1756        for i in start..end {
1757            if i > start {
1758                acc.push(',');
1759                Self::push_newline_indent(acc, depth + 1);
1760            }
1761            acc.push('[');
1762            multi_index[depth] = i;
1763            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
1764            acc.push(']');
1765        }
1766    }
1767
1768    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
1769    ///
1770    /// This function is designed to work with tensors of any dimensionality.
1771    /// It traverses the tensor dimensions recursively, converting the elements
1772    /// to strings and appending them to the accumulator string with the
1773    /// appropriate formatting.
1774    ///
1775    /// # Arguments
1776    ///
1777    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
1778    /// * `depth` - The current depth of the tensor dimensions being processed.
1779    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
1780    fn display_recursive(
1781        &self,
1782        acc: &mut String,
1783        depth: usize,
1784        multi_index: &mut [usize],
1785        print_options: &PrintOptions,
1786        summarize: bool,
1787    ) {
1788        let edge_items = print_options.edge_items;
1789
1790        if depth == 0 {
1791            acc.push('[');
1792        }
1793
1794        if depth == self.dims().len() - 1 {
1795            // if we are at the innermost dimension, just push its elements into the accumulator
1796            if summarize && self.dims()[depth] > 2 * edge_items {
1797                // print the starting `edge_items` elements
1798                self.fmt_inner_tensor(
1799                    acc,
1800                    depth,
1801                    multi_index,
1802                    (0, edge_items),
1803                    print_options.precision,
1804                );
1805                acc.push_str(", ...");
1806                // print the last `edge_items` elements
1807                self.fmt_inner_tensor(
1808                    acc,
1809                    depth,
1810                    multi_index,
1811                    (self.dims()[depth] - edge_items, self.dims()[depth]),
1812                    print_options.precision,
1813                );
1814            } else {
1815                // print all the elements
1816                self.fmt_inner_tensor(
1817                    acc,
1818                    depth,
1819                    multi_index,
1820                    (0, self.dims()[depth]),
1821                    print_options.precision,
1822                );
1823            }
1824        } else {
1825            // otherwise, iterate through the current dimension and recursively display the inner tensors
1826            if summarize && self.dims()[depth] > 2 * edge_items {
1827                self.fmt_outer_tensor(
1828                    acc,
1829                    depth,
1830                    multi_index,
1831                    print_options,
1832                    summarize,
1833                    (0, edge_items),
1834                );
1835
1836                acc.push(',');
1837                Self::push_newline_indent(acc, depth + 1);
1838                acc.push_str("...");
1839                Self::push_newline_indent(acc, depth + 1);
1840
1841                self.fmt_outer_tensor(
1842                    acc,
1843                    depth,
1844                    multi_index,
1845                    print_options,
1846                    summarize,
1847                    (self.dims()[depth] - edge_items, self.dims()[depth]),
1848                );
1849            } else {
1850                self.fmt_outer_tensor(
1851                    acc,
1852                    depth,
1853                    multi_index,
1854                    print_options,
1855                    summarize,
1856                    (0, self.dims()[depth]),
1857                );
1858            }
1859        }
1860
1861        if depth == 0 {
1862            acc.push(']');
1863        }
1864    }
1865}
1866
1867#[derive(Clone, Debug)]
1868/// Options for Tensor pretty printing
1869pub struct PrintOptions {
1870    /// number of elements to start summarizing tensor
1871    pub threshold: usize,
1872
1873    /// number of starting elements and ending elements to display
1874    pub edge_items: usize,
1875
1876    /// Precision for floating point numbers
1877    pub precision: Option<usize>,
1878}
1879
1880static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
1881
1882impl PrintOptions {
1883    /// Print options with default values
1884    pub const fn const_default() -> Self {
1885        Self {
1886            threshold: 1000,
1887            edge_items: 3,
1888            precision: None,
1889        }
1890    }
1891}
1892
1893impl Default for PrintOptions {
1894    fn default() -> Self {
1895        Self::const_default()
1896    }
1897}
1898
1899/// Set print options
1900pub fn set_print_options(options: PrintOptions) {
1901    let mut print_opts = PRINT_OPTS.write().unwrap();
1902    *print_opts = options;
1903}
1904
1905/// Pretty print tensors
1906impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
1907where
1908    B: Backend,
1909    B::IntElem: core::fmt::Display,
1910    K: BasicOps<B>,
1911    <K as BasicOps<B>>::Elem: Debug,
1912{
1913    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1914        writeln!(f, "Tensor {{")?;
1915
1916        {
1917            // Do not lock the mutex for the whole function
1918            let mut po = { PRINT_OPTS.read().unwrap().clone() };
1919
1920            // Override the precision if it is set from the formatter
1921            // This will be possible when the tensor is printed using the `{:.*}` syntax
1922            if let Some(precision) = f.precision() {
1923                po.precision = Some(precision);
1924            }
1925
1926            let mut acc = String::new();
1927            let mut multi_index = vec![0; D];
1928            let summarize = self.shape().num_elements() > po.threshold;
1929
1930            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
1931
1932            writeln!(f, "  data:")?;
1933            write!(f, "{acc}")?;
1934            writeln!(f, ",")?;
1935        }
1936
1937        writeln!(f, "  shape:  {:?},", self.dims())?;
1938        writeln!(f, "  device:  {:?},", self.device())?;
1939        writeln!(f, "  backend:  {:?},", B::name(&self.device()))?;
1940        writeln!(f, "  kind:  {:?},", K::name())?;
1941
1942        let dtype = self.primitive.dtype();
1943
1944        writeln!(f, "  dtype:  {:?},", dtype.name())?;
1945        write!(f, "}}")
1946    }
1947}
1948
1949/// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g.
1950/// ```rust
1951/// use burn_tensor::backend::Backend;
1952/// use burn_tensor::{Tensor, T};
1953///
1954/// fn example<B: Backend>() {
1955///     let device = Default::default();
1956///     let tensor = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
1957///     let transposed = tensor^T;
1958/// }
1959/// ```
1960pub struct T;
1961
1962impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
1963    type Output = Self;
1964    fn bitxor(self, _: T) -> Self::Output {
1965        self.transpose()
1966    }
1967}
1968
1969/// Trait that list all operations that can be applied on all tensors.
1970///
1971/// # Warnings
1972///
1973/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
1974pub trait BasicOps<B: Backend>: TensorKind<B> {
1975    /// The type of the tensor elements.
1976    type Elem: Element;
1977
1978    /// Creates an empty tensor with the given shape.
1979    ///
1980    /// # Arguments
1981    ///
1982    /// * `shape` - The shape of the tensor.
1983    /// * `device` - The device on which the tensor will be allocated.
1984    ///
1985    /// # Returns
1986    ///
1987    /// The empty tensor.
1988    ///
1989    /// # Remarks
1990    ///
1991    /// This is a low-level function used internally by the library to call different backend functions
1992    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1993    /// or use this function directly.
1994    ///
1995    /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
1996    /// which is more high-level and designed for public use.
1997    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive;
1998
1999    /// Reshapes the tensor.
2000    ///
2001    /// # Arguments
2002    ///
2003    /// * `tensor` - The tensor.
2004    /// * `shape` - The new shape of the tensor.
2005    ///
2006    /// # Returns
2007    ///
2008    /// The reshaped tensor.
2009    ///
2010    /// # Remarks
2011    ///
2012    /// This is a low-level function used internally by the library to call different backend functions
2013    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2014    /// or use this function directly.
2015    ///
2016    /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
2017    /// which is more high-level and designed for public use.
2018    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2019
2020    /// Transposes a tensor.
2021    ///
2022    /// # Arguments
2023    ///
2024    /// * `tensor` - The tensor to transpose.
2025    ///
2026    /// # Returns
2027    ///
2028    /// The transposed tensor.
2029    fn transpose(tensor: Self::Primitive) -> Self::Primitive;
2030
2031    /// Swaps two dimensions of a tensor.
2032    ///
2033    /// # Arguments
2034    ///
2035    /// * `tensor` - The tensor to swap the dimensions of.
2036    /// * `dim1` - The first dimension to swap.
2037    /// * `dim2` - The second dimension to swap.
2038    ///
2039    /// # Returns
2040    ///
2041    /// The tensor with the dimensions swapped.
2042    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
2043
2044    /// Permutes the dimensions of a tensor.
2045    ///
2046    /// # Arguments
2047    ///
2048    /// * `tensor` - The tensor to permute the dimensions of.
2049    /// * `axes` - The new order of the dimensions.
2050    ///
2051    /// # Returns
2052    ///
2053    /// The tensor with the dimensions permuted.
2054    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2055
2056    /// Flips the tensor along the given axes.
2057    ///
2058    /// # Arguments
2059    ///
2060    /// * `tensor` - The tensor to flip.
2061    /// * `axes` - The axes to flip the tensor along.
2062    ///
2063    /// # Returns
2064    ///
2065    /// The tensor with the axes flipped.
2066    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2067
2068    ///  Select tensor elements corresponding for the given ranges.
2069    ///
2070    /// # Arguments
2071    ///
2072    /// * `tensor` - The tensor.
2073    /// * `ranges` - The ranges of the elements to select.
2074    ///
2075    /// # Returns
2076    ///
2077    /// The selected elements.
2078    ///
2079    /// # Remarks
2080    ///
2081    /// This is a low-level function used internally by the library to call different backend functions
2082    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2083    /// or use this function directly.
2084    ///
2085    /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2086    /// which is more high-level and designed for public use.
2087    fn slice(tensor: Self::Primitive, range: &[Range<usize>]) -> Self::Primitive;
2088
2089    ///  Assigns the given value to the tensor elements corresponding for the given ranges.
2090    ///
2091    /// # Arguments
2092    ///
2093    /// * `tensor` - The tensor.
2094    /// * `ranges` - The ranges of the elements to select.
2095    /// * `value` - The value to assign.
2096    ///
2097    /// # Returns
2098    ///
2099    /// The tensor with the assigned values.
2100    ///
2101    /// # Remarks
2102    ///
2103    /// This is a low-level function used internally by the library to call different backend functions
2104    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2105    /// or use this function directly.
2106    ///
2107    /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2108    /// which is more high-level and designed for public use.
2109    fn slice_assign(
2110        tensor: Self::Primitive,
2111        ranges: &[Range<usize>],
2112        value: Self::Primitive,
2113    ) -> Self::Primitive;
2114
2115    /// Returns the device on which the tensor is allocated.
2116    ///
2117    /// # Arguments
2118    ///
2119    /// * `tensor` - The tensor.
2120    ///
2121    /// # Returns
2122    ///
2123    /// The device on which the tensor is allocated.
2124    ///
2125    /// # Remarks
2126    ///
2127    /// This is a low-level function used internally by the library to call different backend functions
2128    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2129    /// or use this function directly.
2130    ///
2131    /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2132    /// which is more high-level and designed for public use.
2133    fn device(tensor: &Self::Primitive) -> B::Device;
2134
2135    /// Moves the tensor to the given device.
2136    ///
2137    /// # Arguments
2138    ///
2139    /// * `tensor` - The tensor.
2140    /// * `device` - The device on which the tensor will be moved.
2141    ///
2142    /// # Returns
2143    ///
2144    /// The tensor on the given device.
2145    ///
2146    /// # Remarks
2147    ///
2148    /// This is a low-level function used internally by the library to call different backend functions
2149    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2150    /// or use this function directly.
2151    ///
2152    /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2153    /// which is more high-level and designed for public use.
2154    fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2155
2156    /// Extracts the data from the tensor asynchronously.
2157    ///
2158    /// # Arguments
2159    ///
2160    /// * `tensor` - The tensor.
2161    ///
2162    /// # Returns
2163    ///
2164    /// The data of the tensor.
2165    ///
2166    /// # Remarks
2167    ///
2168    /// This is a low-level function used internally by the library to call different backend functions
2169    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2170    /// or use this function directly.
2171    ///
2172    /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
2173    /// which is more high-level and designed for public use.
2174    fn into_data_async(
2175        tensor: Self::Primitive,
2176    ) -> impl Future<Output = TensorData> + 'static + Send;
2177
2178    /// Read the data from the tensor using a transaction.
2179    ///
2180    /// # Remarks
2181    ///
2182    /// This is a low-level function used internally by the library to call different backend functions
2183    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2184    /// or use this function directly.
2185    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
2186
2187    /// Creates a tensor from the given data.
2188    ///
2189    /// # Arguments
2190    ///
2191    /// * `data` - The data of the tensor.
2192    /// * `device` - The device on which the tensor will be allocated.
2193    ///
2194    /// # Returns
2195    ///
2196    /// The tensor.
2197    ///
2198    /// # Remarks
2199    ///
2200    /// This is a low-level function used internally by the library to call different backend functions
2201    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2202    /// or use this function directly.
2203    ///
2204    /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
2205    /// which is more high-level and designed for public use.
2206    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
2207    /// Creates a tensor from the given data enforcing the given data type.
2208    ///
2209    /// # Remarks
2210    ///
2211    /// This is a low-level function used internally by the library to call different backend functions
2212    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2213    /// or use this function directly.
2214    ///
2215    /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype)
2216    /// function, which is more high-level and designed for public use.
2217    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
2218
2219    /// Repeat the tensor along the given dimension.
2220    ///
2221    /// # Arguments
2222    ///
2223    /// * `tensor` - The tensor.
2224    /// * `dim` - The dimension along which the tensor will be repeated.
2225    /// * `times` - The number of times the tensor will be repeated.
2226    ///
2227    /// # Returns
2228    ///
2229    /// The repeated tensor.
2230    ///
2231    /// # Remarks
2232    ///
2233    /// This is a low-level function used internally by the library to call different backend functions
2234    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2235    /// or use this function directly.
2236    ///
2237    /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
2238    /// which is more high-level and designed for public use.
2239    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
2240
2241    /// Concatenates the given tensors along the given dimension.
2242    ///
2243    /// # Arguments
2244    ///
2245    /// * `vectors` - The tensors to concatenate.
2246    /// * `dim` - The dimension along which the tensors will be concatenated.
2247    ///
2248    /// # Returns
2249    ///
2250    /// The concatenated tensor.
2251    ///
2252    /// # Remarks
2253    ///
2254    /// This is a low-level function used internally by the library to call different backend functions
2255    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2256    /// or use this function directly.
2257    ///
2258    /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
2259    /// which is more high-level and designed for public use.
2260    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
2261
2262    /// Attempts to split the tensor along the given dimension into chunks.
2263    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2264    ///
2265    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2266    /// Otherwise all chunks will be of equal size except for the last one.
2267    ///
2268    /// # Panics
2269    ///
2270    ///  If the dimension is greater than the number of dimensions of the tensor.
2271    ///
2272    /// # Returns
2273    /// A vector of tensors.
2274    ///
2275    /// # Remarks
2276    ///
2277    /// This is a low-level function used internally by the library to call different backend functions
2278    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2279    /// or use this function directly.
2280    ///
2281    /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
2282    /// which is more high-level and designed for public use.
2283    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;
2284
2285    /// Splits the tensor into chunks of a specified size along a given dimension.
2286    /// Each chunk is a view of the original tensor.
2287    ///
2288    /// # Panics
2289    ///
2290    /// If the dimension to split along is greater than the number of dimensions of the tensor.
2291    ///
2292    /// # Returns
2293    ///
2294    /// A vector of tensors.
2295    ///
2296    /// # Remarks
2297    /// This is a low-level function used internally by the library to call different backend functions
2298    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2299    /// or use this function directly.
2300    ///
2301    /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
2302    /// which is more high-level and designed for public use.
2303    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
2304
2305    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2306    /// Each chunk is a view of the original tensor.
2307    ///
2308    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2309    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2310    ///
2311    /// # Panics
2312    ///
2313    /// If the dimension to split along is greater than the number of dimensions of the tensor or
2314    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2315    ///
2316    /// # Returns
2317    ///
2318    /// A vector of tensors.
2319    ///
2320    /// # Remarks
2321    /// This is a low-level function used internally by the library to call different backend functions
2322    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2323    /// or use this function directly.
2324    ///
2325    /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
2326    /// which is more high-level and designed for public use.
2327    fn split_with_sizes(
2328        tensor: Self::Primitive,
2329        split_sizes: Vec<usize>,
2330        dim: usize,
2331    ) -> Vec<Self::Primitive>;
2332
2333    /// Equates the given tensors.
2334    ///
2335    /// # Arguments
2336    ///
2337    /// * `lhs` - The left hand side tensor.
2338    /// * `rhs` - The right hand side tensor.
2339    ///
2340    /// # Returns
2341    ///
2342    /// The tensor of booleans indicating whether the corresponding elements are equal.
2343    ///
2344    /// # Remarks
2345    ///
2346    /// This is a low-level function used internally by the library to call different backend functions
2347    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2348    /// or use this function directly.
2349    ///
2350    /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
2351    /// which is more high-level and designed for public use.
2352    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2353
2354    /// Applies element-wise non-equality comparison between the given tensors.
2355    ///
2356    /// # Arguments
2357    ///
2358    /// * `lhs` - The left hand side tensor.
2359    /// * `rhs` - The right hand side tensor.
2360    ///
2361    /// # Returns
2362    ///
2363    /// The tensor of booleans indicating whether the corresponding elements are equal.
2364    ///
2365    /// # Remarks
2366    ///
2367    /// This is a low-level function used internally by the library to call different backend functions
2368    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2369    /// or use this function directly.
2370    ///
2371    /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
2372    /// function, which is more high-level and designed for public use.
2373    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2374
2375    /// Returns the name of the element type.
2376    fn elem_type_name() -> &'static str {
2377        core::any::type_name::<Self::Elem>()
2378    }
2379
2380    /// Returns the tensor data type.
2381    fn dtype(tensor: &Self::Primitive) -> DType {
2382        tensor.dtype()
2383    }
2384
2385    /// Tests if any element in the `tensor` evaluates to True.
2386    ///
2387    /// # Arguments
2388    ///
2389    /// * `tensor` - The tensor to test.
2390    ///
2391    /// # Returns
2392    ///
2393    /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
2394    ///
2395    /// # Remarks
2396    ///
2397    /// This is a low-level function used internally by the library to call different backend functions
2398    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2399    /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
2400    /// which is more high-level and designed for public use.
2401    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2402
2403    /// Tests if any element in the tensor evaluates to True along a given dimension dim.
2404    ///
2405    /// # Arguments
2406    ///
2407    /// * tensor - The tensor to test.
2408    /// * dim - The axis along which to test.
2409    ///
2410    /// # Returns
2411    ///
2412    /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
2413    /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
2414    ///
2415    /// # Remarks
2416    ///
2417    /// This is a low-level function used internally by the library to call different backend functions
2418    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2419    /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
2420    /// which is more high-level and designed for public use.
2421    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2422
2423    /// Tests if all elements in the `tensor` evaluate to True.
2424    ///
2425    /// # Arguments
2426    ///
2427    /// * `tensor` - The tensor to test.
2428    ///
2429    /// # Returns
2430    ///
2431    /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
2432    ///
2433    /// # Remarks
2434    ///
2435    /// This is a low-level function used internally by the library to call different backend functions
2436    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2437    /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
2438    /// which is more high-level and designed for public use.
2439    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2440
2441    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2442    ///
2443    /// # Arguments
2444    ///
2445    /// * `tensor` - The tensor to test.
2446    ///
2447    /// # Returns
2448    ///
2449    /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
2450    /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
2451    ///
2452    /// # Remarks
2453    ///
2454    /// This is a low-level function used internally by the library to call different backend functions
2455    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2456    /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
2457    /// which is more high-level and designed for public use.
2458    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2459
2460    /// Broadcasts the given tensor to the specified shape.
2461    ///
2462    /// # Arguments
2463    ///
2464    /// * `tensor` - The tensor to broadcast.
2465    /// * `shape` - The shape to broadcast to.
2466    ///
2467    /// # Returns
2468    ///
2469    /// The broadcasted tensor.
2470    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2471}
2472
2473impl<B: Backend> BasicOps<B> for Float {
2474    type Elem = B::FloatElem;
2475
2476    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2477        TensorPrimitive::Float(B::float_empty(shape, device))
2478    }
2479
2480    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2481        tr.register_float(tensor);
2482    }
2483
2484    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2485        match tensor {
2486            TensorPrimitive::Float(tensor) => {
2487                TensorPrimitive::Float(B::float_reshape(tensor, shape))
2488            }
2489            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
2490        }
2491    }
2492
2493    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2494        match tensor {
2495            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
2496            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
2497        }
2498    }
2499
2500    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2501        match tensor {
2502            TensorPrimitive::Float(tensor) => {
2503                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
2504            }
2505            TensorPrimitive::QFloat(tensor) => {
2506                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
2507            }
2508        }
2509    }
2510
2511    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2512        match tensor {
2513            TensorPrimitive::Float(tensor) => {
2514                TensorPrimitive::Float(B::float_slice(tensor, ranges))
2515            }
2516            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)),
2517        }
2518    }
2519
2520    fn slice_assign(
2521        tensor: Self::Primitive,
2522        ranges: &[Range<usize>],
2523        value: Self::Primitive,
2524    ) -> Self::Primitive {
2525        match (tensor, value) {
2526            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => {
2527                TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value))
2528            }
2529            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => {
2530                TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value))
2531            }
2532            _ => panic!("Primitive type mismatch for tensor and value"),
2533        }
2534    }
2535
2536    fn device(tensor: &Self::Primitive) -> Device<B> {
2537        match tensor {
2538            TensorPrimitive::Float(tensor) => B::float_device(tensor),
2539            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
2540        }
2541    }
2542
2543    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2544        match tensor {
2545            TensorPrimitive::Float(tensor) => {
2546                TensorPrimitive::Float(B::float_to_device(tensor, device))
2547            }
2548            TensorPrimitive::QFloat(tensor) => {
2549                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
2550            }
2551        }
2552    }
2553
2554    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2555        match tensor {
2556            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
2557            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
2558        }
2559    }
2560
2561    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2562        match data.dtype {
2563            DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
2564            _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
2565        }
2566    }
2567
2568    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2569        match dtype {
2570            DType::QFloat(_strategy) => {
2571                TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
2572            }
2573            _ if dtype.is_float() => {
2574                TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
2575            }
2576            _ => panic!("Expected float dtype, got {dtype:?}"),
2577        }
2578    }
2579
2580    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2581        match tensor {
2582            TensorPrimitive::Float(tensor) => {
2583                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
2584            }
2585            TensorPrimitive::QFloat(tensor) => {
2586                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
2587            }
2588        }
2589    }
2590
2591    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2592        match vectors.first().unwrap() {
2593            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
2594                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
2595                dim,
2596            )),
2597            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
2598                vectors
2599                    .into_iter()
2600                    .map(|tensor| {
2601                        if let TensorPrimitive::QFloat(t) = tensor {
2602                            t
2603                        } else {
2604                            panic!("Concatenation only works with vector of QFloat")
2605                        }
2606                    })
2607                    .collect(),
2608                dim,
2609            )),
2610        }
2611    }
2612
2613    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2614        B::float_equal(lhs.tensor(), rhs.tensor())
2615    }
2616
2617    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2618        B::float_not_equal(lhs.tensor(), rhs.tensor())
2619    }
2620
2621    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2622        B::float_any(tensor.tensor())
2623    }
2624
2625    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2626        B::float_any_dim(tensor.tensor(), dim)
2627    }
2628
2629    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2630        B::float_all(tensor.tensor())
2631    }
2632
2633    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2634        B::float_all_dim(tensor.tensor(), dim)
2635    }
2636
2637    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2638        match tensor {
2639            TensorPrimitive::Float(tensor) => {
2640                TensorPrimitive::Float(B::float_permute(tensor, axes))
2641            }
2642            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
2643        }
2644    }
2645
2646    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2647        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
2648    }
2649
2650    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2651        match tensor {
2652            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
2653            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
2654        }
2655    }
2656
2657    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2658        match tensor {
2659            TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim)
2660                .into_iter()
2661                .map(TensorPrimitive::Float)
2662                .collect(),
2663            TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim)
2664                .into_iter()
2665                .map(TensorPrimitive::QFloat)
2666                .collect(),
2667        }
2668    }
2669
2670    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2671        match tensor {
2672            TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
2673                .into_iter()
2674                .map(TensorPrimitive::Float)
2675                .collect(),
2676            TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
2677                .into_iter()
2678                .map(TensorPrimitive::QFloat)
2679                .collect(),
2680        }
2681    }
2682
2683    fn split_with_sizes(
2684        tensor: Self::Primitive,
2685        split_sizes: Vec<usize>,
2686        dim: usize,
2687    ) -> Vec<Self::Primitive> {
2688        match tensor {
2689            TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
2690                .into_iter()
2691                .map(TensorPrimitive::Float)
2692                .collect(),
2693            TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
2694                .into_iter()
2695                .map(TensorPrimitive::QFloat)
2696                .collect(),
2697        }
2698    }
2699}
2700
2701impl<B: Backend> BasicOps<B> for Int {
2702    type Elem = B::IntElem;
2703
2704    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2705        B::int_empty(shape, device)
2706    }
2707
2708    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2709        tr.register_int(tensor);
2710    }
2711
2712    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2713        B::int_reshape(tensor, shape)
2714    }
2715
2716    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2717        B::int_transpose(tensor)
2718    }
2719
2720    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2721        B::int_swap_dims(tensor, dim1, dim2)
2722    }
2723
2724    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2725        B::int_slice(tensor, ranges)
2726    }
2727
2728    fn slice_assign(
2729        tensor: Self::Primitive,
2730        ranges: &[Range<usize>],
2731        value: Self::Primitive,
2732    ) -> Self::Primitive {
2733        B::int_slice_assign(tensor, ranges, value)
2734    }
2735
2736    fn device(tensor: &Self::Primitive) -> Device<B> {
2737        B::int_device(tensor)
2738    }
2739
2740    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2741        B::int_to_device(tensor, device)
2742    }
2743
2744    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2745        B::int_into_data(tensor).await
2746    }
2747
2748    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2749        B::int_from_data(data.convert::<B::IntElem>(), device)
2750    }
2751
2752    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2753        if !dtype.is_int() {
2754            panic!("Expected int dtype, got {dtype:?}")
2755        }
2756
2757        B::int_from_data(data.convert_dtype(dtype), device)
2758    }
2759
2760    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2761        B::int_repeat_dim(tensor, dim, times)
2762    }
2763
2764    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2765        B::int_equal(lhs, rhs)
2766    }
2767
2768    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2769        B::int_not_equal(lhs, rhs)
2770    }
2771
2772    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2773        B::int_cat(vectors, dim)
2774    }
2775
2776    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2777        B::int_any(tensor)
2778    }
2779
2780    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2781        B::int_any_dim(tensor, dim)
2782    }
2783
2784    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2785        B::int_all(tensor)
2786    }
2787
2788    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2789        B::int_all_dim(tensor, dim)
2790    }
2791
2792    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2793        B::int_permute(tensor, axes)
2794    }
2795
2796    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2797        B::int_expand(tensor, shape)
2798    }
2799
2800    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2801        B::int_flip(tensor, axes)
2802    }
2803
2804    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2805        B::int_chunk(tensor, chunks, dim)
2806    }
2807
2808    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2809        B::int_split(tensor, split_size, dim)
2810    }
2811
2812    fn split_with_sizes(
2813        tensor: Self::Primitive,
2814        split_sizes: Vec<usize>,
2815        dim: usize,
2816    ) -> Vec<Self::Primitive> {
2817        B::int_split_with_sizes(tensor, split_sizes, dim)
2818    }
2819}
2820
2821impl<B: Backend> BasicOps<B> for Bool {
2822    type Elem = B::BoolElem;
2823
2824    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2825        B::bool_empty(shape, device)
2826    }
2827
2828    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2829        tr.register_bool(tensor);
2830    }
2831
2832    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2833        B::bool_reshape(tensor, shape)
2834    }
2835
2836    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2837        B::bool_transpose(tensor)
2838    }
2839
2840    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2841        B::bool_swap_dims(tensor, dim1, dim2)
2842    }
2843
2844    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2845        B::bool_slice(tensor, ranges)
2846    }
2847
2848    fn slice_assign(
2849        tensor: Self::Primitive,
2850        ranges: &[Range<usize>],
2851        value: Self::Primitive,
2852    ) -> Self::Primitive {
2853        B::bool_slice_assign(tensor, ranges, value)
2854    }
2855
2856    fn device(tensor: &Self::Primitive) -> Device<B> {
2857        B::bool_device(tensor)
2858    }
2859
2860    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2861        B::bool_to_device(tensor, device)
2862    }
2863
2864    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2865        B::bool_into_data(tensor).await
2866    }
2867
2868    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2869        B::bool_from_data(data.convert::<B::BoolElem>(), device)
2870    }
2871
2872    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2873        // Backends only use one bool representation dtype
2874        if dtype != B::BoolElem::dtype() {
2875            panic!("Expected bool dtype, got {dtype:?}")
2876        }
2877        B::bool_from_data(data.convert_dtype(dtype), device)
2878    }
2879
2880    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2881        B::bool_repeat_dim(tensor, dim, times)
2882    }
2883
2884    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2885        B::bool_equal(lhs, rhs)
2886    }
2887
2888    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2889        B::bool_not_equal(lhs, rhs)
2890    }
2891
2892    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2893        B::bool_cat(vectors, dim)
2894    }
2895
2896    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2897        B::bool_any(tensor)
2898    }
2899
2900    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2901        B::bool_any_dim(tensor, dim)
2902    }
2903
2904    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2905        B::bool_all(tensor)
2906    }
2907
2908    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2909        B::bool_all_dim(tensor, dim)
2910    }
2911
2912    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2913        B::bool_permute(tensor, axes)
2914    }
2915
2916    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2917        B::bool_expand(tensor, shape)
2918    }
2919
2920    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2921        B::bool_flip(tensor, axes)
2922    }
2923
2924    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2925        B::bool_chunk(tensor, chunks, dim)
2926    }
2927
2928    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2929        B::bool_split(tensor, split_size, dim)
2930    }
2931
2932    fn split_with_sizes(
2933        tensor: Self::Primitive,
2934        split_sizes: Vec<usize>,
2935        dim: usize,
2936    ) -> Vec<Self::Primitive> {
2937        B::bool_split_with_sizes(tensor, split_sizes, dim)
2938    }
2939}
2940
2941/// Trait used for movedim arguments
2942pub trait MovedimArgs {
2943    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
2944    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
2945}
2946
2947impl MovedimArgs for Vec<i32> {
2948    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2949        let set = self
2950            .iter()
2951            .map(|&dim| {
2952                if dim < 0 {
2953                    (D as i32 + dim) as usize
2954                } else {
2955                    dim as usize
2956                }
2957            })
2958            .collect::<Vec<usize>>();
2959        check!(TensorCheck::movedim_args_vec::<D>(&set));
2960
2961        set
2962    }
2963}
2964
2965impl MovedimArgs for Vec<usize> {
2966    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2967        check!(TensorCheck::movedim_args_vec::<D>(&self));
2968        self
2969    }
2970}
2971
2972impl MovedimArgs for usize {
2973    #[allow(clippy::vec_init_then_push)]
2974    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2975        check!(TensorCheck::movedim_args_usize::<D>(self));
2976
2977        let mut set = Vec::with_capacity(1);
2978        set.push(self);
2979
2980        set
2981    }
2982}
2983
2984impl MovedimArgs for i32 {
2985    #[allow(clippy::vec_init_then_push)]
2986    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2987        check!(TensorCheck::movedim_args_i32::<D>(self));
2988
2989        let dim = if self < 0 {
2990            (D as i32 + self) as usize
2991        } else {
2992            self as usize
2993        };
2994
2995        let mut set = Vec::with_capacity(1);
2996        set.push(dim);
2997
2998        set
2999    }
3000}
3001
3002/// Trait used for slice arguments
3003pub trait RangesArg<const D2: usize> {
3004    /// Converts into a set of ranges to `[Range<usize>; D2]` for the `tensor.slice()` function
3005    fn into_ranges(self, shape: Shape) -> [Range<usize>; D2];
3006}
3007
3008impl<const D2: usize, T: Into<Slice>> RangesArg<D2> for [T; D2] {
3009    fn into_ranges(self, shape: Shape) -> [Range<usize>; D2] {
3010        // clamp the ranges to the shape dimensions
3011        let ranges = self
3012            .into_iter()
3013            .enumerate()
3014            .map(|(i, range)| range.into().into_range(shape.dims[i]))
3015            .collect::<Vec<_>>();
3016        ranges.try_into().unwrap()
3017    }
3018}
3019
3020impl<T: Into<Slice>> RangesArg<1> for T {
3021    fn into_ranges(self, shape: Shape) -> [Range<usize>; 1] {
3022        [self.into().into_range(shape.dims[0])]
3023    }
3024}
3025
3026/// Trait used for reshape arguments.
3027pub trait ReshapeArgs<const D2: usize> {
3028    /// Converts to a shape.
3029    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3030        self,
3031        tensor: &Tensor<B, D, K>,
3032    ) -> Shape;
3033}
3034
3035impl<const D2: usize> ReshapeArgs<D2> for Shape {
3036    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3037        self,
3038        tensor: &Tensor<B, D, K>,
3039    ) -> Shape {
3040        check!(TensorCheck::reshape_args_usize::<D, D2>(
3041            &tensor.shape(),
3042            &self
3043        ));
3044
3045        self
3046    }
3047}
3048impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3049    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3050        self,
3051        tensor: &Tensor<B, D, K>,
3052    ) -> Shape {
3053        let shape = Shape::from(self);
3054
3055        check!(TensorCheck::reshape_args_usize::<D, D2>(
3056            &tensor.shape(),
3057            &shape
3058        ));
3059
3060        shape
3061    }
3062}
3063
3064impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3065    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3066        self,
3067        tensor: &Tensor<B, D, K>,
3068    ) -> Shape {
3069        // Validate the reshape arguments
3070        check!(TensorCheck::reshape_args_i32(&self));
3071
3072        // Temporary shape
3073        let mut new_shape: [i32; D2] = [1; D2];
3074
3075        // We need to find the index of the 0 dimension and
3076        // replace it with the actual dimension value.
3077        for (i, &s) in self.iter().enumerate() {
3078            if s != 0 {
3079                new_shape[i] = s;
3080            } else {
3081                new_shape[i] = tensor.dims()[i] as i32;
3082            }
3083        }
3084
3085        // Find the index of the inferred dimension (-1)
3086        let infer_index = new_shape.iter().position(|x| x == &-1);
3087
3088        // Handle the case where the dimension is inferred (via -1)
3089        if let Some(index) = infer_index {
3090            // Handle the case where the dimension is inferred
3091            let mut product = 1;
3092            for (i, &s) in new_shape.iter().enumerate() {
3093                if i != index {
3094                    product *= s;
3095                }
3096            }
3097            let product_current = tensor.shape().num_elements() as i32;
3098
3099            new_shape[index] = product_current / product;
3100
3101            // Check if the reshape is valid
3102            if product_current % product != 0 {
3103                panic!(
3104                    "Cannot reshape tensor of shape {:?} to shape {:?}",
3105                    tensor.shape(),
3106                    new_shape
3107                );
3108            }
3109        };
3110
3111        // Convert each element to usize
3112        let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3113
3114        Shape::from(new_shape)
3115    }
3116}
3117
3118/// Trait used for broadcast arguments.
3119pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3120    /// Converts to a shape.
3121    fn into_shape(self, shape: &Shape) -> Shape;
3122}
3123
3124impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3125    fn into_shape(self, _shape: &Shape) -> Shape {
3126        self
3127    }
3128}
3129impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3130    fn into_shape(self, _shape: &Shape) -> Shape {
3131        Shape::from(self)
3132    }
3133}
3134
3135impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3136    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3137    fn into_shape(self, shape: &Shape) -> Shape {
3138        if self.len() < shape.num_dims() {
3139            panic!("Broadcast arguments must be greater than the number of dimensions");
3140        }
3141
3142        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3143        let new_shape: Vec<_> = self
3144            .iter()
3145            .rev()
3146            .map(|x| {
3147                let primitive = x.to_i64();
3148                if primitive < -1 || primitive == 0 {
3149                    panic!("Broadcast arguments must be positive or -1");
3150                }
3151                primitive
3152            })
3153            .zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3154            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3155            .collect::<Vec<_>>()
3156            .into_iter()
3157            .rev()
3158            .collect();
3159
3160        if new_shape.contains(&0) {
3161            panic!("Cannot substitute -1 for a non-existing dimension");
3162        }
3163
3164        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3165
3166        Shape::from(new_shape)
3167    }
3168}
3169
3170impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3171where
3172    B: Backend,
3173    K: BasicOps<B>,
3174    K::Elem: Debug + Copy + Serialize,
3175{
3176    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3177        let data = self.to_data();
3178        data.serialize(serializer)
3179    }
3180}
3181
3182impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3183where
3184    B: Backend,
3185    K: BasicOps<B>,
3186    K::Elem: Debug + Copy + Deserialize<'de>,
3187{
3188    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3189        let tensor = Tensor::from_data(
3190            TensorData::deserialize(deserializer)?,
3191            &<B::Device as Default>::default(),
3192        );
3193        Ok(tensor)
3194    }
3195}
3196
3197#[cfg(test)]
3198mod tests {
3199    use crate::Shape;
3200    use crate::s;
3201
3202    use super::*;
3203
3204    #[test]
3205    fn slice_range_single_dim_leading() {
3206        let shape = Shape::new([8, 4]);
3207
3208        // Half-open range
3209        assert_eq!([0..5], (0..5).into_ranges(shape.clone()));
3210        assert_eq!([0..5], [0..5].into_ranges(shape.clone()));
3211        assert_eq!([5..7], [-3..-1].into_ranges(shape.clone()));
3212
3213        // Inclusive range
3214        assert_eq!([0..5], (0..=4).into_ranges(shape.clone()));
3215        assert_eq!([0..5], [0..=4].into_ranges(shape.clone()));
3216        assert_eq!([6..8], [-2..=-1].into_ranges(shape.clone()));
3217
3218        // Unbounded start
3219        assert_eq!([0..3], (..3).into_ranges(shape.clone()));
3220        assert_eq!([0..3], [..3].into_ranges(shape.clone()));
3221        assert_eq!([0..3], [..-5].into_ranges(shape.clone()));
3222
3223        // Unbounded end
3224        assert_eq!([5..8], (5..).into_ranges(shape.clone()));
3225        assert_eq!([5..8], [5..].into_ranges(shape.clone()));
3226        assert_eq!([5..8], [-3..].into_ranges(shape.clone()));
3227
3228        // Full range
3229        assert_eq!([0..8], [..].into_ranges(shape));
3230    }
3231
3232    #[test]
3233    fn slice_range_multi_dim() {
3234        let shape = Shape::new([8, 4]);
3235
3236        // Multiple ways to provide ranges
3237        assert_eq!([0..5, 0..4], [0..5, 0..4].into_ranges(shape.clone()));
3238        assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape.clone()));
3239        assert_eq!([0..8, 0..4], [0..=7, 0..=3].into_ranges(shape.clone()));
3240
3241        assert_eq!([0..5, 0..3], [0..5, 0..3].into_ranges(shape.clone()));
3242
3243        assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape));
3244    }
3245
3246    #[test]
3247    fn slice_range_multi_dim_index() {
3248        let shape = Shape::new([8, 4]);
3249
3250        // Indices (single integer) should also convert to correct range
3251        assert_eq!([0..1, 2..3], [0, 2].into_ranges(shape.clone()));
3252        assert_eq!([7..8, 3..4], [-1, -1].into_ranges(shape.clone()));
3253        assert_eq!([7..8], (-1).into_ranges(shape.clone()));
3254        assert_eq!([7..8], 7.into_ranges(shape));
3255    }
3256
3257    #[test]
3258    fn slice_range_multi_dim_heterogeneous() {
3259        // Slice macro `s![]` can be used to provide different range types
3260        let shape = Shape::new([8, 4, 2]);
3261        let slice = s![0..5, .., -1];
3262        assert_eq!([0..5, 0..4, 1..2], slice.into_ranges(shape));
3263
3264        let shape = Shape::new([8, 4, 2, 3]);
3265        let slice = s![..=4, 0..=3, .., -2..];
3266        assert_eq!([0..5, 0..4, 0..2, 1..3], slice.into_ranges(shape));
3267
3268        let shape = Shape::new([3, 4]);
3269        let slice = s![1..-1, ..];
3270        assert_eq!([1..2, 0..4], slice.into_ranges(shape));
3271    }
3272}