burn_tensor/tensor/api/
base.rs

1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::any::TypeId;
11use core::future::Future;
12use core::iter::repeat;
13use core::{fmt::Debug, ops::Range};
14use serde::{Deserialize, Deserializer};
15
16use serde::{Serialize, Serializer};
17
18use crate::check::TensorCheck;
19use crate::tensor::api::narrow::narrow;
20use crate::{
21    backend::Backend, check, ops::Device, Bool, Float, Int, Shape, TensorData, TensorKind,
22};
23use crate::{DType, Element, TensorPrimitive};
24
25use super::{TensorMetadata, Transaction};
26
27/// A tensor with a given backend, shape and data type.
28///
29/// # Indexing
30/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
31/// or [`select`](Tensor::select) for numeric types.
32///
33/// ## Example
34///
35/// ```rust
36/// use burn_tensor::backend::Backend;
37/// use burn_tensor::Tensor;
38/// use burn_tensor::Int;
39///
40/// fn example<B: Backend>() {
41///     let device = Default::default();
42///
43///     let tensor = Tensor::<B, 2>::from_data(
44///         [
45///             [3.0, 4.9, 2.0],
46///             [2.0, 1.9, 3.0],
47///             [6.0, 1.5, 7.0],
48///             [3.0, 4.9, 9.0],
49///         ],
50///         &device,
51///     );
52///
53///     // Slice the tensor to get the second and third rows:
54///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
55///     // The resulting tensor will have dimensions [2, 3].
56///     let slice = tensor.clone().slice([1..3]);
57///     println!("{slice}");
58///
59///     // Slice the tensor to get the first two rows and the first 2 columns:
60///     // [[3.0, 4.9], [2.0, 1.9]]
61///     // The resulting tensor will have dimensions [2, 2].
62///     let slice = tensor.clone().slice([0..2, 0..2]);
63///     println!("{slice}");
64///
65///     // Index the tensor along the dimension 1 to get the elements 0 and 2:
66///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
67///     // The resulting tensor will have dimensions [4, 2]
68///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
69///     let indexed = tensor.select(1, indices);
70///     println!("{indexed}");
71/// }
72/// ```
73#[derive(new, Clone, Debug)]
74pub struct Tensor<B, const D: usize, K = Float>
75where
76    B: Backend,
77    K: TensorKind<B>,
78{
79    pub(crate) primitive: K::Primitive,
80}
81
82impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
83where
84    B: Backend,
85    K: BasicOps<B>,
86    T: Into<TensorData>,
87{
88    fn from(value: T) -> Self {
89        Tensor::from_data(value.into(), &Default::default())
90    }
91}
92
93impl<B, const D: usize, K> Tensor<B, D, K>
94where
95    B: Backend,
96    K: BasicOps<B>,
97{
98    /// Converts the tensor into a primitive tensor.
99    pub fn into_primitive(self) -> K::Primitive {
100        self.primitive
101    }
102
103    /// Converts from a primitive tensor into a tensor.
104    pub fn from_primitive(tensor: K::Primitive) -> Self {
105        Self::new(tensor)
106    }
107
108    /// Returns the tensor primitive data type.
109    ///
110    /// # Note
111    /// Some element types are encoded in different primitive types depending on the backend
112    /// (e.g., bool could be encoded as `u8` or `u32`).
113    pub fn dtype(&self) -> DType {
114        self.primitive.dtype()
115    }
116
117    /// Create an empty tensor of the given shape.
118    ///
119    /// # Arguments
120    ///
121    /// - shape: The shape of the tensor.
122    /// - device: The device where the tensor will be created.
123    ///
124    /// # Example
125    /// ```rust
126    /// use burn_tensor::backend::Backend;
127    /// use burn_tensor::Tensor;
128    ///
129    /// fn example<B: Backend>() {
130    ///    let device = Default::default();
131    ///    // Create an empty tensor with dimensions [2, 3, 4].
132    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
133    /// }
134    /// ```
135    pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
136        let shape = shape.into();
137        check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
138        Self::new(K::empty(shape, device))
139    }
140
141    /// Returns the dimensions of the current tensor.
142    ///
143    /// # Example
144    /// ```rust
145    /// use burn_tensor::backend::Backend;
146    /// use burn_tensor::Tensor;
147    ///
148    /// fn example<B: Backend>() {
149    ///   let device = Default::default();
150    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
151    ///   let dims = tensor.dims(); // [2, 3, 4]
152    ///   println!("{dims:?}");
153    /// }
154    /// ```
155    pub fn dims(&self) -> [usize; D] {
156        Self::shape(self).dims()
157    }
158
159    /// Returns the shape of the current tensor.
160    ///
161    /// # Example
162    /// ```rust
163    /// use burn_tensor::backend::Backend;
164    /// use burn_tensor::Tensor;
165    ///
166    /// fn example<B: Backend>() {
167    ///    let device = Default::default();
168    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
169    ///    // Shape { dims: [2, 3, 4] }
170    ///    let shape = tensor.shape();
171    /// }
172    /// ```
173    pub fn shape(&self) -> Shape {
174        self.primitive.shape()
175    }
176
177    /// Reshape the tensor to have the given shape.
178    ///
179    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
180    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
181    ///
182    /// A `0` in the shape instructs to keep the current dimension from the original tensor,
183    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
184    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
185    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
186    /// with [1, 3, 4] dimensions to [1, 12].
187    ///
188    /// # Arguments
189    /// - `shape`: The new shape of the tensor.
190    ///
191    /// # Panics
192    /// - If the tensor contains more than one `-1` in the shape.
193    /// - If the tensor contains values that are not positive (other than -1).
194    /// - If the shape does not match the number of elements of the original shape.
195    ///
196    /// # Example
197    ///
198    /// ```rust
199    /// use burn_tensor::backend::Backend;
200    /// use burn_tensor::Tensor;
201    ///
202    /// fn example<B: Backend>() {
203    ///    let device = Default::default();
204    ///    // Create a tensor with dimensions [2, 3, 4]
205    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
206    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.
207    ///    let reshaped = tensor.reshape([2, -1]);
208    ///    println!("{reshaped}");
209    /// }
210    /// ```
211    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
212        // Convert reshape args to shape
213        let shape = shape.into_shape(&self);
214        Tensor::new(K::reshape(self.primitive, shape))
215    }
216
217    /// Transpose the tensor.
218    ///
219    /// # Arguments
220    ///
221    /// * `tensor` - The tensor to transpose.
222    ///
223    /// # Returns
224    ///
225    /// The transposed tensor.
226    ///
227    /// # Example
228    ///
229    /// ```rust
230    /// use burn_tensor::backend::Backend;
231    /// use burn_tensor::Tensor;
232    ///
233    /// fn example<B: Backend>() {
234    ///     let device = Default::default();
235    ///     // Create a 2D tensor of shape [2, 3]
236    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
237    ///
238    ///     // Transpose the tensor:
239    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
240    ///     // The resulting tensor will have dimensions [3, 2].
241    ///     let transposed = tensor.transpose();
242    ///     println!("{transposed}");
243    /// }
244    /// ```
245    pub fn transpose(self) -> Tensor<B, D, K> {
246        Tensor::new(K::transpose(self.primitive))
247    }
248
249    /// Swaps two dimensions of a tensor.
250    ///
251    /// # Arguments
252    ///
253    /// * `tensor` - The tensor to swap the dimensions of.
254    /// * `dim1` - The first dimension to swap.
255    /// * `dim2` - The second dimension to swap.
256    ///
257    /// # Returns
258    ///
259    /// The tensor with the dimensions swapped.
260    ///
261    /// # Example
262    ///
263    /// ```rust
264    /// use burn_tensor::backend::Backend;
265    /// use burn_tensor::Tensor;
266    ///
267    /// fn example<B: Backend>() {
268    ///     let device = Default::default();
269    ///     // Create a 2D tensor of shape [2, 3]
270    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
271    ///
272    ///     // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`):
273    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
274    ///     // The resulting tensor will have dimensions [3, 2].
275    ///     let swapped = tensor.swap_dims(0, 1);
276    ///     println!("{swapped}");
277    /// }
278    /// ```
279    pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
280        check!(TensorCheck::swap_dims::<D>(dim1, dim2));
281        Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
282    }
283
284    /// Permute the dimensions of the tensor.
285    ///
286    /// # Arguments
287    ///
288    /// * `axes` - The new order of the dimensions. The length of the axes
289    ///            must be equal to the number of dimensions of the tensor.
290    ///            The values must be unique and in the range of the number of dimensions.
291    ///            The values can be negative, in which case they are used as an offset from the end.
292    ///
293    /// # Returns
294    ///
295    /// The tensor with the dimensions permuted.
296    ///
297    /// # Example
298    ///
299    /// ```rust
300    /// use burn_tensor::backend::Backend;
301    /// use burn_tensor::Tensor;
302    ///
303    /// fn example<B: Backend>() {
304    ///     let device = Default::default();
305    ///     // Create a 2D tensor of shape [3, 2]
306    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
307    ///
308    ///     // Permute the dimensions 1 and 0:
309    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
310    ///     // The resulting tensor will have dimensions [3, 2].
311    ///     let permuted = tensor.permute([1, 0]);
312    ///     println!("{permuted}");
313    /// }
314    /// ```
315    pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
316        // Convert the axes to usize and handle negative values without using vector
317        let mut transformed_axes: [usize; D] = [0; D];
318        for (i, &x) in axes.iter().enumerate() {
319            transformed_axes[i] = if x < 0 {
320                (D as isize + x) as usize
321            } else {
322                x as usize
323            };
324        }
325
326        // Check if the axes are valid after the transformation
327        check!(TensorCheck::permute(transformed_axes));
328
329        Tensor::new(K::permute(self.primitive, &transformed_axes))
330    }
331
332    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
333    ///
334    /// Other dimensions of input that are not explicitly moved remain in their original order and appear
335    /// at the positions not specified in destination.
336    ///
337    /// # Arguments
338    ///
339    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
340    ///              The values can be negative, in which case they are used as an offset from the end.
341    ///
342    /// * `dst` - Destination positions for each of the original dims. These must also be unique.
343    ///
344    /// # Panics
345    ///
346    /// - If the source and destination dimensions are not of the same length.
347    /// - If the source and destination vectors contain duplicate values.
348    /// - If the source and destination vectors contain values that are out of bounds.
349    ///
350    /// # Returns
351    ///
352    /// The tensor with the dimensions moved.
353    ///
354    /// # Example
355    ///
356    /// ```rust
357    /// use burn_tensor::backend::Backend;
358    /// use burn_tensor::Tensor;
359    ///
360    /// fn example<B: Backend>() {
361    ///     let device = Default::default();
362    ///     // Create a 3D tensor of shape [3, 2, 1]
363    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
364    ///
365    ///     // Move the dimensions 0 and 1:
366    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
367    ///     // The resulting tensor will have dimensions [2, 3, 1].
368    ///     let moved = tensor.movedim(1, 0);
369    ///     println!("{moved}");
370    /// }
371    /// ```
372    // This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
373    // for it
374    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
375        let source_dims = src.into_dim_vec::<D>();
376        let destination_dims = dst.into_dim_vec::<D>();
377
378        check!(TensorCheck::movedim_args_length(
379            &source_dims,
380            &destination_dims
381        ));
382
383        let mut m = [-1; D];
384        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
385            m[d] = s as isize;
386        }
387        let mut axes: [isize; D] = [0; D];
388        let mut source_i = 0;
389        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
390            *item = if m[dest_i] != -1 {
391                m[dest_i]
392            } else {
393                while source_dims.contains(&source_i) {
394                    source_i += 1;
395                }
396                let result = source_i as isize;
397                source_i += 1;
398                result
399            };
400        }
401
402        self.permute(axes)
403    }
404
405    /// Reverse the order of elements in the tensor along the given dimensions.
406    ///
407    /// # Arguments
408    ///
409    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
410    ///            The values can be negative, in which case they are used as an offset from the end.
411    ///
412    /// # Returns
413    ///
414    /// The tensor with the axes flipped.
415    ///
416    /// # Example
417    ///
418    /// ```rust
419    /// use burn_tensor::backend::Backend;
420    /// use burn_tensor::Tensor;
421    ///
422    /// fn example<B: Backend>() {
423    ///     let device = Default::default();
424    ///     // Create a 2D tensor with dimensions [4, 3]
425    ///     let tensor = Tensor::<B, 2>::from_data(
426    ///         [
427    ///             [3.0, 4.9, 2.0],
428    ///             [2.0, 1.9, 3.0],
429    ///             [4.0, 5.9, 8.0],
430    ///             [1.4, 5.8, 6.0],
431    ///         ],
432    ///         &device,
433    ///     );
434    ///
435    ///     // Flip the elements in dimensions 0 and 1:
436    ///     // [[6.0, 5.8, 1.4],
437    ///     //  [8.0, 5.9, 4.0],
438    ///     //  [3.0, 1.9, 2.0],
439    ///     //  [2.0, 4.9, 3.0]]
440    ///     // The resulting tensor will have dimensions [4, 3].
441    ///     let flipped = tensor.flip([0, 1]);
442    ///     println!("{flipped}");
443    /// }
444    /// ```
445    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
446        // Convert the axes to usize and handle negative values without using vector
447        let mut transformed_axes: [usize; N] = [0; N];
448        for (i, &x) in axes.iter().enumerate() {
449            transformed_axes[i] = if x < 0 {
450                (D as isize + x) as usize
451            } else {
452                x as usize
453            };
454        }
455
456        // Check if the axes are valid
457        check!(TensorCheck::flip(D, &transformed_axes));
458
459        Tensor::new(K::flip(self.primitive, &transformed_axes))
460    }
461
462    /// Flatten the tensor along a given range of dimensions.
463    ///
464    /// This function collapses the specified range of dimensions into a single dimension,
465    /// effectively flattening the tensor in that range.
466    ///
467    /// # Arguments
468    ///
469    /// - `start_dim`: The starting dimension of the range to be flattened.
470    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
471    ///
472    /// # Type Parameters
473    ///
474    /// - `D2`: The resulting number of dimensions in the flattened tensor.
475    ///
476    /// # Returns
477    ///
478    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
479    ///
480    /// # Example
481    ///
482    /// ```rust
483    ///
484    /// use burn_tensor::backend::Backend;
485    /// use burn_tensor::{Tensor, Shape};
486    ///
487    /// fn example<B: Backend>() {
488    ///     let device = Default::default();
489    ///     // Create a 3D tensor with dimensions [2, 3, 4]
490    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
491    ///
492    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).
493    ///     // The resulting tensor will have dimensions [2, 12]
494    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
495    ///     println!("{flattened}");
496    /// }
497    /// ```
498    pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
499        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
500
501        let current_dims = self.shape().dims;
502        let mut new_dims: [usize; D2] = [0; D2];
503        let mut flatten_dims = 1;
504
505        for i in current_dims[start_dim..=end_dim].iter() {
506            flatten_dims *= i;
507        }
508
509        new_dims[..start_dim].copy_from_slice(&current_dims[..start_dim]);
510        new_dims[start_dim] = flatten_dims;
511        new_dims[start_dim + 1..].copy_from_slice(&current_dims[end_dim + 1..]);
512
513        Tensor::new(K::reshape(self.primitive, new_dims.into()))
514    }
515
516    /// Squeeze the tensor along the given dimension, removing the specified dimension
517    /// of size one, and effectively reducing the rank of the tensor by one.
518    ///
519    /// # Arguments
520    ///
521    /// - `dim`: The dimension to be squeezed.
522    ///
523    /// # Type Parameters
524    ///
525    ///  - 'D2': The resulting number of dimensions in the squeezed tensor.
526    ///
527    /// # Returns
528    ///
529    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
530    ///
531    /// # Example
532    ///
533    /// ```rust
534    ///
535    /// use burn_tensor::backend::Backend;
536    /// use burn_tensor::{Tensor, Shape};
537    ///
538    /// fn example<B: Backend>() {
539    ///     let device = Default::default();
540    ///     // Create a 3D tensor with dimensions [3, 1, 3]
541    ///     let tensor = Tensor::<B, 3>::from_data(
542    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
543    ///         &device,
544    ///     );
545    ///
546    ///     // Squeeze the dimension 1.
547    ///     // The resulting tensor will have dimensions [3, 3].
548    ///     let squeezed = tensor.squeeze::<2>(1);
549    ///     println!("{squeezed}");
550    /// }
551    /// ```
552    pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
553        check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
554
555        let current_dims = self.shape().dims;
556        let mut new_dims: [usize; D2] = [0; D2];
557
558        new_dims[..dim].copy_from_slice(&current_dims[..dim]);
559        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
560
561        Tensor::new(K::reshape(self.primitive, new_dims.into()))
562    }
563
564    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
565    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
566    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
567    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
568    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
569    /// from the back.
570    ///
571    /// # Arguments
572    ///
573    /// - `dims`: The dimension(s) to be squeezed.
574    ///
575    /// # Type Parameters
576    ///
577    ///  - 'D2': The resulting number of dimensions in the squeezed tensor.
578    ///
579    /// # Returns
580    ///
581    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
582    ///
583    /// # Example
584    ///
585    /// ```rust
586    ///
587    /// use burn_tensor::backend::Backend;
588    /// use burn_tensor::{Tensor, Shape};
589    ///
590    /// fn example<B: Backend>() {
591    ///     let device = Default::default();
592    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]
593    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
594    ///
595    ///     // Squeeze the dimensions 1 and 3.
596    ///     // The resulting tensor will have dimensions [2, 4].
597    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
598    ///     println!("{squeezed}");
599    /// }
600    /// ```
601    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
602        let current_dims = self.shape().dims;
603        let mut dim_indices: Vec<usize>;
604
605        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
606        if dims.is_empty() {
607            dim_indices = current_dims
608                .iter()
609                .enumerate()
610                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
611                .collect();
612        } else {
613            // If negative dims, count from the back
614            dim_indices = dims
615                .iter()
616                .map(|&d| {
617                    if d < 0 {
618                        (current_dims.len() as isize + d) as usize
619                    } else {
620                        d as usize
621                    }
622                })
623                .collect();
624        }
625
626        // Sort indices and remove duplicates
627        dim_indices.sort_unstable();
628        dim_indices.dedup();
629
630        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
631        check!(TensorCheck::squeeze_dims_input::<D2>(
632            &dim_indices,
633            &current_dims
634        ));
635
636        // Calculate new dimensions
637        let mut new_dims = Vec::new();
638        for (index, &dim_size) in current_dims.iter().enumerate() {
639            // Exclude the dimension if it's explicitly marked for squeezing
640            if dim_indices.contains(&index) {
641                check!(TensorCheck::squeeze::<D2>(index, &current_dims));
642                continue;
643            }
644            new_dims.push(dim_size);
645        }
646
647        // Check that after squeezing, we still respect the D2 size
648        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
649
650        Tensor::new(K::reshape(self.primitive, new_dims.into()))
651    }
652
653    /// Unsqueeze the current tensor. Create new dimensions to fit the given size.
654    ///
655    /// If the output size is higher than the current tensor.
656    ///
657    /// # Example
658    ///
659    /// ```rust
660    /// use burn_tensor::backend::Backend;
661    /// use burn_tensor::{Tensor, Shape};
662    ///
663    /// fn example<B: Backend>() {
664    ///     let device = Default::default();
665    ///     // Create a 2D tensor with dimensions [3, 3]
666    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
667    ///     // Unsqueeze the tensor up to 4 dimensions.
668    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].
669    ///     let unsqueezed = tensor.unsqueeze::<4>();
670    ///     println!("{unsqueezed}");
671    /// }
672    /// ```
673    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
674        check!(TensorCheck::unsqueeze::<D, D2>());
675
676        let mut dims = [1; D2];
677        let num_ones = D2 - D;
678        let shape = self.shape();
679
680        dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
681
682        let shape = Shape::new(dims);
683        self.reshape(shape)
684    }
685
686    /// Creates a new tensor with a dimension of size one inserted at the specified position.
687    ///
688    /// # Example
689    ///
690    /// ```rust
691    /// use burn_tensor::backend::Backend;
692    /// use burn_tensor::{Tensor, Shape};
693    ///
694    /// fn example<B: Backend>() {
695    ///     let device = Default::default();
696    ///     // Create a 2D tensor with dimensions [3, 3]
697    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
698    ///     // Unsqueeze the dimension 1.
699    ///     // The resulting tensor will have dimensions [3, 1, 3].
700    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
701    ///     println!("{unsqueezed}");
702    /// }
703    /// ```
704    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
705        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
706
707        let mut dims = [1; D2];
708        let shape = self.shape();
709
710        dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
711
712        if dim < D {
713            dims[dim] = 1;
714            dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
715        } else {
716            dims[dim] = 1;
717        }
718
719        let shape = Shape::new(dims);
720        self.reshape(shape)
721    }
722
723    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
724    /// The indices can be negative, in which case they are counted from the last to the first dimension.
725    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
726    /// is the number of duplicates.
727    /// # Example
728    ///
729    /// ```rust
730    /// use burn_tensor::backend::Backend;
731    /// use burn_tensor::{Tensor, Shape};
732    ///
733    /// fn example<B: Backend>() {
734    ///     let device = Default::default();
735    ///     // Create a 3D tensor with dimensions [3, 4, 5]
736    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
737    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
738    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
739    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
740    ///     println!("{unsqueezed}");
741    /// }
742    /// ```
743    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
744        let mut new_dims = [1; D2];
745        let old_dims = self.shape().dims;
746        //for checking if the dimension is in the acceptable range
747
748        //part 1: convert the negative indices to positive
749        let mut neg_offset = D2;
750        let mut dim_indices = axes
751            .iter()
752            .map(|d| {
753                // check if the dimension is in the acceptable range
754                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
755                (if *d < 0 {
756                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
757                    d + neg_offset as isize + 1
758                } else {
759                    *d
760                }) as usize
761            })
762            .collect::<Vec<usize>>();
763
764        //sort the indices
765        dim_indices.sort_unstable();
766
767        //Now use this to copy the chunks of the dims
768        let mut prev_idx: usize = 0;
769        let mut current_left_b: usize = 0;
770        let mut current_right_b: usize = 0;
771        let mut offset: usize = 0;
772        dim_indices.iter().for_each(|d| {
773            //check if there is space for at least one dimension
774            if prev_idx < *d {
775                current_right_b = *d - offset;
776                //copy the chunks of the dims
777                if current_right_b < D {
778                    new_dims[prev_idx..*d]
779                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);
780                } else {
781                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
782                }
783                prev_idx = *d + 1;
784                //offset is equal to the number of extracted elements from the original shape
785                offset += current_right_b - current_left_b;
786                current_left_b = current_right_b;
787            } else {
788                //it's sorted so the only reason this would happen
789                //is if multiple indices are the same
790                prev_idx += 1;
791            }
792        });
793        //copy over anything past the index of the last new dimension
794        if current_left_b < D {
795            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
796        }
797
798        //lastly, create the shape and reshape
799        let shape = Shape::new(new_dims);
800        self.reshape(shape)
801    }
802
803    /// Returns a tensor containing the elements selected from the given ranges.
804    ///
805    /// # Arguments
806    ///
807    /// * `ranges` - A type implementing the `RangesArg` trait, which can be:
808    ///   - An array of `core::ops::Range<usize>`
809    ///   - An array of `Option<(i64, i64)>`
810    ///   - An array of `(i64, i64)` tuples
811    ///
812    /// # Behavior
813    ///
814    /// - Supports partial and full slicing in any number of dimensions.
815    /// - Missing ranges are treated as full slices if D > D2.
816    /// - Handles negative indices by wrapping around from the end of the dimension.
817    /// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
818    /// - For `Option<(i64, i64)>` ranges, `None` selects the full range of that dimension.
819    ///
820    /// # Panics
821    ///
822    /// - If the number of ranges provided exceeds the tensor's dimensions.
823    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
824    ///
825    /// # Examples
826    ///
827    /// ```rust
828    /// use burn_tensor::backend::Backend;
829    /// use burn_tensor::{Tensor, Shape};
830    ///
831    /// fn example<B: Backend>() {
832    ///     let device = B::Device::default();
833    ///
834    ///     // 1D slicing
835    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
836    ///     let slice = tensor.slice([1..4]);
837    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
838    ///
839    ///     // 2D slicing
840    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
841    ///     let slice = tensor.slice([1..3, 0..2]);
842    ///     assert_eq!(slice.dims(), [2, 2]);
843    ///
844    ///     // Using negative indices
845    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
846    ///     let slice = tensor.slice([(1, -1)]); // Equivalent to 1..4
847    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
848    ///
849    ///     // Using Option<(i64, i64)>
850    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
851    ///     let slice = tensor.slice([Some((1, -1)), None]); // Select rows 1 and 2, all columns
852    ///     assert_eq!(slice.dims(), [2, 4]);
853    /// }
854    /// ```
855    ///
856    /// # Note
857    ///
858    /// This function uses the `RangesArg` trait for flexible range specification. The trait
859    /// handles the conversion of various range formats and applies clamping and negative
860    /// index handling internally.
861    pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
862        let ranges = ranges.into_ranges(self.shape());
863
864        check!(TensorCheck::slice::<D, D2>(&self.shape(), &ranges));
865        Self::new(K::slice(self.primitive, &ranges))
866    }
867
868    /// Returns a copy of the current tensor with the selected elements changed to the new ones at
869    /// the selected indices.
870    ///
871    /// # Panics
872    ///
873    /// - If a range exceeds the number of elements on a dimension.
874    /// - If the given values don't match the given ranges.
875    ///
876    /// # Example
877    ///
878    /// ```rust
879    /// use burn_tensor::backend::Backend;
880    /// use burn_tensor::Tensor;
881    ///
882    /// fn example<B: Backend>() {
883    ///     let device = B::Device::default();
884    ///     let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
885    ///     let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
886    ///     let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
887    ///     println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
888    /// }
889    /// ```
890    pub fn slice_assign<const D2: usize>(
891        self,
892        ranges: [core::ops::Range<usize>; D2],
893        values: Self,
894    ) -> Self {
895        check!(TensorCheck::slice_assign::<D, D2>(
896            &self.shape(),
897            &values.shape(),
898            &ranges
899        ));
900        Self::new(K::slice_assign(self.primitive, &ranges, values.primitive))
901    }
902
903    /// Returns the device of the current tensor.
904    pub fn device(&self) -> B::Device {
905        K::device(&self.primitive)
906    }
907
908    /// Returns a new tensor on the given device.
909    pub fn to_device(self, device: &B::Device) -> Self {
910        Self::new(K::to_device(self.primitive, device))
911    }
912
913    /// Converts the data of the current tensor.
914    ///
915    /// # Note
916    ///
917    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
918    /// tensors at once. This may improve laziness, especially if executed on a different
919    /// thread in native environments.
920    pub fn into_data(self) -> TensorData {
921        crate::try_read_sync(self.into_data_async()).expect(
922            "Failed to read tensor data synchronously.
923        This can happen on platforms that don't support blocking futures like WASM.
924        If possible, try using into_data_async instead.",
925        )
926    }
927
928    /// Converts the data of the current tensor.
929    ///
930    /// # Note
931    ///
932    /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
933    /// tensors at once. This may improve laziness, especially if executed on a different
934    /// thread in native environments.
935    pub fn to_data(&self) -> TensorData {
936        self.clone().into_data()
937    }
938
939    /// Returns the data of the current tensor.
940    pub async fn into_data_async(self) -> TensorData {
941        K::into_data_async(self.primitive).await
942    }
943
944    /// Returns the data of the current tensor.
945    pub async fn to_data_async(&self) -> TensorData {
946        self.clone().into_data_async().await
947    }
948
949    /// Create a tensor from the given data on the given device.
950    pub fn from_data<T>(data: T, device: &B::Device) -> Self
951    where
952        T: Into<TensorData>,
953    {
954        let data = data.into();
955        check!(TensorCheck::creation_ops::<D>(
956            "From Data",
957            data.shape.as_slice()
958        ));
959        Self::new(K::from_data(data, device))
960    }
961
962    /// Repeat the tensor along the given dimension.
963    ///
964    ///
965    /// # Arguments
966    /// - `dim`: The dimension to repeat.
967    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
968    ///
969    /// # Returns
970    ///
971    /// A new tensor with the given dimension repeated `times` times.
972    ///
973    /// # Example
974    ///
975    /// ```rust
976    /// use burn_tensor::backend::Backend;
977    /// use burn_tensor::Tensor;
978    ///
979    /// fn example<B: Backend>() {
980    ///     let device = Default::default();
981    ///     // Create a 2D tensor with dimensions [3, 2]
982    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
983    ///
984    ///     // Repeat the tensor along the dimension 0 twice.
985    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
986    ///     // The resulting tensor will have dimensions [6, 2].
987    ///     let repeated = tensor.repeat_dim(0, 2);
988    ///     println!("{repeated}");
989    /// }
990    /// ```
991    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
992        Self::new(K::repeat_dim(self.primitive, dim, times))
993    }
994
995    /// Repeat the tensor along the given dimensions.
996    /// # Arguments
997    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
998    ///
999    /// # Returns
1000    ///
1001    /// A new tensor with the given dimensions repeated `times` times.
1002    ///
1003    /// # Example
1004    ///
1005    /// ```rust
1006    ///
1007    /// use burn_tensor::backend::Backend;
1008    /// use burn_tensor::Tensor;
1009    ///
1010    /// fn example<B: Backend>() {
1011    ///     let device = Default::default();
1012    ///     // Create a 2D tensor with dimensions [3, 2]
1013    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1014    ///
1015    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1016    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1017    ///     // The resulting tensor will have dimensions [6, 2].
1018    ///     let repeated = tensor.repeat(&[2, 1]);
1019    /// }
1020    /// ```
1021    pub fn repeat(self, sizes: &[usize]) -> Self {
1022        let mut tensor = self;
1023        for (dim, &times) in sizes.iter().enumerate() {
1024            if times > 1 {
1025                tensor = tensor.repeat_dim(dim, times);
1026            }
1027        }
1028        tensor
1029    }
1030
1031    /// Applies element-wise equal comparison and returns a boolean tensor.
1032    ///
1033    /// # Panics
1034    ///
1035    /// If the two tensors don't have the same shape.
1036    ///
1037    /// # Example
1038    ///
1039    /// ```rust
1040    /// use burn_tensor::backend::Backend;
1041    /// use burn_tensor::Tensor;
1042    ///
1043    /// fn example<B: Backend>() {
1044    ///     let device = Default::default();
1045    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1046    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1047    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].
1048    ///     // [[false, true], [true, true], [true, true]]
1049    ///     let equal = t1.equal(t2);
1050    ///     println!("{equal}");
1051    /// }
1052    /// ```
1053    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1054        check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1055        Tensor::new(K::equal(self.primitive, other.primitive))
1056    }
1057
1058    /// Applies element-wise non-equality comparison and returns a boolean tensor.
1059    ///
1060    /// # Panics
1061    ///
1062    /// If the two tensors don't have the same shape.
1063    ///
1064    /// # Example
1065    ///
1066    /// ```rust
1067    /// use burn_tensor::backend::Backend;
1068    /// use burn_tensor::Tensor;
1069    ///
1070    /// fn example<B: Backend>() {
1071    ///     let device = Default::default();
1072    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1073    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1074    ///     // Compare the elements of the two 2D tensors for inequality.
1075    ///     // [[true, false], [false, false], [false, false]]
1076    ///     let not_equal = t1.not_equal(t2);
1077    ///     println!("{not_equal}");
1078    /// }
1079    /// ```
1080    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1081        check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1082        Tensor::new(K::not_equal(self.primitive, other.primitive))
1083    }
1084
1085    /// Concatenates all tensors into a new one along the given dimension.
1086    ///
1087    /// # Panics
1088    ///
1089    /// If all tensors don't have the same shape.
1090    ///
1091    /// # Example
1092    ///
1093    /// ```rust
1094    /// use burn_tensor::backend::Backend;
1095    /// use burn_tensor::Tensor;
1096    ///
1097    /// fn example<B: Backend>() {
1098    ///     let device = Default::default();
1099    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1100    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1101    ///
1102    ///     // Concatenate the two tensors with shape [2, 3] along the dimension 1.
1103    ///     // [[3.0, 4.9, 2.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.4, 5.8, 6.0]]
1104    ///     // The resulting tensor will have shape [2, 6].
1105    ///     let concat = Tensor::cat(vec![t1, t2], 1);
1106    ///     println!("{concat}");
1107    /// }
1108    /// ```
1109    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1110        check!(TensorCheck::cat(&tensors, dim));
1111
1112        Self::new(K::cat(
1113            tensors.into_iter().map(|vector| vector.primitive).collect(),
1114            dim,
1115        ))
1116    }
1117
1118    /// Concatenates all tensors into a new one along a new dimension.
1119    ///
1120    /// # Panics
1121    ///
1122    /// If all tensors don't have the same shape.
1123    /// Given dimension is not with range of 0..D2
1124    ///
1125    /// # Example
1126    ///
1127    /// ```rust
1128    /// use burn_tensor::backend::Backend;
1129    /// use burn_tensor::Tensor;
1130    ///
1131    /// fn example<B: Backend>() {
1132    ///     let device = Default::default();
1133    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1134    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1135    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1136    ///
1137    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1138    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1139    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1140    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1141    ///     // The resulting tensor will have shape [3, 2, 3].
1142    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1143    ///     println!("{stacked}");
1144    /// }
1145    /// ```
1146    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1147        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1148        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1149        Tensor::<B, D2, K>::cat(tensors, dim)
1150    }
1151
1152    /// Iterate over slices of tensors alongside a given dimension.
1153    ///
1154    /// # Panics
1155    ///
1156    /// Given dimension is less than tensor rank.
1157    ///
1158    /// # Returns
1159    ///
1160    /// A tensor iterator.
1161    ///
1162    /// # Example
1163    ///
1164    /// ```rust
1165    /// use burn_tensor::backend::Backend;
1166    /// use burn_tensor::Tensor;
1167    /// fn example<B: Backend>() {
1168    ///   let device = Default::default();
1169    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1170    ///   // Given a 2D tensor with dimensions (2, 3), iterate over slices of tensors along the dimension 0.
1171    ///   let iter = tensor.iter_dim(0);
1172    ///   for (i,tensor) in iter.enumerate() {
1173    ///     println!("Tensor {}: {}", i, tensor);
1174    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1175    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1176    ///  }
1177    /// }
1178    /// ```
1179    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1180        check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1181        DimIter::new(self, dim)
1182    }
1183
1184    /// Returns a new tensor with the given dimension narrowed to the given range.
1185    ///
1186    /// # Panics
1187    ///
1188    /// - If the dimension is greater than the number of dimensions of the tensor.
1189    /// - If the given range exceeds the number of elements on the given dimension.
1190    ///
1191    /// # Returns
1192    ///
1193    /// A new tensor with the given dimension narrowed to the given range.
1194    ///
1195    /// # Example
1196    ///
1197    /// ```rust
1198    /// use burn_tensor::backend::Backend;
1199    /// use burn_tensor::Tensor;
1200    ///
1201    /// fn example<B: Backend>() {
1202    ///     let device = Default::default();
1203    ///     // Create a 2D tensor with dimensions [4, 3]
1204    ///     let tensor = Tensor::<B, 2>::from_data(
1205    ///         [
1206    ///             [3.0, 4.9, 2.0],
1207    ///             [2.0, 1.9, 3.0],
1208    ///             [6.0, 1.5, 7.0],
1209    ///             [3.0, 4.9, 9.0],
1210    ///         ],
1211    ///         &device,
1212    ///     );
1213    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1214    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1215    ///     // The resulting tensor will have dimensions [3, 3].
1216    ///     let narrowed = tensor.narrow(0, 1, 3);
1217    ///     println!("{narrowed}");
1218    /// }
1219    /// ```
1220    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1221        check!(TensorCheck::dim_ops::<D>("narrow", dim));
1222        check!(TensorCheck::narrow(&self, dim, start, length));
1223        Self::new(narrow::<B, K>(self.primitive, dim, start, length))
1224    }
1225
1226    /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1227    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1228    ///
1229    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1230    /// Otherwise all chunks will be of equal size except for the last one.
1231    ///
1232    /// # Panics
1233    ///
1234    ///  If the dimension is greater than the number of dimensions of the tensor.
1235    ///
1236    /// # Returns
1237    /// A vector of tensors.
1238    ///
1239    /// # Example
1240    ///
1241    /// ```rust
1242    /// use burn_tensor::backend::Backend;
1243    /// use burn_tensor::Tensor;
1244    ///
1245    /// fn example<B: Backend>() {
1246    ///     let device = Default::default();
1247    ///     // Create a 2D tensor with dimensions [4, 3]
1248    ///     let tensor = Tensor::<B, 2>::from_data(
1249    ///         [
1250    ///             [3.0, 4.9, 2.0],
1251    ///             [2.0, 1.9, 3.0],
1252    ///             [6.0, 1.5, 7.0],
1253    ///             [3.0, 4.9, 9.0],
1254    ///         ],
1255    ///         &device,
1256    ///     );
1257    ///     // Split the tensor along the dimension 1 into 2 chunks.
1258    ///     // The first chuck will have shape [4, 2]:
1259    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1260    ///     // The second chunk will have shape [4, 1]:
1261    ///     // [[2.0], [3.0], [7.0], [9.0]]
1262    ///     let chunks = tensor.chunk(2, 1);
1263    ///     println!("{chunks:?}");
1264    /// }
1265    /// ```
1266    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1267        check!(TensorCheck::dim_ops::<D>("chunk", dim));
1268        K::chunk(self.primitive, chunks, dim)
1269            .into_iter()
1270            .map(Self::new)
1271            .collect()
1272    }
1273
1274    /// Splits the tensor into chunks of a specified size along a given dimension.
1275    /// Each chunk is a view of the original tensor.
1276    ///
1277    /// If the tensor size along the given dimension is not divisible by `split_size`,
1278    /// then the last chunk will be smaller.
1279    ///
1280    /// # Panics
1281    ///
1282    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
1283    ///
1284    /// # Returns
1285    ///
1286    /// A vector of tensors.
1287    ///
1288    /// # Example
1289    /// ```rust
1290    /// use burn_tensor::backend::Backend;
1291    /// use burn_tensor::Tensor;
1292    ///
1293    /// fn example<B: Backend>() {
1294    ///     let device = Default::default();
1295    ///     // Create a 1D tensor with 5 elements
1296    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1297    ///     // Split the tensor into chunks of size 2 along dimension 0
1298    ///     let chunks = tensor.split(2, 0);
1299    ///     // The result is a vector of tensors:
1300    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
1301    ///     println!("{:?}", chunks);
1302    /// }
1303    /// ```
1304    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
1305        check!(TensorCheck::split::<D>(
1306            self.shape().dims.as_ref(),
1307            split_size,
1308            dim
1309        ));
1310        K::split(self.primitive, split_size, dim)
1311            .into_iter()
1312            .map(Self::new)
1313            .collect()
1314    }
1315
1316    /// Splits the tensor into chunks with the specified sizes along a given dimension.
1317    /// Each chunk is a view of the original tensor.
1318    ///
1319    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
1320    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
1321    ///
1322    /// # Panics
1323    ///
1324    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
1325    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
1326    ///
1327    /// # Returns
1328    ///
1329    /// A vector of tensors.
1330    ///
1331    /// # Example
1332    /// ```rust
1333    /// use burn_tensor::backend::Backend;
1334    /// use burn_tensor::Tensor;
1335    ///
1336    /// fn example<B: Backend>() {
1337    ///     let device = Default::default();
1338    ///     // Create a 1D tensor with 5 elements
1339    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1340    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0
1341    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);
1342    ///     // The result is a vector of tensors:
1343    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
1344    ///     println!("{:?}", chunks);
1345    /// }
1346    /// ```
1347    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
1348        check!(TensorCheck::split_with_sizes::<D>(
1349            self.shape().dims.as_ref(),
1350            &split_sizes,
1351            dim
1352        ));
1353        K::split_with_sizes(self.primitive, split_sizes, dim)
1354            .into_iter()
1355            .map(Self::new)
1356            .collect()
1357    }
1358
1359    /// Tests if any element in the `tensor` evaluates to True.
1360    ///
1361    /// # Arguments
1362    ///
1363    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1364    ///
1365    /// # Returns
1366    ///
1367    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
1368    /// evaluates to True, False otherwise.
1369    ///
1370    /// # Example
1371    ///
1372    /// ```rust
1373    /// use burn_tensor::backend::Backend;
1374    /// use burn_tensor::{Tensor, Bool};
1375    ///
1376    /// fn example<B: Backend>() {
1377    ///   let device = Default::default();
1378    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
1379    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
1380    ///
1381    ///   // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True.
1382    ///   let any_tensor = tensor.any();
1383    ///   println!("{}", any_tensor);
1384    ///   // Tensor { data: [true], ... }
1385    ///
1386    ///   // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True.
1387    ///   let any_tensor_two = tensor_two.any();
1388    ///   println!("{}", any_tensor_two);
1389    ///   // Tensor { data: [false], ... }
1390    /// }
1391    /// ```
1392    pub fn any(self) -> Tensor<B, 1, Bool> {
1393        Tensor::new(K::any(self.primitive))
1394    }
1395
1396    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
1397    ///
1398    /// # Arguments
1399    ///
1400    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1401    /// * `dim` - The axis along which to test.
1402    ///
1403    /// # Returns
1404    ///
1405    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1406    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1407    /// evaluates to True, False otherwise.
1408    ///
1409    /// # Example
1410    ///
1411    /// ```rust
1412    /// use burn_tensor::backend::Backend;
1413    /// use burn_tensor::{Tensor, Bool};
1414    ///
1415    /// fn example<B: Backend>() {
1416    ///     let device = Default::default();
1417    ///     let tensor =
1418    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
1419    ///     // Check if any element in the tensor evaluates to True along the dimension 1.
1420    ///     // [[true], [true]],
1421    ///     let any_dim = tensor.clone().any_dim(1);
1422    ///     println!("{any_dim}");
1423    /// }
1424    /// ```
1425    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1426        Tensor::new(K::any_dim(self.primitive, dim))
1427    }
1428
1429    /// Tests if all elements in the `tensor` evaluate to True.
1430    ///
1431    /// # Arguments
1432    ///
1433    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1434    ///
1435    /// # Returns
1436    ///
1437    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1438    /// evaluate to True, False otherwise.
1439    ///
1440    /// # Example
1441    ///
1442    /// ```rust
1443    /// use burn_tensor::backend::Backend;
1444    /// use burn_tensor::{Tensor, Bool};
1445    ///
1446    /// fn example<B: Backend>() {
1447    ///     let device = Default::default();
1448    ///     let tensor =
1449    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
1450    ///     // Check if all elements in the tensor evaluate to True (which is not the case).
1451    ///     // [false]
1452    ///     let all = tensor.all();
1453    ///     println!("{all}");
1454    /// }
1455    /// ```
1456    pub fn all(self) -> Tensor<B, 1, Bool> {
1457        Tensor::new(K::all(self.primitive))
1458    }
1459
1460    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1461    ///
1462    /// # Arguments
1463    ///
1464    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1465    /// * `dim` - The axis along which to test.
1466    ///
1467    /// # Returns
1468    ///
1469    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1470    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1471    /// evaluates to True, False otherwise.
1472    ///
1473    /// # Example
1474    ///
1475    /// ```rust
1476    /// use burn_tensor::backend::Backend;
1477    /// use burn_tensor::{Tensor, Bool};
1478    ///
1479    /// fn example<B: Backend>() {
1480    ///     let device = Default::default();
1481    ///     let tensor =
1482    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
1483    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.
1484    ///     // [[true, true, false]]
1485    ///     let all_dim = tensor.clone().all_dim(0);
1486    ///     println!("{all_dim}");
1487    /// }
1488    /// ```
1489    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1490        Tensor::new(K::all_dim(self.primitive, dim))
1491    }
1492
1493    /// Convert the tensor into a scalar.
1494    ///
1495    /// # Panics
1496    ///
1497    /// If the tensor doesn't have one element.
1498    /// If the backend fails to read the tensor data synchronously.
1499    ///
1500    /// # Returns
1501    ///
1502    /// The scalar value of the tensor.
1503    ///
1504    /// # Example
1505    ///
1506    /// ```rust
1507    /// use burn_tensor::backend::Backend;
1508    /// use burn_tensor::Tensor;
1509    ///
1510    /// fn example<B: Backend>() {
1511    ///     let device = Default::default();
1512    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
1513    ///     // Convert the tensor with a single element into a scalar.
1514    ///     let scalar = tensor.into_scalar();
1515    ///     println!("{scalar}");
1516    /// }
1517    /// ```
1518    pub fn into_scalar(self) -> K::Elem {
1519        crate::try_read_sync(self.into_scalar_async()).expect(
1520            "Failed to read tensor data synchronously. This can happen on platforms
1521            that don't support blocking futures like WASM. Try into_scalar_async instead.",
1522        )
1523    }
1524
1525    /// Convert the tensor into a scalar.
1526    ///
1527    /// # Panics
1528    ///
1529    /// If the tensor doesn't have one element.
1530    pub async fn into_scalar_async(self) -> K::Elem {
1531        check!(TensorCheck::into_scalar::<D>(&self.shape()));
1532        let x = self.into_data_async().await.iter().next().unwrap();
1533        x
1534    }
1535
1536    /// Broadcast the tensor to the given shape.
1537    ///
1538    /// # Arguments
1539    ///
1540    /// * `shape` - The shape to broadcast the tensor to.
1541    ///             Can contain -1 for dimensions that should be inferred.
1542    ///             The number of elements in the shape must be greater or equal as
1543    ///             the number of dimensions of the tensor.
1544    ///
1545    /// # Panics
1546    ///
1547    /// If the tensor cannot be broadcasted to the given shape.
1548    ///
1549    /// # Returns
1550    ///
1551    /// A new tensor with the given shape.
1552    ///
1553    /// # Example
1554    ///
1555    /// ```rust
1556    /// use burn_tensor::backend::Backend;
1557    /// use burn_tensor::Tensor;
1558    ///
1559    /// fn example<B: Backend>() {
1560    ///     let device = Default::default();
1561    ///     // Create a 2D tensor with dimensions [3, 1]
1562    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
1563    ///     // Expand the tensor to a new shape [3, 4]
1564    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
1565    ///     let expanded = tensor.expand([3, 4]);
1566    ///     println!("{}", expanded);
1567    /// }
1568    /// ```
1569    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
1570        let shape = shape.into_shape(&self.shape());
1571        check!(TensorCheck::expand::<D, D2>(
1572            "expand",
1573            &self.shape(),
1574            &shape,
1575        ));
1576
1577        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
1578    }
1579}
1580
1581/// Iterator given by (Tensor::iter_dim).
1582pub struct DimIter<B, const D: usize, K>
1583where
1584    B: Backend,
1585    K: BasicOps<B>,
1586{
1587    start: usize,
1588    end: usize,
1589    dim: usize,
1590    ranges: [Range<usize>; D],
1591    tensor: Tensor<B, D, K>,
1592}
1593
1594impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
1595    type Item = Tensor<B, D, K>;
1596
1597    fn next(&mut self) -> Option<Self::Item> {
1598        if self.start >= self.end {
1599            return None;
1600        }
1601
1602        let mut ranges = self.ranges.clone();
1603        ranges[self.dim] = self.start..(self.start + 1);
1604
1605        let slice = self.tensor.clone().slice(ranges);
1606        self.start += 1;
1607
1608        Some(slice)
1609    }
1610}
1611
1612impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
1613    fn next_back(&mut self) -> Option<Self::Item> {
1614        if self.start >= self.end {
1615            return None;
1616        }
1617
1618        let mut ranges = self.ranges.clone();
1619        ranges[self.dim] = (self.end - 1)..self.end;
1620
1621        let slice = self.tensor.clone().slice(ranges);
1622        self.end = self.end.saturating_sub(1);
1623
1624        Some(slice)
1625    }
1626}
1627
1628impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
1629    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
1630        let dims = tensor.dims();
1631        let ranges = dims
1632            .iter()
1633            .map(|&dim| 0..dim)
1634            .collect::<Vec<Range<usize>>>();
1635        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
1636        Self {
1637            end: dims[dim],
1638            ranges,
1639            start: 0,
1640            dim,
1641            tensor,
1642        }
1643    }
1644}
1645
1646impl<B, const D: usize, K> Tensor<B, D, K>
1647where
1648    B: Backend,
1649    K: BasicOps<B>,
1650    <K as BasicOps<B>>::Elem: Debug,
1651{
1652    #[inline]
1653    fn push_newline_indent(acc: &mut String, indent: usize) {
1654        acc.push('\n');
1655        for _ in 0..indent {
1656            acc.push(' ');
1657        }
1658    }
1659    fn fmt_inner_tensor(
1660        &self,
1661        acc: &mut String,
1662        depth: usize,
1663        multi_index: &mut [usize],
1664        range: (usize, usize),
1665        precision: Option<usize>,
1666    ) {
1667        let (start, end) = range;
1668        for i in start..end {
1669            if i > 0 {
1670                acc.push_str(", ");
1671            }
1672            multi_index[depth] = i;
1673            let range: [core::ops::Range<usize>; D] =
1674                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
1675
1676            let data =
1677                burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
1678
1679            if let Some(data) = data {
1680                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
1681                match (precision, K::name()) {
1682                    (Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
1683                    _ => acc.push_str(&format!("{:?}", elem)),
1684                }
1685            } else {
1686                acc.push_str("<Tensor data not available>");
1687            }
1688        }
1689    }
1690
1691    fn fmt_outer_tensor(
1692        &self,
1693        acc: &mut String,
1694        depth: usize,
1695        multi_index: &mut [usize],
1696        print_options: &PrintOptions,
1697        summarize: bool,
1698        range: (usize, usize),
1699    ) {
1700        let (start, end) = range;
1701        for i in start..end {
1702            if i > start {
1703                acc.push(',');
1704                Self::push_newline_indent(acc, depth + 1);
1705            }
1706            acc.push('[');
1707            multi_index[depth] = i;
1708            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
1709            acc.push(']');
1710        }
1711    }
1712
1713    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
1714    ///
1715    /// This function is designed to work with tensors of any dimensionality.
1716    /// It traverses the tensor dimensions recursively, converting the elements
1717    /// to strings and appending them to the accumulator string with the
1718    /// appropriate formatting.
1719    ///
1720    /// # Arguments
1721    ///
1722    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
1723    /// * `depth` - The current depth of the tensor dimensions being processed.
1724    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
1725    fn display_recursive(
1726        &self,
1727        acc: &mut String,
1728        depth: usize,
1729        multi_index: &mut [usize],
1730        print_options: &PrintOptions,
1731        summarize: bool,
1732    ) {
1733        let edge_items = print_options.edge_items;
1734
1735        if depth == 0 {
1736            acc.push('[');
1737        }
1738
1739        if depth == self.dims().len() - 1 {
1740            // if we are at the innermost dimension, just push its elements into the accumulator
1741            if summarize && self.dims()[depth] > 2 * edge_items {
1742                // print the starting `edge_items` elements
1743                self.fmt_inner_tensor(
1744                    acc,
1745                    depth,
1746                    multi_index,
1747                    (0, edge_items),
1748                    print_options.precision,
1749                );
1750                acc.push_str(", ...");
1751                // print the last `edge_items` elements
1752                self.fmt_inner_tensor(
1753                    acc,
1754                    depth,
1755                    multi_index,
1756                    (self.dims()[depth] - edge_items, self.dims()[depth]),
1757                    print_options.precision,
1758                );
1759            } else {
1760                // print all the elements
1761                self.fmt_inner_tensor(
1762                    acc,
1763                    depth,
1764                    multi_index,
1765                    (0, self.dims()[depth]),
1766                    print_options.precision,
1767                );
1768            }
1769        } else {
1770            // otherwise, iterate through the current dimension and recursively display the inner tensors
1771            if summarize && self.dims()[depth] > 2 * edge_items {
1772                self.fmt_outer_tensor(
1773                    acc,
1774                    depth,
1775                    multi_index,
1776                    print_options,
1777                    summarize,
1778                    (0, edge_items),
1779                );
1780
1781                acc.push(',');
1782                Self::push_newline_indent(acc, depth + 1);
1783                acc.push_str("...");
1784                Self::push_newline_indent(acc, depth + 1);
1785
1786                self.fmt_outer_tensor(
1787                    acc,
1788                    depth,
1789                    multi_index,
1790                    print_options,
1791                    summarize,
1792                    (self.dims()[depth] - edge_items, self.dims()[depth]),
1793                );
1794            } else {
1795                self.fmt_outer_tensor(
1796                    acc,
1797                    depth,
1798                    multi_index,
1799                    print_options,
1800                    summarize,
1801                    (0, self.dims()[depth]),
1802                );
1803            }
1804        }
1805
1806        if depth == 0 {
1807            acc.push(']');
1808        }
1809    }
1810}
1811
1812#[derive(Clone, Debug)]
1813/// Options for Tensor pretty printing
1814pub struct PrintOptions {
1815    /// number of elements to start summarizing tensor
1816    pub threshold: usize,
1817
1818    /// number of starting elements and ending elements to display
1819    pub edge_items: usize,
1820
1821    /// Precision for floating point numbers
1822    pub precision: Option<usize>,
1823}
1824
1825static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
1826
1827impl PrintOptions {
1828    /// Print options with default values
1829    pub const fn const_default() -> Self {
1830        Self {
1831            threshold: 1000,
1832            edge_items: 3,
1833            precision: None,
1834        }
1835    }
1836}
1837
1838impl Default for PrintOptions {
1839    fn default() -> Self {
1840        Self::const_default()
1841    }
1842}
1843
1844/// Set print options
1845pub fn set_print_options(options: PrintOptions) {
1846    let mut print_opts = PRINT_OPTS.write().unwrap();
1847    *print_opts = options;
1848}
1849
1850/// Pretty print tensors
1851impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
1852where
1853    B: Backend,
1854    B::IntElem: core::fmt::Display,
1855    K: BasicOps<B>,
1856    <K as BasicOps<B>>::Elem: Debug,
1857{
1858    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1859        writeln!(f, "Tensor {{")?;
1860
1861        {
1862            // Do not lock the mutex for the whole function
1863            let mut po = { PRINT_OPTS.read().unwrap().clone() };
1864
1865            // Override the precision if it is set from the formatter
1866            // This will be possible when the tensor is printed using the `{:.*}` syntax
1867            if let Some(precision) = f.precision() {
1868                po.precision = Some(precision);
1869            }
1870
1871            let mut acc = String::new();
1872            let mut multi_index = vec![0; D];
1873            let summarize = self.shape().num_elements() > po.threshold;
1874
1875            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
1876
1877            writeln!(f, "  data:")?;
1878            write!(f, "{acc}")?;
1879            writeln!(f, ",")?;
1880        }
1881
1882        writeln!(f, "  shape:  {:?},", self.dims())?;
1883        writeln!(f, "  device:  {:?},", self.device())?;
1884        writeln!(f, "  backend:  {:?},", B::name())?;
1885        writeln!(f, "  kind:  {:?},", K::name())?;
1886
1887        // Bool tensors might be encoded in a different type, which we abstract for the display
1888        let dtype = if TypeId::of::<K::Elem>() == TypeId::of::<bool>() {
1889            DType::Bool
1890        } else {
1891            self.primitive.dtype()
1892        };
1893
1894        writeln!(f, "  dtype:  {:?},", dtype.name())?;
1895        write!(f, "}}")
1896    }
1897}
1898
1899/// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g.
1900/// ```rust
1901/// use burn_tensor::backend::Backend;
1902/// use burn_tensor::{Tensor, T};
1903///
1904/// fn example<B: Backend>() {
1905///     let device = Default::default();
1906///     let tensor = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
1907///     let transposed = tensor^T;
1908/// }
1909/// ```
1910pub struct T;
1911
1912impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
1913    type Output = Self;
1914    fn bitxor(self, _: T) -> Self::Output {
1915        self.transpose()
1916    }
1917}
1918
1919/// Trait that list all operations that can be applied on all tensors.
1920///
1921/// # Warnings
1922///
1923/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
1924pub trait BasicOps<B: Backend>: TensorKind<B> {
1925    /// The type of the tensor elements.
1926    type Elem: Element;
1927
1928    /// Creates an empty tensor with the given shape.
1929    ///
1930    /// # Arguments
1931    ///
1932    /// * `shape` - The shape of the tensor.
1933    /// * `device` - The device on which the tensor will be allocated.
1934    ///
1935    /// # Returns
1936    ///
1937    /// The empty tensor.
1938    ///
1939    /// # Remarks
1940    ///
1941    /// This is a low-level function used internally by the library to call different backend functions
1942    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1943    /// or use this function directly.
1944    ///
1945    /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
1946    /// which is more high-level and designed for public use.
1947    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive;
1948
1949    /// Reshapes the tensor.
1950    ///
1951    /// # Arguments
1952    ///
1953    /// * `tensor` - The tensor.
1954    /// * `shape` - The new shape of the tensor.
1955    ///
1956    /// # Returns
1957    ///
1958    /// The reshaped tensor.
1959    ///
1960    /// # Remarks
1961    ///
1962    /// This is a low-level function used internally by the library to call different backend functions
1963    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1964    /// or use this function directly.
1965    ///
1966    /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
1967    /// which is more high-level and designed for public use.
1968    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
1969
1970    /// Transposes a tensor.
1971    ///
1972    /// # Arguments
1973    ///
1974    /// * `tensor` - The tensor to transpose.
1975    ///
1976    /// # Returns
1977    ///
1978    /// The transposed tensor.
1979    fn transpose(tensor: Self::Primitive) -> Self::Primitive;
1980
1981    /// Swaps two dimensions of a tensor.
1982    ///
1983    /// # Arguments
1984    ///
1985    /// * `tensor` - The tensor to swap the dimensions of.
1986    /// * `dim1` - The first dimension to swap.
1987    /// * `dim2` - The second dimension to swap.
1988    ///
1989    /// # Returns
1990    ///
1991    /// The tensor with the dimensions swapped.
1992    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
1993
1994    /// Permutes the dimensions of a tensor.
1995    ///
1996    /// # Arguments
1997    ///
1998    /// * `tensor` - The tensor to permute the dimensions of.
1999    /// * `axes` - The new order of the dimensions.
2000    ///
2001    /// # Returns
2002    ///
2003    /// The tensor with the dimensions permuted.
2004    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2005
2006    /// Flips the tensor along the given axes.
2007    ///
2008    /// # Arguments
2009    ///
2010    /// * `tensor` - The tensor to flip.
2011    /// * `axes` - The axes to flip the tensor along.
2012    ///
2013    /// # Returns
2014    ///
2015    /// The tensor with the axes flipped.
2016    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2017
2018    ///  Select tensor elements corresponding for the given ranges.
2019    ///
2020    /// # Arguments
2021    ///
2022    /// * `tensor` - The tensor.
2023    /// * `ranges` - The ranges of the elements to select.
2024    ///
2025    /// # Returns
2026    ///
2027    /// The selected elements.
2028    ///
2029    /// # Remarks
2030    ///
2031    /// This is a low-level function used internally by the library to call different backend functions
2032    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2033    /// or use this function directly.
2034    ///
2035    /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2036    /// which is more high-level and designed for public use.
2037    fn slice(tensor: Self::Primitive, range: &[Range<usize>]) -> Self::Primitive;
2038
2039    ///  Assigns the given value to the tensor elements corresponding for the given ranges.
2040    ///
2041    /// # Arguments
2042    ///
2043    /// * `tensor` - The tensor.
2044    /// * `ranges` - The ranges of the elements to select.
2045    /// * `value` - The value to assign.
2046    ///
2047    /// # Returns
2048    ///
2049    /// The tensor with the assigned values.
2050    ///
2051    /// # Remarks
2052    ///
2053    /// This is a low-level function used internally by the library to call different backend functions
2054    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2055    /// or use this function directly.
2056    ///
2057    /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2058    /// which is more high-level and designed for public use.
2059    fn slice_assign(
2060        tensor: Self::Primitive,
2061        ranges: &[Range<usize>],
2062        value: Self::Primitive,
2063    ) -> Self::Primitive;
2064
2065    /// Returns the device on which the tensor is allocated.
2066    ///
2067    /// # Arguments
2068    ///
2069    /// * `tensor` - The tensor.
2070    ///
2071    /// # Returns
2072    ///
2073    /// The device on which the tensor is allocated.
2074    ///
2075    /// # Remarks
2076    ///
2077    /// This is a low-level function used internally by the library to call different backend functions
2078    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2079    /// or use this function directly.
2080    ///
2081    /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2082    /// which is more high-level and designed for public use.
2083    fn device(tensor: &Self::Primitive) -> B::Device;
2084
2085    /// Moves the tensor to the given device.
2086    ///
2087    /// # Arguments
2088    ///
2089    /// * `tensor` - The tensor.
2090    /// * `device` - The device on which the tensor will be moved.
2091    ///
2092    /// # Returns
2093    ///
2094    /// The tensor on the given device.
2095    ///
2096    /// # Remarks
2097    ///
2098    /// This is a low-level function used internally by the library to call different backend functions
2099    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2100    /// or use this function directly.
2101    ///
2102    /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2103    /// which is more high-level and designed for public use.
2104    fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2105
2106    /// Extracts the data from the tensor asynchronously.
2107    ///
2108    /// # Arguments
2109    ///
2110    /// * `tensor` - The tensor.
2111    ///
2112    /// # Returns
2113    ///
2114    /// The data of the tensor.
2115    ///
2116    /// # Remarks
2117    ///
2118    /// This is a low-level function used internally by the library to call different backend functions
2119    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2120    /// or use this function directly.
2121    ///
2122    /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
2123    /// which is more high-level and designed for public use.
2124    fn into_data_async(
2125        tensor: Self::Primitive,
2126    ) -> impl Future<Output = TensorData> + 'static + Send;
2127
2128    /// Read the data from the tensor using a transaction.
2129    ///
2130    /// # Remarks
2131    ///
2132    /// This is a low-level function used internally by the library to call different backend functions
2133    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2134    /// or use this function directly.
2135    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
2136
2137    /// Creates a tensor from the given data.
2138    ///
2139    /// # Arguments
2140    ///
2141    /// * `data` - The data of the tensor.
2142    /// * `device` - The device on which the tensor will be allocated.
2143    ///
2144    /// # Returns
2145    ///
2146    /// The tensor.
2147    ///
2148    /// # Remarks
2149    ///
2150    /// This is a low-level function used internally by the library to call different backend functions
2151    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2152    /// or use this function directly.
2153    ///
2154    /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
2155    /// which is more high-level and designed for public use.
2156    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
2157
2158    /// Repeat the tensor along the given dimension.
2159    ///
2160    /// # Arguments
2161    ///
2162    /// * `tensor` - The tensor.
2163    /// * `dim` - The dimension along which the tensor will be repeated.
2164    /// * `times` - The number of times the tensor will be repeated.
2165    ///
2166    /// # Returns
2167    ///
2168    /// The repeated tensor.
2169    ///
2170    /// # Remarks
2171    ///
2172    /// This is a low-level function used internally by the library to call different backend functions
2173    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2174    /// or use this function directly.
2175    ///
2176    /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
2177    /// which is more high-level and designed for public use.
2178    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
2179
2180    /// Concatenates the given tensors along the given dimension.
2181    ///
2182    /// # Arguments
2183    ///
2184    /// * `vectors` - The tensors to concatenate.
2185    /// * `dim` - The dimension along which the tensors will be concatenated.
2186    ///
2187    /// # Returns
2188    ///
2189    /// The concatenated tensor.
2190    ///
2191    /// # Remarks
2192    ///
2193    /// This is a low-level function used internally by the library to call different backend functions
2194    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2195    /// or use this function directly.
2196    ///
2197    /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
2198    /// which is more high-level and designed for public use.
2199    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
2200
2201    /// Attempts to split the tensor along the given dimension into chunks.
2202    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2203    ///
2204    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2205    /// Otherwise all chunks will be of equal size except for the last one.
2206    ///
2207    /// # Panics
2208    ///
2209    ///  If the dimension is greater than the number of dimensions of the tensor.
2210    ///
2211    /// # Returns
2212    /// A vector of tensors.
2213    ///
2214    /// # Remarks
2215    ///
2216    /// This is a low-level function used internally by the library to call different backend functions
2217    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2218    /// or use this function directly.
2219    ///
2220    /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
2221    /// which is more high-level and designed for public use.
2222    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;
2223
2224    /// Splits the tensor into chunks of a specified size along a given dimension.
2225    /// Each chunk is a view of the original tensor.
2226    ///
2227    /// # Panics
2228    ///
2229    /// If the dimension to split along is greater than the number of dimensions of the tensor.
2230    ///
2231    /// # Returns
2232    ///
2233    /// A vector of tensors.
2234    ///
2235    /// # Remarks
2236    /// This is a low-level function used internally by the library to call different backend functions
2237    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2238    /// or use this function directly.
2239    ///
2240    /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
2241    /// which is more high-level and designed for public use.
2242    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
2243
2244    /// Splits the tensor into chunks with the specified sizes along a given dimension.
2245    /// Each chunk is a view of the original tensor.
2246    ///
2247    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2248    /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2249    ///
2250    /// # Panics
2251    ///
2252    /// If the dimension to split along is greater than the number of dimensions of the tensor or
2253    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2254    ///
2255    /// # Returns
2256    ///
2257    /// A vector of tensors.
2258    ///
2259    /// # Remarks
2260    /// This is a low-level function used internally by the library to call different backend functions
2261    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2262    /// or use this function directly.
2263    ///
2264    /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
2265    /// which is more high-level and designed for public use.
2266    fn split_with_sizes(
2267        tensor: Self::Primitive,
2268        split_sizes: Vec<usize>,
2269        dim: usize,
2270    ) -> Vec<Self::Primitive>;
2271
2272    /// Equates the given tensors.
2273    ///
2274    /// # Arguments
2275    ///
2276    /// * `lhs` - The left hand side tensor.
2277    /// * `rhs` - The right hand side tensor.
2278    ///
2279    /// # Returns
2280    ///
2281    /// The tensor of booleans indicating whether the corresponding elements are equal.
2282    ///
2283    /// # Remarks
2284    ///
2285    /// This is a low-level function used internally by the library to call different backend functions
2286    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2287    /// or use this function directly.
2288    ///
2289    /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
2290    /// which is more high-level and designed for public use.
2291    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2292
2293    /// Applies element-wise non-equality comparison between the given tensors.
2294    ///
2295    /// # Arguments
2296    ///
2297    /// * `lhs` - The left hand side tensor.
2298    /// * `rhs` - The right hand side tensor.
2299    ///
2300    /// # Returns
2301    ///
2302    /// The tensor of booleans indicating whether the corresponding elements are equal.
2303    ///
2304    /// # Remarks
2305    ///
2306    /// This is a low-level function used internally by the library to call different backend functions
2307    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2308    /// or use this function directly.
2309    ///
2310    /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
2311    /// function, which is more high-level and designed for public use.
2312    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2313
2314    /// Returns the name of the element type.
2315    fn elem_type_name() -> &'static str {
2316        core::any::type_name::<Self::Elem>()
2317    }
2318
2319    /// Returns the tensor data type.
2320    fn dtype(tensor: &Self::Primitive) -> DType {
2321        tensor.dtype()
2322    }
2323
2324    /// Tests if any element in the `tensor` evaluates to True.
2325    ///
2326    /// # Arguments
2327    ///
2328    /// * `tensor` - The tensor to test.
2329    ///
2330    /// # Returns
2331    ///
2332    /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
2333    ///
2334    /// # Remarks
2335    ///
2336    /// This is a low-level function used internally by the library to call different backend functions
2337    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2338    /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
2339    /// which is more high-level and designed for public use.
2340    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2341
2342    /// Tests if any element in the tensor evaluates to True along a given dimension dim.
2343    ///
2344    /// # Arguments
2345    ///
2346    /// * tensor - The tensor to test.
2347    /// * dim - The axis along which to test.
2348    ///
2349    /// # Returns
2350    ///
2351    /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
2352    /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
2353    ///
2354    /// # Remarks
2355    ///
2356    /// This is a low-level function used internally by the library to call different backend functions
2357    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2358    /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
2359    /// which is more high-level and designed for public use.
2360    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2361
2362    /// Tests if all elements in the `tensor` evaluate to True.
2363    ///
2364    /// # Arguments
2365    ///
2366    /// * `tensor` - The tensor to test.
2367    ///
2368    /// # Returns
2369    ///
2370    /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
2371    ///
2372    /// # Remarks
2373    ///
2374    /// This is a low-level function used internally by the library to call different backend functions
2375    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2376    /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
2377    /// which is more high-level and designed for public use.
2378    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2379
2380    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2381    ///
2382    /// # Arguments
2383    ///
2384    /// * `tensor` - The tensor to test.
2385    ///
2386    /// # Returns
2387    ///
2388    /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
2389    /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
2390    ///
2391    /// # Remarks
2392    ///
2393    /// This is a low-level function used internally by the library to call different backend functions
2394    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2395    /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
2396    /// which is more high-level and designed for public use.
2397    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2398
2399    /// Broadcasts the given tensor to the specified shape.
2400    ///
2401    /// # Arguments
2402    ///
2403    /// * `tensor` - The tensor to broadcast.
2404    /// * `shape` - The shape to broadcast to.
2405    ///
2406    /// # Returns
2407    ///
2408    /// The broadcasted tensor.
2409    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2410}
2411
2412impl<B: Backend> BasicOps<B> for Float {
2413    type Elem = B::FloatElem;
2414
2415    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2416        TensorPrimitive::Float(B::float_empty(shape, device))
2417    }
2418
2419    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2420        tr.register_float(tensor);
2421    }
2422
2423    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2424        match tensor {
2425            TensorPrimitive::Float(tensor) => {
2426                TensorPrimitive::Float(B::float_reshape(tensor, shape))
2427            }
2428            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
2429        }
2430    }
2431
2432    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2433        match tensor {
2434            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
2435            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
2436        }
2437    }
2438
2439    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2440        match tensor {
2441            TensorPrimitive::Float(tensor) => {
2442                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
2443            }
2444            TensorPrimitive::QFloat(tensor) => {
2445                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
2446            }
2447        }
2448    }
2449
2450    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2451        match tensor {
2452            TensorPrimitive::Float(tensor) => {
2453                TensorPrimitive::Float(B::float_slice(tensor, ranges))
2454            }
2455            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)),
2456        }
2457    }
2458
2459    fn slice_assign(
2460        tensor: Self::Primitive,
2461        ranges: &[Range<usize>],
2462        value: Self::Primitive,
2463    ) -> Self::Primitive {
2464        match (tensor, value) {
2465            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => {
2466                TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value))
2467            }
2468            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => {
2469                TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value))
2470            }
2471            _ => panic!("Primitive type mismatch for tensor and value"),
2472        }
2473    }
2474
2475    fn device(tensor: &Self::Primitive) -> Device<B> {
2476        match tensor {
2477            TensorPrimitive::Float(tensor) => B::float_device(tensor),
2478            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
2479        }
2480    }
2481
2482    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2483        match tensor {
2484            TensorPrimitive::Float(tensor) => {
2485                TensorPrimitive::Float(B::float_to_device(tensor, device))
2486            }
2487            TensorPrimitive::QFloat(tensor) => {
2488                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
2489            }
2490        }
2491    }
2492
2493    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2494        match tensor {
2495            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
2496            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
2497        }
2498    }
2499
2500    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2501        match data.dtype {
2502            DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
2503            _ => TensorPrimitive::Float(B::float_from_data(data, device)),
2504        }
2505    }
2506
2507    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2508        match tensor {
2509            TensorPrimitive::Float(tensor) => {
2510                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
2511            }
2512            TensorPrimitive::QFloat(tensor) => {
2513                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
2514            }
2515        }
2516    }
2517
2518    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2519        match vectors.first().unwrap() {
2520            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
2521                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
2522                dim,
2523            )),
2524            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
2525                vectors
2526                    .into_iter()
2527                    .map(|tensor| {
2528                        if let TensorPrimitive::QFloat(t) = tensor {
2529                            t
2530                        } else {
2531                            panic!("Concatenation only works with vector of QFloat")
2532                        }
2533                    })
2534                    .collect(),
2535                dim,
2536            )),
2537        }
2538    }
2539
2540    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2541        B::float_equal(lhs.tensor(), rhs.tensor())
2542    }
2543
2544    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2545        B::float_not_equal(lhs.tensor(), rhs.tensor())
2546    }
2547
2548    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2549        B::float_any(tensor.tensor())
2550    }
2551
2552    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2553        B::float_any_dim(tensor.tensor(), dim)
2554    }
2555
2556    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2557        B::float_all(tensor.tensor())
2558    }
2559
2560    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2561        B::float_all_dim(tensor.tensor(), dim)
2562    }
2563
2564    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2565        match tensor {
2566            TensorPrimitive::Float(tensor) => {
2567                TensorPrimitive::Float(B::float_permute(tensor, axes))
2568            }
2569            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
2570        }
2571    }
2572
2573    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2574        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
2575    }
2576
2577    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2578        match tensor {
2579            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
2580            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
2581        }
2582    }
2583
2584    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2585        match tensor {
2586            TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim)
2587                .into_iter()
2588                .map(TensorPrimitive::Float)
2589                .collect(),
2590            TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim)
2591                .into_iter()
2592                .map(TensorPrimitive::QFloat)
2593                .collect(),
2594        }
2595    }
2596
2597    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2598        match tensor {
2599            TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
2600                .into_iter()
2601                .map(TensorPrimitive::Float)
2602                .collect(),
2603            TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
2604                .into_iter()
2605                .map(TensorPrimitive::QFloat)
2606                .collect(),
2607        }
2608    }
2609
2610    fn split_with_sizes(
2611        tensor: Self::Primitive,
2612        split_sizes: Vec<usize>,
2613        dim: usize,
2614    ) -> Vec<Self::Primitive> {
2615        match tensor {
2616            TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
2617                .into_iter()
2618                .map(TensorPrimitive::Float)
2619                .collect(),
2620            TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
2621                .into_iter()
2622                .map(TensorPrimitive::QFloat)
2623                .collect(),
2624        }
2625    }
2626}
2627
2628impl<B: Backend> BasicOps<B> for Int {
2629    type Elem = B::IntElem;
2630
2631    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2632        B::int_empty(shape, device)
2633    }
2634
2635    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2636        tr.register_int(tensor);
2637    }
2638
2639    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2640        B::int_reshape(tensor, shape)
2641    }
2642
2643    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2644        B::int_transpose(tensor)
2645    }
2646
2647    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2648        B::int_swap_dims(tensor, dim1, dim2)
2649    }
2650
2651    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2652        B::int_slice(tensor, ranges)
2653    }
2654
2655    fn slice_assign(
2656        tensor: Self::Primitive,
2657        ranges: &[Range<usize>],
2658        value: Self::Primitive,
2659    ) -> Self::Primitive {
2660        B::int_slice_assign(tensor, ranges, value)
2661    }
2662
2663    fn device(tensor: &Self::Primitive) -> Device<B> {
2664        B::int_device(tensor)
2665    }
2666
2667    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2668        B::int_to_device(tensor, device)
2669    }
2670
2671    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2672        B::int_into_data(tensor).await
2673    }
2674
2675    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2676        B::int_from_data(data, device)
2677    }
2678
2679    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2680        B::int_repeat_dim(tensor, dim, times)
2681    }
2682
2683    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2684        B::int_equal(lhs, rhs)
2685    }
2686
2687    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2688        B::int_not_equal(lhs, rhs)
2689    }
2690
2691    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2692        B::int_cat(vectors, dim)
2693    }
2694
2695    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2696        B::int_any(tensor)
2697    }
2698
2699    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2700        B::int_any_dim(tensor, dim)
2701    }
2702
2703    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2704        B::int_all(tensor)
2705    }
2706
2707    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2708        B::int_all_dim(tensor, dim)
2709    }
2710
2711    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2712        B::int_permute(tensor, axes)
2713    }
2714
2715    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2716        B::int_expand(tensor, shape)
2717    }
2718
2719    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2720        B::int_flip(tensor, axes)
2721    }
2722
2723    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2724        B::int_chunk(tensor, chunks, dim)
2725    }
2726
2727    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2728        B::int_split(tensor, split_size, dim)
2729    }
2730
2731    fn split_with_sizes(
2732        tensor: Self::Primitive,
2733        split_sizes: Vec<usize>,
2734        dim: usize,
2735    ) -> Vec<Self::Primitive> {
2736        B::int_split_with_sizes(tensor, split_sizes, dim)
2737    }
2738}
2739
2740impl<B: Backend> BasicOps<B> for Bool {
2741    type Elem = bool;
2742
2743    fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2744        B::bool_empty(shape, device)
2745    }
2746
2747    fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2748        tr.register_bool(tensor);
2749    }
2750
2751    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2752        B::bool_reshape(tensor, shape)
2753    }
2754
2755    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2756        B::bool_transpose(tensor)
2757    }
2758
2759    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2760        B::bool_swap_dims(tensor, dim1, dim2)
2761    }
2762
2763    fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2764        B::bool_slice(tensor, ranges)
2765    }
2766
2767    fn slice_assign(
2768        tensor: Self::Primitive,
2769        ranges: &[Range<usize>],
2770        value: Self::Primitive,
2771    ) -> Self::Primitive {
2772        B::bool_slice_assign(tensor, ranges, value)
2773    }
2774
2775    fn device(tensor: &Self::Primitive) -> Device<B> {
2776        B::bool_device(tensor)
2777    }
2778
2779    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2780        B::bool_to_device(tensor, device)
2781    }
2782
2783    async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2784        B::bool_into_data(tensor).await
2785    }
2786
2787    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2788        B::bool_from_data(data, device)
2789    }
2790
2791    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2792        B::bool_repeat_dim(tensor, dim, times)
2793    }
2794
2795    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2796        B::bool_equal(lhs, rhs)
2797    }
2798
2799    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2800        B::bool_not_equal(lhs, rhs)
2801    }
2802
2803    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2804        B::bool_cat(vectors, dim)
2805    }
2806
2807    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2808        B::bool_any(tensor)
2809    }
2810
2811    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2812        B::bool_any_dim(tensor, dim)
2813    }
2814
2815    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2816        B::bool_all(tensor)
2817    }
2818
2819    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2820        B::bool_all_dim(tensor, dim)
2821    }
2822
2823    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2824        B::bool_permute(tensor, axes)
2825    }
2826
2827    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2828        B::bool_expand(tensor, shape)
2829    }
2830
2831    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2832        B::bool_flip(tensor, axes)
2833    }
2834
2835    fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2836        B::bool_chunk(tensor, chunks, dim)
2837    }
2838
2839    fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2840        B::bool_split(tensor, split_size, dim)
2841    }
2842
2843    fn split_with_sizes(
2844        tensor: Self::Primitive,
2845        split_sizes: Vec<usize>,
2846        dim: usize,
2847    ) -> Vec<Self::Primitive> {
2848        B::bool_split_with_sizes(tensor, split_sizes, dim)
2849    }
2850}
2851
2852/// Trait used for movedim arguments
2853pub trait MovedimArgs {
2854    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
2855    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
2856}
2857
2858impl MovedimArgs for Vec<i32> {
2859    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2860        let set = self
2861            .iter()
2862            .map(|&dim| {
2863                if dim < 0 {
2864                    (D as i32 + dim) as usize
2865                } else {
2866                    dim as usize
2867                }
2868            })
2869            .collect::<Vec<usize>>();
2870        check!(TensorCheck::movedim_args_vec::<D>(&set));
2871
2872        set
2873    }
2874}
2875
2876impl MovedimArgs for Vec<usize> {
2877    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2878        check!(TensorCheck::movedim_args_vec::<D>(&self));
2879        self
2880    }
2881}
2882
2883impl MovedimArgs for usize {
2884    #[allow(clippy::vec_init_then_push)]
2885    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2886        check!(TensorCheck::movedim_args_usize::<D>(self));
2887
2888        let mut set = Vec::with_capacity(1);
2889        set.push(self);
2890
2891        set
2892    }
2893}
2894
2895impl MovedimArgs for i32 {
2896    #[allow(clippy::vec_init_then_push)]
2897    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2898        check!(TensorCheck::movedim_args_i32::<D>(self));
2899
2900        let dim = if self < 0 {
2901            (D as i32 + self) as usize
2902        } else {
2903            self as usize
2904        };
2905
2906        let mut set = Vec::with_capacity(1);
2907        set.push(dim);
2908
2909        set
2910    }
2911}
2912
2913/// Trait used for slice arguments
2914pub trait RangesArg<const D2: usize> {
2915    /// Converts into a set of ranges to `[core::ops::Range<usize>; D2]` for the `tensor.slice()` function
2916    fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2];
2917
2918    /// Handles negative index values
2919    fn handle_negative_index(start: i64, end: i64, dim: usize) -> (usize, usize) {
2920        let start = if start < 0 {
2921            (dim as i64 + start) as usize
2922        } else {
2923            start as usize
2924        };
2925        let end = if end < 0 {
2926            (dim as i64 + end) as usize
2927        } else {
2928            end as usize
2929        };
2930        (start, end)
2931    }
2932
2933    /// Clamps the range to the shape dimensions
2934    fn clamp_range(start: usize, end: usize, dim: usize) -> (usize, usize) {
2935        let start = start.clamp(0, dim);
2936        let end = end.clamp(0, dim);
2937        (start, end)
2938    }
2939}
2940
2941impl<const D2: usize> RangesArg<D2> for [core::ops::Range<usize>; D2] {
2942    fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2] {
2943        // clamp the ranges to the shape dimensions
2944        let ranges = self
2945            .iter()
2946            .enumerate()
2947            .map(|(i, range)| {
2948                let (start, end) = Self::clamp_range(range.start, range.end, shape.dims[i]);
2949                start..end
2950            })
2951            .collect::<Vec<_>>();
2952        ranges.try_into().unwrap()
2953    }
2954}
2955
2956impl<const D2: usize> RangesArg<D2> for [Option<(i64, i64)>; D2] {
2957    fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2] {
2958        let ranges = self
2959            .iter()
2960            .enumerate()
2961            .map(|(i, range)| match range {
2962                Some((start, end)) => {
2963                    let (start, end) = Self::handle_negative_index(*start, *end, shape.dims[i]);
2964                    let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
2965                    start..end
2966                }
2967                None => 0..shape.dims[i], // if None, use the full range
2968            })
2969            .collect::<Vec<_>>();
2970
2971        ranges.try_into().unwrap()
2972    }
2973}
2974
2975impl<const D2: usize> RangesArg<D2> for [(i64, i64); D2] {
2976    fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2] {
2977        let ranges = self
2978            .iter()
2979            .enumerate()
2980            .map(|(i, &(start, end))| {
2981                let (start, end) = Self::handle_negative_index(start, end, shape.dims[i]);
2982                let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
2983                start..end
2984            })
2985            .collect::<Vec<_>>();
2986
2987        ranges.try_into().unwrap()
2988    }
2989}
2990
2991/// Trait used for reshape arguments.
2992pub trait ReshapeArgs<const D2: usize> {
2993    /// Converts to a shape.
2994    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
2995        self,
2996        tensor: &Tensor<B, D, K>,
2997    ) -> Shape;
2998}
2999
3000impl<const D2: usize> ReshapeArgs<D2> for Shape {
3001    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3002        self,
3003        tensor: &Tensor<B, D, K>,
3004    ) -> Shape {
3005        check!(TensorCheck::reshape_args_usize::<D, D2>(
3006            &tensor.shape(),
3007            &self
3008        ));
3009
3010        self
3011    }
3012}
3013impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3014    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3015        self,
3016        tensor: &Tensor<B, D, K>,
3017    ) -> Shape {
3018        let shape = Shape::from(self);
3019
3020        check!(TensorCheck::reshape_args_usize::<D, D2>(
3021            &tensor.shape(),
3022            &shape
3023        ));
3024
3025        shape
3026    }
3027}
3028
3029impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3030    fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3031        self,
3032        tensor: &Tensor<B, D, K>,
3033    ) -> Shape {
3034        // Validate the reshape arguments
3035        check!(TensorCheck::reshape_args_i32(&self));
3036
3037        // Temporary shape
3038        let mut new_shape: [i32; D2] = [1; D2];
3039
3040        // We need to find the index of the 0 dimension and
3041        // replace it with the actual dimension value.
3042        for (i, &s) in self.iter().enumerate() {
3043            if s != 0 {
3044                new_shape[i] = s;
3045            } else {
3046                new_shape[i] = tensor.dims()[i] as i32;
3047            }
3048        }
3049
3050        // Find the index of the inferred dimension (-1)
3051        let infer_index = new_shape.iter().position(|x| x == &-1);
3052
3053        // Handle the case where the dimension is inferred (via -1)
3054        if let Some(index) = infer_index {
3055            // Handle the case where the dimension is inferred
3056            let mut product = 1;
3057            for (i, &s) in new_shape.iter().enumerate() {
3058                if i != index {
3059                    product *= s;
3060                }
3061            }
3062            let product_current = tensor.shape().num_elements() as i32;
3063
3064            new_shape[index] = product_current / product;
3065
3066            // Check if the reshape is valid
3067            if product_current % product != 0 {
3068                panic!(
3069                    "Cannot reshape tensor of shape {:?} to shape {:?}",
3070                    tensor.shape(),
3071                    new_shape
3072                );
3073            }
3074        };
3075
3076        // Convert each element to usize
3077        let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3078
3079        Shape::from(new_shape)
3080    }
3081}
3082
3083/// Trait used for broadcast arguments.
3084pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3085    /// Converts to a shape.
3086    fn into_shape(self, shape: &Shape) -> Shape;
3087}
3088
3089impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3090    fn into_shape(self, _shape: &Shape) -> Shape {
3091        self
3092    }
3093}
3094impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3095    fn into_shape(self, _shape: &Shape) -> Shape {
3096        Shape::from(self)
3097    }
3098}
3099
3100impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3101    // Passing -1 as the size for a dimension means not changing the size of that dimension.
3102    fn into_shape(self, shape: &Shape) -> Shape {
3103        if self.len() < shape.num_dims() {
3104            panic!("Broadcast arguments must be greater than the number of dimensions");
3105        }
3106
3107        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3108        let new_shape: Vec<_> = self
3109            .iter()
3110            .rev()
3111            .map(|x| {
3112                let primitive = x.to_i64();
3113                if primitive < -1 || primitive == 0 {
3114                    panic!("Broadcast arguments must be positive or -1");
3115                }
3116                primitive
3117            })
3118            .zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3119            .map(|(x, &y)| if x == -1 { y } else { x as usize })
3120            .collect::<Vec<_>>()
3121            .into_iter()
3122            .rev()
3123            .collect();
3124
3125        if new_shape.iter().any(|&x| x == 0) {
3126            panic!("Cannot substitute -1 for a non-existing dimension");
3127        }
3128
3129        let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3130
3131        Shape::from(new_shape)
3132    }
3133}
3134
3135impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3136where
3137    B: Backend,
3138    K: BasicOps<B>,
3139    K::Elem: Debug + Copy + Serialize,
3140{
3141    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3142        let data = self.to_data();
3143        data.serialize(serializer)
3144    }
3145}
3146
3147impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3148where
3149    B: Backend,
3150    K: BasicOps<B>,
3151    K::Elem: Debug + Copy + Deserialize<'de>,
3152{
3153    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3154        let tensor = Tensor::from_data(
3155            TensorData::deserialize(deserializer)?,
3156            &<B::Device as Default>::default(),
3157        );
3158        Ok(tensor)
3159    }
3160}