burn_tensor/tensor/api/
numeric.rs

1use alloc::vec::Vec;
2
3use crate::alloc::borrow::ToOwned;
4
5use crate::indexing::canonicalize_dim;
6use crate::{
7    AsIndex, BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor,
8    TensorKind,
9    backend::Backend,
10    check,
11    check::TensorCheck,
12    ops::{Device, IntTensor},
13};
14use crate::{DType, TensorPrimitive};
15
16macro_rules! q_bin_ops {
17    ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
18        match ($lhs, $rhs) {
19            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
20                TensorPrimitive::Float(B::$op(lhs, rhs))
21            }
22            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
23            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
24                TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
25            }
26            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
27                TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
28            }
29        }
30    };
31}
32
33impl<B, const D: usize, K> Tensor<B, D, K>
34where
35    B: Backend,
36    K: Numeric<B>,
37    K::Elem: Element,
38{
39    /// Applies element wise addition operation.
40    ///
41    /// `y = x2 + x1`
42    ///
43    /// # Arguments
44    ///
45    /// * `other` - The tensor to add.
46    ///
47    /// # Example
48    ///
49    /// ```rust
50    /// use burn_tensor::backend::Backend;
51    /// use burn_tensor::{Tensor, Shape};
52    ///
53    /// fn example<B: Backend>() {
54    ///    let device = B::Device::default();
55    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
56    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
57    ///    let tensor = tensor1 + tensor2;
58    ///    println!("{tensor}");
59    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
60    /// }
61    /// ```
62    #[allow(clippy::should_implement_trait)]
63    pub fn add(self, other: Self) -> Self {
64        check!(TensorCheck::binary_ops_ew("Add", &self, &other));
65        Self::new(K::add(self.primitive, other.primitive))
66    }
67
68    /// Applies element wise addition operation with a scalar.
69    ///
70    /// `y = x + s`
71    ///
72    /// # Arguments
73    ///
74    /// * `other` - The scalar to add, element wise.
75    ///
76    /// # Example
77    ///
78    /// ```rust
79    /// use burn_tensor::backend::Backend;
80    /// use burn_tensor::{Tensor, Shape};
81    ///
82    /// fn example<B: Backend>() {
83    ///   let device = B::Device::default();
84    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
85    ///   let scalar = 2.0;
86    ///   let tensor = tensor + scalar;
87    ///   println!("{tensor}");
88    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
89    /// }
90    /// ```
91    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
92        Self::new(K::add_scalar::<E>(self.primitive, other))
93    }
94
95    /// Applies element wise subtraction operation.
96    ///
97    /// `y = x2 - x1`
98    ///
99    /// # Arguments
100    ///
101    /// * `other` - The tensor to subtract.
102    ///
103    /// # Example
104    ///
105    /// ```rust
106    /// use burn_tensor::backend::Backend;
107    /// use burn_tensor::{Tensor, Shape};
108    ///
109    /// fn example<B: Backend>() {
110    ///   let device = B::Device::default();
111    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
112    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
113    ///   let tensor = tensor1 - tensor2;
114    ///   println!("{tensor}");
115    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
116    /// }
117    /// ```
118    #[allow(clippy::should_implement_trait)]
119    pub fn sub(self, other: Self) -> Self {
120        check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
121        Self::new(K::sub(self.primitive, other.primitive))
122    }
123
124    /// Applies element wise subtraction operation with a scalar.
125    ///
126    /// `y = x - s`
127    ///
128    /// # Arguments
129    ///
130    /// * `other` - The scalar to subtract, element wise.
131    ///
132    /// # Example
133    ///
134    /// ```rust
135    /// use burn_tensor::backend::Backend;
136    /// use burn_tensor::{Tensor, Shape};
137    ///
138    /// fn example<B: Backend>() {
139    ///    let device = B::Device::default();
140    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
141    ///    let scalar = 2.0;
142    ///    let tensor = tensor - scalar;
143    ///    println!("{tensor}");
144    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
145    /// }
146    /// ```
147    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
148        Self::new(K::sub_scalar::<E>(self.primitive, other))
149    }
150
151    /// Applies element wise division operation.
152    ///
153    /// `y = x2 / x1`
154    ///
155    /// # Arguments
156    ///
157    /// * `other` - The tensor to divide.
158    ///
159    /// # Example
160    ///
161    /// ```rust
162    /// use burn_tensor::backend::Backend;
163    /// use burn_tensor::{Tensor, Shape};
164    ///
165    /// fn example<B: Backend>() {
166    ///    let device = B::Device::default();
167    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
168    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
169    ///    let tensor = tensor1 / tensor2;
170    ///    println!("{tensor}");
171    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
172    /// }
173    /// ```
174    #[allow(clippy::should_implement_trait)]
175    pub fn div(self, other: Self) -> Self {
176        check!(TensorCheck::binary_ops_ew("Div", &self, &other));
177        Self::new(K::div(self.primitive, other.primitive))
178    }
179
180    /// Applies element wise division operation with a scalar.
181    ///
182    /// `y = x / s`
183    ///
184    /// # Arguments
185    ///
186    /// * `other` - The scalar to divide, element wise.
187    ///
188    /// # Example
189    ///
190    /// ```rust
191    /// use burn_tensor::backend::Backend;
192    /// use burn_tensor::{Tensor, Shape};
193    ///
194    /// fn example<B: Backend>() {
195    ///    let device = B::Device::default();
196    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
197    ///    let scalar = 2.0;
198    ///    let tensor = tensor / scalar;
199    ///    println!("{tensor}");
200    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
201    /// }
202    /// ```
203    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
204        Self::new(K::div_scalar::<E>(self.primitive, other))
205    }
206
207    /// Applies element wise the remainder operation with a scalar.
208    ///
209    /// `y = x2 % x1`
210    pub fn remainder(self, other: Self) -> Self {
211        Self::new(K::remainder(self.primitive, other.primitive))
212    }
213
214    /// Applies element wise the remainder operation with a scalar.
215    ///
216    /// `y = x % s`
217    ///
218    /// # Arguments
219    ///
220    /// * `other` - The scalar to divide, element wise.
221    ///
222    /// # Example
223    ///
224    /// ```rust
225    /// use burn_tensor::backend::Backend;
226    /// use burn_tensor::{Tensor, Shape};
227    ///
228    /// fn example<B: Backend>() {
229    ///    let device = B::Device::default();
230    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
231    ///    let scalar = 2.0;
232    ///    let tensor = tensor1 % scalar;
233    ///    println!("{tensor}");
234    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
235    /// }
236    /// ```
237    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
238        Self::new(K::remainder_scalar::<E>(self.primitive, other))
239    }
240
241    /// Applies element wise multiplication operation.
242    ///
243    /// `y = x2 * x1`
244    ///
245    /// # Arguments
246    ///
247    /// * `other` - The tensor to multiply.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use burn_tensor::backend::Backend;
253    /// use burn_tensor::{Tensor, Shape};
254    ///
255    /// fn example<B: Backend>() {
256    ///    let device = B::Device::default();
257    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
258    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
259    ///    let tensor = tensor1 * tensor2;
260    ///    println!("{tensor}");
261    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
262    /// }
263    /// ```
264    #[allow(clippy::should_implement_trait)]
265    pub fn mul(self, other: Self) -> Self {
266        check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
267        Self::new(K::mul(self.primitive, other.primitive))
268    }
269
270    /// Applies element wise multiplication operation with a scalar.
271    ///
272    /// `y = x * s`
273    ///
274    /// # Arguments
275    ///
276    /// * `other` - The scalar to multiply, element wise.
277    ///
278    /// # Example
279    ///
280    /// ```rust
281    /// use burn_tensor::backend::Backend;
282    /// use burn_tensor::{Tensor, Shape};
283    ///
284    /// fn example<B: Backend>() {
285    ///    let device = B::Device::default();
286    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
287    ///    let scalar = 2.0;
288    ///    let tensor = tensor * scalar;
289    ///    println!("{tensor}");
290    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
291    /// }
292    /// ```
293    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
294        Self::new(K::mul_scalar::<E>(self.primitive, other))
295    }
296
297    /// Switch sign of each element in the tensor.
298    ///
299    /// `y = -x`
300    ///
301    /// # Example
302    ///
303    /// ```rust
304    /// use burn_tensor::backend::Backend;
305    /// use burn_tensor::{Tensor, Shape};
306    ///
307    /// fn example<B: Backend>() {
308    ///    let device = B::Device::default();
309    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
310    ///    let tensor = -tensor;
311    ///    println!("{tensor}");
312    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
313    /// }
314    /// ```
315    #[allow(clippy::should_implement_trait)]
316    pub fn neg(self) -> Self {
317        Self::new(K::neg(self.primitive))
318    }
319
320    /// Returns the signs of the elements of the input tensor.
321    ///
322    /// # Example
323    ///
324    /// ```rust
325    /// use burn_tensor::backend::Backend;
326    /// use burn_tensor::{Tensor, Shape};
327    ///
328    /// fn example<B: Backend>() {
329    ///    let device = B::Device::default();
330    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
331    ///    let tensor = tensor.sign();
332    ///    println!("{tensor}");
333    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
334    /// }
335    /// ```
336    pub fn sign(self) -> Self {
337        Self::new(K::sign(self.primitive))
338    }
339
340    /// Create a tensor of the given shape where each element is zero.
341    ///
342    /// # Example
343    ///
344    /// ```rust
345    /// use burn_tensor::backend::Backend;
346    /// use burn_tensor::{Tensor, Shape};
347    ///
348    /// fn example<B: Backend>() {
349    ///    let device = B::Device::default();
350    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
351    ///    println!("{tensor}");
352    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
353    /// }
354    /// ```
355    pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
356        let shape = shape.into();
357        check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
358        Self::new(K::zeros(shape, device, K::Elem::dtype()))
359    }
360
361    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
362    ///
363    /// # Example
364    ///
365    /// ```rust
366    /// use burn_tensor::backend::Backend;
367    /// use burn_tensor::{Tensor, Shape};
368    ///
369    /// fn example<B: Backend>() {
370    ///   let device = B::Device::default();
371    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
372    ///   let tensor = tensor.zeros_like();
373    ///   println!("{tensor}");
374    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
375    /// }
376    /// ```
377    pub fn zeros_like(&self) -> Self {
378        Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
379    }
380
381    /// Create a tensor of the given shape where each element is one.
382    ///
383    /// # Example
384    ///
385    /// ```rust
386    /// use burn_tensor::backend::Backend;
387    /// use burn_tensor::{Tensor, Shape};
388    ///
389    /// fn example<B: Backend>() {
390    ///   let device = B::Device::default();
391    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
392    ///   println!("{tensor}");
393    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
394    /// }
395    /// ```
396    pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
397        let shape = shape.into();
398        check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
399        Self::new(K::ones(shape, device, K::Elem::dtype()))
400    }
401
402    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
403    ///
404    /// # Example
405    ///
406    /// ```rust
407    /// use burn_tensor::backend::Backend;
408    /// use burn_tensor::{Tensor, Shape};
409    ///
410    /// fn example<B: Backend>() {
411    ///    let device = B::Device::default();
412    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
413    ///    let tensor = tensor.ones_like();
414    ///    println!("{tensor}");
415    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
416    /// }
417    /// ```
418    pub fn ones_like(&self) -> Self {
419        Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
420    }
421
422    /// Aggregate all elements in the tensor with the mean operation.
423    ///
424    /// # Example
425    ///
426    /// ```rust
427    /// use burn_tensor::backend::Backend;
428    /// use burn_tensor::{Tensor, Shape};
429    ///
430    /// fn example<B: Backend>() {
431    ///    let device = B::Device::default();
432    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
433    ///    let tensor = tensor.mean();
434    ///    println!("{tensor}");
435    ///    // [3.6666667]
436    /// }
437    /// ```
438    pub fn mean(self) -> Tensor<B, 1, K> {
439        Tensor::new(K::mean(self.primitive))
440    }
441
442    /// Aggregate all elements in the tensor with the sum operation.
443    ///
444    /// # Example
445    ///
446    /// ```rust
447    /// use burn_tensor::backend::Backend;
448    /// use burn_tensor::{Tensor, Shape};
449    ///
450    /// fn example<B: Backend>() {
451    ///   let device = B::Device::default();
452    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
453    ///   let tensor = tensor.sum();
454    ///   println!("{tensor}");
455    ///   // [22.0]
456    /// }
457    /// ```
458    pub fn sum(self) -> Tensor<B, 1, K> {
459        Tensor::new(K::sum(self.primitive))
460    }
461
462    /// Aggregate all elements along the given *dimension* or *axis*
463    /// in the tensor with the mean operation.
464    ///
465    /// # Arguments
466    ///
467    /// * `dim` - The dimension or axis along which to aggregate the elements;
468    ///   supports negative indexing.
469    ///
470    /// # Example
471    ///
472    /// ```rust
473    /// use burn_tensor::backend::Backend;
474    /// use burn_tensor::{Tensor, Shape};
475    ///
476    /// fn example<B: Backend>() {
477    ///   let device = B::Device::default();
478    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
479    ///   let tensor = tensor.clone().mean_dim(0);
480    ///   println!("{tensor}");
481    ///   // [[3.0, 3.5, 4.5]]
482    ///   let tensor = tensor.clone().mean_dim(1);
483    ///   println!("{tensor}");
484    ///   // [[0.6666667], [6.6666665]]
485    /// }
486    /// ```
487    pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
488        let dim = canonicalize_dim(dim, D, false);
489        check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
490        Self::new(K::mean_dim(self.primitive, dim))
491    }
492
493    /// Aggregate all elements along the given *axes*
494    /// in the tensor with the mean operation.
495    ///
496    /// # Arguments
497    ///
498    /// * `dims` - the dimensions to aggregate; supports negative indexing.
499    ///
500    /// # Returns
501    ///
502    /// The returned tensor will have the same rank,
503    /// but the aggregated dimensions will have size 1.
504    ///
505    /// # Example
506    ///
507    /// ```rust
508    /// use burn_tensor::backend::Backend;
509    /// use burn_tensor::{Tensor, Shape};
510    ///
511    /// fn example<B: Backend>() {
512    ///    let device = B::Device::default();
513    ///    let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
514    ///    let tensor = tensor.clone().mean_dims(&[0, 1]);
515    ///    println!("{tensor}");
516    ///    // [[2.0]]
517    /// }
518    /// ```
519    pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
520        dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
521    }
522
523    /// Aggregate all elements along the given *dimension* or *axis*
524    /// in the tensor with the sum operation.
525    ///
526    /// # Arguments
527    ///
528    /// * `dim` - The dimension or axis along which to aggregate the elements;
529    ///   supports negative indexing.
530    ///
531    /// # Example
532    ///
533    /// ```rust
534    /// use burn_tensor::backend::Backend;
535    /// use burn_tensor::{Tensor, Shape};
536    ///
537    /// fn example<B: Backend>() {
538    ///    let device = B::Device::default();
539    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
540    ///    let tensor = tensor.clone().sum_dim(0);
541    ///    println!("{tensor}");
542    ///    // [[6.0, 7.0, 9.0]]
543    ///    let tensor = tensor.clone().sum_dim(1);
544    ///    println!("{tensor}");
545    ///    // [[2.0], [20.0]]
546    /// }
547    /// ```
548    pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
549        let dim = canonicalize_dim(dim, D, false);
550        check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
551        Self::new(K::sum_dim(self.primitive, dim))
552    }
553
554    /// Aggregate all elements along the given *axes*
555    /// in the tensor with the sum operation.
556    ///
557    /// # Arguments
558    ///
559    /// * `dims` - the dimensions to aggregate; supports negative indexing.
560    ///
561    /// # Returns
562    ///
563    /// The returned tensor will have the same rank,
564    /// but the aggregated dimensions will have size 1.
565    ///
566    /// # Example
567    ///
568    /// ```rust
569    /// use burn_tensor::backend::Backend;
570    /// use burn_tensor::{Tensor, Shape};
571    ///
572    /// fn example<B: Backend>() {
573    ///    let device = B::Device::default();
574    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
575    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
576    ///    println!("{tensor}");
577    ///    // [[27]]
578    /// }
579    /// ```
580    pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
581        dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
582    }
583
584    /// Aggregate and squeeze along the given dimensions.
585    ///
586    /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
587    ///
588    /// # Arguments
589    ///
590    /// * `dims` - the dimensions to aggregate; supports negative indexing.
591    ///
592    /// # Returns
593    ///
594    /// The returned tensor will have the same rank,
595    /// but the aggregated dimensions will have size 1.
596    ///
597    /// # Example
598    ///
599    /// ```rust
600    /// use burn_tensor::backend::Backend;
601    /// use burn_tensor::{Tensor, Shape};
602    ///
603    /// fn example<B: Backend>() {
604    ///     let device = B::Device::default();
605    ///     let tensor = Tensor::<B, 3>::from_data([
606    ///         [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
607    ///         [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
608    ///     ], &device);
609    ///     let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
610    ///     println!("{tensor}");
611    ///     // [20.0, 16.0, 21.0]
612    /// }
613    /// ```
614    pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
615        // TODO: remove idims when squeeze_dims uses AsIndex.
616        let idims = dims
617            .iter()
618            .map(|&dim| canonicalize_dim(dim, D, false) as isize)
619            .collect::<Vec<_>>();
620        self.sum_dims(dims).squeeze_dims::<D2>(&idims)
621    }
622
623    /// Aggregate all elements in the tensor with the product operation.
624    ///
625    /// # Example
626    ///
627    /// ```rust
628    /// use burn_tensor::backend::Backend;
629    /// use burn_tensor::{Tensor, Shape};
630    ///
631    /// fn example<B: Backend>() {
632    ///    let device = B::Device::default();
633    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
634    ///    let tensor = tensor.prod();
635    ///    println!("{tensor}");
636    ///    // [-1620.0]
637    /// }
638    /// ```
639    pub fn prod(self) -> Tensor<B, 1, K> {
640        Tensor::new(K::prod(self.primitive))
641    }
642
643    /// Aggregate all elements along the given *dimension* or *axis*
644    /// in the tensor with the product operation.
645    ///
646    /// # Arguments
647    ///
648    /// * `dim` - The dimension or axis along which to aggregate the elements,
649    ///   supports negative indexing.
650    ///
651    /// # Returns
652    ///
653    /// The returned tensor will have the same rank,
654    /// but the aggregated dimension will have size 1.
655    ///
656    /// # Example
657    ///
658    /// ```rust
659    /// use burn_tensor::backend::Backend;
660    /// use burn_tensor::{Tensor, Shape};
661    ///
662    /// fn example<B: Backend>() {
663    ///    let device = B::Device::default();
664    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
665    ///    let tensor = tensor.clone().prod_dim(0);
666    ///    println!("{tensor}");
667    ///    // [[5.0, -18.0, 18.0]]
668    ///    let tensor = tensor.clone().prod_dim(1);
669    ///    println!("{tensor}");
670    ///    // [[-6.0], [270.0]]
671    /// }
672    /// ```
673    pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
674        let dim = canonicalize_dim(dim, D, false);
675        check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
676        Self::new(K::prod_dim(self.primitive, dim))
677    }
678
679    /// Aggregate all elements along the given *axes*
680    /// in the tensor with the prod operation.
681    ///
682    /// # Arguments
683    ///
684    /// * `dims` - the dimensions to aggregate, supports negative indexing.
685    ///
686    /// # Returns
687    ///
688    /// The returned tensor will have the same rank,
689    /// but the aggregated dimensions will have size 1.
690    ///
691    /// # Example
692    ///
693    /// ```rust
694    /// use burn_tensor::backend::Backend;
695    /// use burn_tensor::{Tensor, Shape};
696    ///
697    /// fn example<B: Backend>() {
698    ///    let device = B::Device::default();
699    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
700    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
701    ///    println!("{tensor}");
702    ///    // [[-1620.0]]
703    /// }
704    /// ```
705    pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
706        dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
707    }
708
709    /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
710    ///
711    /// # Arguments
712    ///
713    /// * `dim` - The dimension or axis along which to compute the cumulative sum.
714    ///
715    /// # Example
716    ///
717    /// ```rust
718    /// use burn_tensor::backend::Backend;
719    /// use burn_tensor::{Tensor, Shape};
720    ///
721    /// fn example<B: Backend>() {
722    ///    let device = B::Device::default();
723    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
724    ///    let result = tensor.clone().cumsum(0);
725    ///    println!("{result}");
726    ///    // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
727    ///    let result = tensor.cumsum(1);
728    ///    println!("{result}");
729    ///    // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
730    /// }
731    /// ```
732    pub fn cumsum(self, dim: usize) -> Self {
733        check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
734        Self::new(K::cumsum(self.primitive, dim))
735    }
736
737    /// Computes the cumulative product of elements along the given *dimension* or *axis*.
738    ///
739    /// # Arguments
740    ///
741    /// * `dim` - The dimension or axis along which to compute the cumulative product.
742    ///
743    /// # Example
744    ///
745    /// ```rust
746    /// use burn_tensor::backend::Backend;
747    /// use burn_tensor::{Tensor, Shape};
748    ///
749    /// fn example<B: Backend>() {
750    ///    let device = B::Device::default();
751    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
752    ///    let result = tensor.clone().cumprod(0);
753    ///    println!("{result}");
754    ///    // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
755    ///    let result = tensor.cumprod(1);
756    ///    println!("{result}");
757    ///    // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
758    /// }
759    /// ```
760    pub fn cumprod(self, dim: usize) -> Self {
761        check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
762        Self::new(K::cumprod(self.primitive, dim))
763    }
764
765    /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
766    ///
767    /// # Arguments
768    ///
769    /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
770    ///
771    /// # Example
772    ///
773    /// ```rust
774    /// use burn_tensor::backend::Backend;
775    /// use burn_tensor::{Tensor, Shape};
776    ///
777    /// fn example<B: Backend>() {
778    ///    let device = B::Device::default();
779    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
780    ///    let result = tensor.clone().cummin(0);
781    ///    println!("{result}");
782    ///    // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
783    ///    let result = tensor.cummin(1);
784    ///    println!("{result}");
785    ///    // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
786    /// }
787    /// ```
788    pub fn cummin(self, dim: usize) -> Self {
789        check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
790        Self::new(K::cummin(self.primitive, dim))
791    }
792
793    /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
794    ///
795    /// # Arguments
796    ///
797    /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
798    ///
799    /// # Example
800    ///
801    /// ```rust
802    /// use burn_tensor::backend::Backend;
803    /// use burn_tensor::{Tensor, Shape};
804    ///
805    /// fn example<B: Backend>() {
806    ///    let device = B::Device::default();
807    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
808    ///    let result = tensor.clone().cummax(0);
809    ///    println!("{result}");
810    ///    // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
811    ///    let result = tensor.cummax(1);
812    ///    println!("{result}");
813    ///    // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
814    /// }
815    /// ```
816    pub fn cummax(self, dim: usize) -> Self {
817        check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
818        Self::new(K::cummax(self.primitive, dim))
819    }
820
821    ///
822    /// # Arguments
823    ///
824    /// * `other` - The element to compare.
825    ///
826    /// # Example
827    ///
828    /// ```rust
829    /// use burn_tensor::backend::Backend;
830    /// use burn_tensor::{Tensor, Shape};
831    ///
832    /// fn example<B: Backend>() {
833    ///    let device = B::Device::default();
834    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
835    ///    let tensor = tensor.equal_elem(3.0);
836    ///    println!("{tensor}");
837    ///    // [[false, false, true], [false, false, false]]
838    /// }
839    /// ```
840    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
841        Tensor::new(K::equal_elem(self.primitive, other.elem()))
842    }
843
844    /// Applies element wise non-equality comparison and returns a boolean tensor.
845    ///
846    /// # Arguments
847    ///
848    /// * `other` - The element to compare.
849    ///
850    /// # Example
851    ///
852    /// ```rust
853    /// use burn_tensor::backend::Backend;
854    /// use burn_tensor::{Tensor, Shape};
855    ///
856    /// fn example<B: Backend>() {
857    ///    let device = B::Device::default();
858    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
859    ///    let tensor = tensor.not_equal_elem(3.0);
860    ///    println!("{tensor}");
861    ///    // [[true, true, false], [true, true, true]]
862    /// }
863    /// ```
864    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
865        Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
866    }
867
868    /// Applies element wise greater comparison and returns a boolean tensor.
869    ///
870    /// # Panics
871    ///
872    /// If the two tensors don't have the same shape.
873    ///
874    /// # Example
875    ///
876    /// ```rust
877    /// use burn_tensor::backend::Backend;
878    /// use burn_tensor::{Tensor, Shape};
879    ///
880    /// fn example<B: Backend>() {
881    ///   let device = B::Device::default();
882    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
883    ///   let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
884    ///   let tensor = tensor1.greater(tensor2);
885    ///   println!("{tensor}");
886    ///   // [[false, false, false], [true, true, true]]
887    /// }
888    /// ```
889    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
890        check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
891        Tensor::new(K::greater(self.primitive, other.primitive))
892    }
893
894    /// Applies element wise greater-equal comparison and returns a boolean tensor.
895    ///
896    /// # Panics
897    ///
898    /// If the two tensors don't have the same shape.
899    ///
900    /// # Example
901    ///
902    /// ```rust
903    /// use burn_tensor::backend::Backend;
904    /// use burn_tensor::{Tensor, Shape};
905    ///
906    /// fn example<B: Backend>() {
907    ///    let device = B::Device::default();
908    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
909    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
910    ///    let tensor = tensor1.greater_equal(tensor2);
911    ///    println!("{tensor}");
912    ///    // [[true, false, false], [true, true, true]]
913    /// }
914    /// ```
915    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
916        check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
917        Tensor::new(K::greater_equal(self.primitive, other.primitive))
918    }
919
920    /// Applies element wise lower comparison and returns a boolean tensor.
921    ///
922    /// # Panics
923    ///
924    /// If the two tensors don't have the same shape.
925    ///
926    /// # Example
927    ///
928    /// ```rust
929    /// use burn_tensor::backend::Backend;
930    /// use burn_tensor::{Tensor, Shape};
931    ///
932    /// fn example<B: Backend>() {
933    ///    let device = B::Device::default();
934    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
935    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
936    ///    let tensor = tensor1.lower(tensor2);
937    ///    println!("{tensor}");
938    ///    // [[false, true, true], [false, false, false]]
939    /// }
940    /// ```
941    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
942        check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
943        Tensor::new(K::lower(self.primitive, other.primitive))
944    }
945
946    /// Applies element wise lower-equal comparison and returns a boolean tensor.
947    ///
948    /// # Panics
949    ///
950    /// If the two tensors don't have the same shape.
951    ///
952    /// # Example
953    ///
954    /// ```rust
955    /// use burn_tensor::backend::Backend;
956    /// use burn_tensor::{Tensor, Shape};
957    ///
958    /// fn example<B: Backend>() {
959    ///    let device = B::Device::default();
960    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
961    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
962    ///    let tensor = tensor1.lower_equal(tensor2);
963    ///    println!("{tensor}");
964    ///    // [[true, true, true], [false, false, false]]
965    /// }
966    /// ```
967    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
968        check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
969        Tensor::new(K::lower_equal(self.primitive, other.primitive))
970    }
971
972    /// Applies greater than `other` comparison and returns a boolean tensor.
973    ///
974    /// # Arguments
975    ///
976    /// * `other` - The element to compare.
977    ///
978    /// # Example
979    ///
980    /// ```rust
981    /// use burn_tensor::backend::Backend;
982    /// use burn_tensor::{Tensor, Shape};
983    ///
984    /// fn example<B: Backend>() {
985    ///    let device = B::Device::default();
986    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
987    ///    let tensor = tensor.greater_elem(3.0);
988    ///    println!("{tensor}");
989    ///    // [[false, false, true], [true, true, true]]
990    /// }
991    /// ```
992    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
993        Tensor::new(K::greater_elem(self.primitive, other.elem()))
994    }
995
996    /// Applies greater-equal than `other` comparison and returns a boolean tensor.
997    ///
998    /// # Arguments
999    ///
1000    /// * `other` - The element to compare.
1001    ///
1002    /// # Example
1003    ///
1004    /// ```rust
1005    /// use burn_tensor::backend::Backend;
1006    /// use burn_tensor::{Tensor, Shape};
1007    ///
1008    /// fn example<B: Backend>() {
1009    ///    let device = B::Device::default();
1010    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1011    ///    let tensor = tensor.greater_equal_elem(3.0);
1012    ///    println!("{tensor}");
1013    ///    // [[false, false, true], [true, true, true]]
1014    /// }
1015    /// ```
1016    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
1017        Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
1018    }
1019
1020    /// Applies lower than `other` comparison and returns a boolean tensor.
1021    ///
1022    /// # Arguments
1023    ///
1024    /// * `other` - The element to compare.
1025    ///
1026    /// # Example
1027    ///
1028    /// ```rust
1029    /// use burn_tensor::backend::Backend;
1030    /// use burn_tensor::{Tensor, Shape};
1031    ///
1032    /// fn example<B: Backend>() {
1033    ///     let device = B::Device::default();
1034    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1035    ///     let tensor = tensor.lower_elem(3.0);
1036    ///     println!("{tensor}");
1037    ///     // [[true, true, false], [false, false, false]]
1038    /// }
1039    /// ```
1040    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
1041        Tensor::new(K::lower_elem(self.primitive, other.elem()))
1042    }
1043
1044    /// Applies lower-equal than `other` comparison and returns a boolean tensor.
1045    ///
1046    /// # Arguments
1047    ///
1048    /// * `other` - The element to compare.
1049    ///
1050    /// # Example
1051    ///
1052    /// ```rust
1053    /// use burn_tensor::backend::Backend;
1054    /// use burn_tensor::{Tensor, Shape};
1055    ///
1056    /// fn example<B: Backend>() {
1057    ///    let device = B::Device::default();
1058    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1059    ///    let tensor = tensor.lower_equal_elem(3.0);
1060    ///    println!("{tensor}");
1061    ///    // [[true, true, true], [false, false, false]]
1062    /// }
1063    /// ```
1064    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
1065        Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
1066    }
1067
1068    /// Update the given tensor with the value tensor where the mask is true.
1069    ///
1070    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1071    /// a scalar.
1072    ///
1073    /// # Example
1074    ///
1075    /// ```rust
1076    /// use burn_tensor::backend::Backend;
1077    /// use burn_tensor::{Tensor, Shape, Bool};
1078    ///
1079    /// fn example<B: Backend>() {
1080    ///   let device = B::Device::default();
1081    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1082    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1083    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1084    ///   let tensor = tensor.mask_where(mask, value);
1085    ///   println!("{tensor}");
1086    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1087    /// }
1088    /// ```
1089    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1090        Self::new(K::mask_where(
1091            self.primitive,
1092            mask.primitive,
1093            value.primitive,
1094        ))
1095    }
1096
1097    /// Update the given tensor with the value where the mask is true.
1098    ///
1099    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1100    /// a tensor.
1101    ///
1102    /// # Example
1103    ///
1104    /// ```rust
1105    /// use burn_tensor::backend::Backend;
1106    /// use burn_tensor::{Tensor, Shape, Bool};
1107    ///
1108    /// fn example<B: Backend>() {
1109    ///   let device = B::Device::default();
1110    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1111    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1112    ///   let tensor = tensor.mask_fill(mask, 3.0);
1113    ///   println!("{tensor}");
1114    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1115    /// }
1116    /// ```
1117    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1118        Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
1119    }
1120
1121    /// Gather tensor elements corresponding to the given indices from the specified dim.
1122    ///
1123    /// Example using a 3D tensor:
1124    ///
1125    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1126    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1127    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1128    ///
1129    /// # Notes
1130    ///
1131    /// The index tensor should have the same shape as the original tensor except for the dim
1132    /// specified.
1133    ///
1134    /// # Warning
1135    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1136    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1137    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1138        check!(TensorCheck::gather::<D>(
1139            dim,
1140            &self.shape(),
1141            &indices.shape()
1142        ));
1143
1144        Self::new(K::gather(dim, self.primitive, indices.primitive))
1145    }
1146
1147    /// Assign the gathered elements corresponding to the given indices along the specified dimension
1148    /// from the value tensor to the original tensor using sum reduction.
1149    ///
1150    /// Example using a 3D tensor:
1151    ///
1152    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1153    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1154    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1155    ///
1156    /// # Notes
1157    ///
1158    /// The index tensor should have the same shape as the original tensor except for the specified
1159    /// dimension. The value and index tensors should have the same shape.
1160    ///
1161    /// Other references to the input tensor will not be modified by this operation.
1162    ///
1163    /// # Warning
1164    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1165    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1166    pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
1167        check!(TensorCheck::scatter::<D>(
1168            dim,
1169            &self.shape(),
1170            &indices.shape(),
1171            &values.shape()
1172        ));
1173
1174        Self::new(K::scatter(
1175            dim,
1176            self.primitive,
1177            indices.primitive,
1178            values.primitive,
1179        ))
1180    }
1181
1182    /// Applies the argmax function along the given dimension and returns an integer tensor.
1183    ///
1184    /// # Example
1185    ///
1186    /// ```rust
1187    /// use burn_tensor::backend::Backend;
1188    /// use burn_tensor::{Tensor, Shape};
1189    ///
1190    /// fn example<B: Backend>() {
1191    ///     let device = B::Device::default();
1192    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1193    ///     let tensor = tensor.argmax(1);
1194    ///     println!("{:?}", tensor.shape());
1195    ///     // Shape { dims: [2, 1, 3] }
1196    /// }
1197    /// ```
1198    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
1199        Tensor::new(K::argmax(self.primitive, dim))
1200    }
1201
1202    /// Find the maximum value.
1203    ///
1204    /// # Example
1205    ///
1206    /// ```rust
1207    /// use burn_tensor::backend::Backend;
1208    /// use burn_tensor::{Tensor, Shape};
1209    ///
1210    /// fn example<B: Backend>() {
1211    ///   let device = B::Device::default();
1212    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1213    ///   let tensor = tensor.max();
1214    ///   println!("{tensor}");
1215    ///   // [9.0]
1216    /// }
1217    /// ```
1218    pub fn max(self) -> Tensor<B, 1, K> {
1219        Tensor::new(K::max(self.primitive))
1220    }
1221
1222    /// Find the maximum value along the given dimension.
1223    ///
1224    /// # Arguments
1225    ///
1226    /// * `dim` - The dimension or axis along which to aggregate the elements;
1227    ///   supports negative indexing.
1228    ///
1229    /// # Returns
1230    ///
1231    /// The returned tensor will have the same rank,
1232    /// but the aggregated dimension will have size 1.
1233    ///
1234    /// # Example
1235    ///
1236    /// ```rust
1237    /// use burn_tensor::backend::Backend;
1238    /// use burn_tensor::{Tensor, Shape};
1239    ///
1240    /// fn example<B: Backend>() {
1241    ///   let device = B::Device::default();
1242    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1243    ///   let tensor = tensor.max_dim(0);
1244    ///   println!("{tensor}");
1245    ///   // [[5.0, 9.0, 6.0]]
1246    /// }
1247    /// ```
1248    pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
1249        let dim = canonicalize_dim(dim, D, false);
1250        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1251        Tensor::new(K::max_dim(self.primitive, dim))
1252    }
1253
1254    /// Find the maximum value along the given dimensions.
1255    ///
1256    /// # Arguments
1257    ///
1258    /// * `dims` - The dimensions or axis along which to aggregate the elements;
1259    ///   supports negative indexing.
1260    ///
1261    /// # Returns
1262    ///
1263    /// The returned tensor will have the same rank,
1264    /// but the aggregated dimensions will have size 1.
1265    ///
1266    /// # Example
1267    ///
1268    /// ```rust
1269    /// use burn_tensor::backend::Backend;
1270    /// use burn_tensor::{Tensor, Shape};
1271    ///
1272    /// fn example<B: Backend>() {
1273    ///   let device = B::Device::default();
1274    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1275    ///   let tensor = tensor.max_dims(&[0, 1]);
1276    ///   println!("{tensor}");
1277    ///   // [[9.0]]
1278    /// }
1279    /// ```
1280    pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1281        dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1282    }
1283
1284    /// Find the maximum value along the given dimension.
1285    ///
1286    /// Also returns the indices.
1287    ///
1288    /// # Example
1289    ///
1290    /// ```rust
1291    /// use burn_tensor::backend::Backend;
1292    /// use burn_tensor::{Tensor, Shape};
1293    ///
1294    /// fn example<B: Backend>() {
1295    ///    let device = B::Device::default();
1296    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1297    ///    let (tensor, index) = tensor.max_dim_with_indices(0);
1298    ///    // [[5.0, 9.0, 6.0]]
1299    ///    println!("{tensor}");
1300    ///    // [[1, 1, 1]]
1301    ///    println!("{index}");
1302    /// }
1303    /// ```
1304    pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1305        let dim = canonicalize_dim(dim, D, false);
1306        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1307
1308        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1309
1310        let tensor = Tensor::new(tensor);
1311        let index = Tensor::new(index);
1312
1313        (tensor, index)
1314    }
1315
1316    /// Finds the maximum pair wise values with another tensor.
1317    ///
1318    /// # Arguments
1319    ///
1320    /// * `other` - Other tensor to find maximum elements with
1321    ///
1322    /// # Returns
1323    ///
1324    /// A tensor with the same shape as the input tensors containing the maximum value found
1325    /// in the input tensors.
1326    ///
1327    /// # Example
1328    ///
1329    /// ```rust
1330    /// use burn_tensor::backend::Backend;
1331    /// use burn_tensor::{Tensor, Shape};
1332    ///
1333    /// fn example<B: Backend>() {
1334    ///    let device = B::Device::default();
1335    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1336    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1337    ///    let tensor = tensor1.max_pair(tensor2);
1338    ///    println!("{tensor}");
1339    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1340    /// }
1341    /// ```
1342    pub fn max_pair(self, other: Self) -> Self {
1343        let mask = self.clone().lower(other.clone());
1344        self.mask_where(mask, other)
1345    }
1346
1347    /// Find the maximum absolute value.
1348    ///
1349    /// # Example
1350    ///
1351    /// ```rust
1352    /// use burn_tensor::backend::Backend;
1353    /// use burn_tensor::{Tensor, Shape};
1354    ///
1355    /// fn example<B: Backend>() {
1356    ///   let device = B::Device::default();
1357    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
1358    ///   let tensor = tensor.max_abs();
1359    ///   println!("{tensor}");
1360    ///   // [7.0]
1361    /// }
1362    /// ```
1363    pub fn max_abs(self) -> Tensor<B, 1, K> {
1364        Tensor::new(K::max_abs(self.primitive))
1365    }
1366
1367    /// Find the maximum absolute value along the given dimension.
1368    ///
1369    /// # Arguments
1370    ///
1371    /// * `dim` - The dimension or axis along which to aggregate the elements,
1372    ///   supports negative indexing.
1373    ///
1374    /// # Returns
1375    ///
1376    /// The returned tensor will have the same rank,
1377    /// but the aggregated dimension will have size 1.
1378    ///
1379    /// # Example
1380    ///
1381    /// ```rust
1382    /// use burn_tensor::backend::Backend;
1383    /// use burn_tensor::{Tensor, Shape};
1384    ///
1385    /// fn example<B: Backend>() {
1386    ///   let device = B::Device::default();
1387    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1388    ///   let tensor = tensor.max_dim(0);
1389    ///   println!("{tensor}");
1390    ///   // [[5.0, 9.0, 6.0]]
1391    /// }
1392    /// ```
1393    pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
1394        let dim = canonicalize_dim(dim, D, false);
1395        check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
1396
1397        Tensor::new(K::max_abs_dim(self.primitive, dim))
1398    }
1399
1400    /// Find the maximum absolute value along the given dimensions.
1401    ///
1402    /// # Arguments
1403    ///
1404    /// * `dims` - The dimensions or axes along which to aggregate the elements,
1405    ///   supports negative indexing.
1406    ///
1407    /// # Returns
1408    ///
1409    /// The returned tensor will have the same rank,
1410    /// but the aggregated dimensions will have size 1.
1411    ///
1412    /// # Example
1413    ///
1414    /// ```rust
1415    /// use burn_tensor::backend::Backend;
1416    /// use burn_tensor::{Tensor, Shape};
1417    ///
1418    /// fn example<B: Backend>() {
1419    ///   let device = B::Device::default();
1420    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1421    ///   let tensor = tensor.max_abs_dims(&[0, 1]);
1422    ///   println!("{tensor}");
1423    ///   // [[9.0]]
1424    /// }
1425    /// ```
1426    pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1427        dims.iter()
1428            .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
1429    }
1430
1431    /// Applies the argmin function along the given dimension and returns an integer tensor.
1432    ///
1433    /// # Example
1434    ///
1435    /// ```rust
1436    /// use burn_tensor::backend::Backend;
1437    /// use burn_tensor::{Tensor, Shape};
1438    ///
1439    /// fn example<B: Backend>() {
1440    ///     let device = Default::default();
1441    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1442    ///     let tensor = tensor.argmin(1);
1443    ///     println!("{:?}", tensor.shape());
1444    ///     // Shape { dims: [2, 1, 3] }
1445    /// }
1446    /// ```
1447    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1448        Tensor::new(K::argmin(self.primitive, dim))
1449    }
1450
1451    /// Find the minimum value.
1452    ///
1453    /// # Example
1454    ///
1455    /// ```rust
1456    /// use burn_tensor::backend::Backend;
1457    /// use burn_tensor::{Tensor, Shape};
1458    ///
1459    /// fn example<B: Backend>() {
1460    ///    let device = B::Device::default();
1461    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1462    ///    let tensor = tensor.min();
1463    ///    println!("{tensor}");
1464    ///    // [-2.0]
1465    /// }
1466    /// ```
1467    pub fn min(self) -> Tensor<B, 1, K> {
1468        Tensor::new(K::min(self.primitive))
1469    }
1470
1471    /// Find the minimum value along the given dimension.
1472    ///
1473    /// # Arguments
1474    ///
1475    /// * `dim` - The dimension or axis along which to aggregate the elements;
1476    ///   supports negative indexing.
1477    ///
1478    /// # Returns
1479    ///
1480    /// The returned tensor will have the same rank,
1481    /// but the aggregated dimension will have size 1.
1482    ///
1483    /// # Example
1484    ///
1485    /// ```rust
1486    /// use burn_tensor::backend::Backend;
1487    /// use burn_tensor::{Tensor, Shape};
1488    ///
1489    /// fn example<B: Backend>() {
1490    ///    let device = B::Device::default();
1491    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1492    ///    let tensor = tensor.min_dim(0);
1493    ///    println!("{tensor}");
1494    ///    // [[1.0, -2.0, 3.0]]
1495    /// }
1496    /// ```
1497    pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
1498        let dim = canonicalize_dim(dim, D, false);
1499        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1500        Tensor::new(K::min_dim(self.primitive, dim))
1501    }
1502
1503    /// Find the minimum value along the given dimensions.
1504    ///
1505    /// # Arguments
1506    ///
1507    /// * `dims` - The dimensions or axes along which to aggregate the elements;
1508    ///   supports negative indexing.
1509    ///
1510    /// # Returns
1511    ///
1512    /// The returned tensor will have the same rank,
1513    /// but the aggregated dimensions will have size 1.
1514    ///
1515    /// # Example
1516    ///
1517    /// ```rust
1518    /// use burn_tensor::backend::Backend;
1519    /// use burn_tensor::{Tensor, Shape};
1520    ///
1521    /// fn example<B: Backend>() {
1522    ///   let device = B::Device::default();
1523    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1524    ///   let tensor = tensor.min_dims(&[0, 1]);
1525    ///   println!("{tensor}");
1526    ///   // [[-2.0]]
1527    /// }
1528    /// ```
1529    pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1530        dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
1531    }
1532
1533    /// Find the minimum value along the given dimension.
1534    ///
1535    /// Also returns the indices.
1536    ///
1537    /// # Example
1538    ///
1539    /// ```rust
1540    /// use burn_tensor::backend::Backend;
1541    /// use burn_tensor::{Tensor, Shape};
1542    ///
1543    /// fn example<B: Backend>() {
1544    ///    let device = B::Device::default();
1545    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1546    ///    let (tensor, index) = tensor.min_dim_with_indices(0);
1547    ///    println!("{tensor}");
1548    ///    // [[5.0, -2.0, 3.0]]
1549    ///    println!("{}", index);
1550    ///    // [[1, 0, 0]]
1551    /// }
1552    /// ```
1553    pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1554        let dim = canonicalize_dim(dim, D, false);
1555        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1556
1557        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1558
1559        let tensor = Tensor::new(tensor);
1560        let index = Tensor::new(index);
1561
1562        (tensor, index)
1563    }
1564
1565    /// Finds the minimum pair wise values with another tensor.
1566    ///
1567    /// # Arguments
1568    ///
1569    /// * `other` - Other tensor to find minimum elements with
1570    ///
1571    /// # Returns
1572    ///
1573    /// A tensor with the same shape as the input tensors containing the minimum value found
1574    /// between each element of the two source tensors.
1575    ///
1576    /// # Example
1577    ///
1578    /// ```rust
1579    /// use burn_tensor::backend::Backend;
1580    /// use burn_tensor::{Tensor, Shape};
1581    ///
1582    /// fn example<B: Backend>() {
1583    ///    let device = B::Device::default();
1584    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1585    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1586    ///    let tensor = tensor1.min_pair(tensor2);
1587    ///    println!("{tensor}");
1588    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1589    /// }
1590    pub fn min_pair(self, other: Self) -> Self {
1591        let mask = other.clone().lower(self.clone());
1592        self.mask_where(mask, other)
1593    }
1594
1595    /// Clamp element wise between the given min and max values.
1596    ///
1597    /// # Arguments
1598    ///
1599    /// * `min` - The minimum value.
1600    /// * `max` - The maximum value.
1601    ///
1602    /// # Returns
1603    ///
1604    /// A new tensor with the values clamped between the given min and max values.
1605    ///
1606    /// # Example
1607    ///
1608    /// ```rust
1609    /// use burn_tensor::backend::Backend;
1610    /// use burn_tensor::{Int, Tensor};
1611    ///
1612    /// fn example<B: Backend>() {
1613    ///   let device = Default::default();
1614    ///   let tensor = Tensor::<B, 2, Int>::from_ints(
1615    ///    [
1616    ///     [1, 2, 3],
1617    ///     [4, 5, 6],
1618    ///     [7, 8, 9]
1619    ///    ],
1620    ///    &device);
1621    ///    let tensor = tensor.clamp(2, 6);
1622    ///    println!("{tensor}");
1623    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1624    /// }
1625    /// ```
1626    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1627        Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1628    }
1629
1630    /// Clamp element wise under a minimum value.
1631    ///
1632    /// # Arguments
1633    ///
1634    /// * `tensor` - The tensor to clamp.
1635    /// * `min` - The minimum value.
1636    ///
1637    /// # Returns
1638    ///
1639    /// A new tensor with the values clamped under the given min value.
1640    ///
1641    /// # Example
1642    ///
1643    /// ```rust
1644    /// use burn_tensor::backend::Backend;
1645    /// use burn_tensor::{Int, Tensor};
1646    ///
1647    /// fn example<B: Backend>() {
1648    ///    let device = Default::default();
1649    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1650    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1651    ///    &device);
1652    ///    let tensor = tensor.clamp_min(4);
1653    ///    println!("{tensor}");
1654    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1655    /// }
1656    /// ```
1657    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1658        Self::new(K::clamp_min(self.primitive, min.elem()))
1659    }
1660
1661    /// Clamp element wise over a maximum value.
1662    ///
1663    /// # Arguments
1664    ///
1665    /// * `tensor` - The tensor to clamp.
1666    /// * `max` - The maximum value.
1667    ///
1668    /// # Returns
1669    ///
1670    /// A new tensor with the values clamped over the given max value.
1671    ///
1672    /// # Example
1673    ///
1674    /// ```rust
1675    /// use burn_tensor::backend::Backend;
1676    /// use burn_tensor::{Int, Tensor};
1677    ///
1678    /// fn example<B: Backend>() {
1679    ///    let device = Default::default();
1680    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1681    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1682    ///    &device);
1683    ///    let tensor = tensor.clamp_max(5);
1684    ///    println!("{tensor}");
1685    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1686    /// }
1687    /// ```
1688    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1689        Self::new(K::clamp_max(self.primitive, max.elem()))
1690    }
1691
1692    /// Apply element wise absolute value operation.
1693    ///
1694    /// # Example
1695    ///
1696    /// ```rust
1697    /// use burn_tensor::backend::Backend;
1698    /// use burn_tensor::{Int, Tensor};
1699    ///
1700    /// fn example<B: Backend>() {
1701    ///   let device = Default::default();
1702    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1703    ///   let tensor = tensor.abs();
1704    ///   println!("{tensor}");
1705    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1706    /// }
1707    /// ```
1708    pub fn abs(self) -> Self {
1709        Self::new(K::abs(self.primitive))
1710    }
1711
1712    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1713    /// the other elements of the result tensor out are set to 0.
1714    ///
1715    /// See also [`triu_mask`](Tensor::triu_mask).
1716    ///
1717    /// # Arguments
1718    ///
1719    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1720    ///   towards the upper triangle.
1721    ///
1722    /// # Example
1723    /// ```rust
1724    /// use burn_tensor::backend::Backend;
1725    /// use burn_tensor::{Int, Tensor};
1726    ///
1727    /// fn example<B: Backend>() {
1728    ///    let device = Default::default();
1729    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1730    ///        [
1731    ///          [1, 2, 3],
1732    ///          [4, 5, 6],
1733    ///          [7, 8, 9]
1734    ///        ],
1735    ///        &device
1736    ///    );
1737    ///    let tensor = tensor.triu(1);
1738    ///    println!("{tensor}");
1739    ///    // [
1740    ///    //   [0, 2, 3],
1741    ///    //   [0, 0, 6],
1742    ///    //   [0, 0, 0]
1743    ///    // ]
1744    /// }
1745    /// ```
1746    pub fn triu(self, diagonal: i64) -> Self {
1747        check!(TensorCheck::tri::<{ D }>());
1748
1749        // last two dimensions
1750        let shape = &self.shape().dims[D - 2..].to_owned();
1751
1752        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1753        self.mask_fill(mask, 0)
1754    }
1755
1756    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1757    /// the other elements of the result tensor out are set to 0.
1758    ///
1759    /// See also [`tril_mask`](Tensor::tril_mask).
1760    ///
1761    /// # Arguments
1762    ///
1763    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1764    ///   towards the upper triangle.
1765    ///
1766    /// # Example
1767    /// ```rust
1768    /// use burn_tensor::backend::Backend;
1769    /// use burn_tensor::{Int, Tensor};
1770    ///
1771    /// fn example<B: Backend>() {
1772    ///    let device = Default::default();
1773    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1774    ///        [
1775    ///          [1, 2, 3],
1776    ///          [4, 5, 6],
1777    ///          [7, 8, 9]
1778    ///        ],
1779    ///        &device
1780    ///    );
1781    ///
1782    ///    let tensor = tensor.tril(-1);
1783    ///    println!("{tensor}");
1784    ///    // [
1785    ///    //   [0, 0, 0],
1786    ///    //   [4, 0, 0],
1787    ///    //   [7, 8, 0]
1788    ///    // ]
1789    /// }
1790    /// ```
1791    pub fn tril(self, diagonal: i64) -> Self {
1792        check!(TensorCheck::tri::<{ D }>());
1793
1794        // last two dimensions
1795        let shape = &self.shape().dims[D - 2..].to_owned();
1796        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1797
1798        self.mask_fill(mask, 0)
1799    }
1800
1801    /// Applies element wise power operation with a float Tensor
1802    ///
1803    /// # Arguments
1804    ///
1805    /// * `other` - The tensor to apply the power operation with.
1806    ///
1807    /// # Example
1808    ///
1809    /// ```rust
1810    /// use burn_tensor::backend::Backend;
1811    /// use burn_tensor::{Tensor, Shape};
1812    ///
1813    /// fn example<B: Backend>() {
1814    ///    let device = B::Device::default();
1815    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1816    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1817    ///    let tensor = tensor1.powf(tensor2);
1818    ///    println!("{tensor}");
1819    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1820    /// }
1821    /// ```
1822    pub fn powf(self, other: Self) -> Self {
1823        Self::new(K::powf(self.primitive, other.primitive))
1824    }
1825
1826    /// Applies element wise power operation with a float scalar
1827    ///
1828    /// # Arguments
1829    ///
1830    /// * `other` - The scalar to apply the power operation with.
1831    ///
1832    /// # Example
1833    ///
1834    /// ```rust
1835    /// use burn_tensor::backend::Backend;
1836    /// use burn_tensor::{Tensor, Shape};
1837    ///
1838    /// fn example<B: Backend>() {
1839    ///    let device = B::Device::default();
1840    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1841    ///    let tensor = tensor.powf_scalar(2.0);
1842    ///    println!("{tensor}");
1843    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1844    /// }
1845    /// ```
1846    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1847        Self::new(K::powf_scalar::<E>(self.primitive, other))
1848    }
1849
1850    /// Applies element wise power operation with a integer Tensor
1851    ///
1852    /// # Arguments
1853    ///
1854    /// * `other` - The tensor to apply the power operation with.
1855    ///
1856    /// # Example
1857    ///
1858    /// ```rust
1859    /// use burn_tensor::backend::Backend;
1860    /// use burn_tensor::{Tensor, Shape, Int};
1861    ///
1862    /// fn example<B: Backend>() {
1863    ///    let device = B::Device::default();
1864    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1865    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1866    ///    let tensor = tensor1.powi(tensor2);
1867    ///    println!("{tensor}");
1868    ///    // [[1, -8, 81], [5, 81, 216]]
1869    /// }
1870    /// ```
1871    pub fn powi(self, other: Self) -> Self {
1872        Self::new(K::powi(self.primitive, other.primitive))
1873    }
1874
1875    /// Applies element wise power operation with a integer scalar
1876    ///
1877    /// # Arguments
1878    ///
1879    /// * `other` - The scalar to apply the power operation with.
1880    ///
1881    /// # Example
1882    ///
1883    /// ```rust
1884    /// use burn_tensor::backend::Backend;
1885    /// use burn_tensor::{Tensor, Shape, Int};
1886    ///
1887    /// fn example<B: Backend>() {
1888    ///    let device = B::Device::default();
1889    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1890    ///    let tensor = tensor.powi_scalar(2);
1891    ///    println!("{tensor}");
1892    ///
1893    ///    // [[1, 4, 9], [25, 81, 36]]
1894    ///    let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
1895    ///    let tensor = tensor.powi_scalar(2);
1896    ///    println!("{tensor}");
1897    ///    // [[2.25, 4., 9.], [25., 81., 36.]]
1898    /// }
1899    /// ```
1900    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1901        Self::new(K::powi_scalar::<E>(self.primitive, other))
1902    }
1903
1904    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1905    ///
1906    /// # Returns
1907    ///
1908    /// A boolean tensor with the same shape as the input tensor.
1909    ///
1910    /// # Example
1911    ///
1912    /// ```rust
1913    /// use burn_tensor::backend::Backend;
1914    /// use burn_tensor::{Tensor, Shape};
1915    ///
1916    /// fn example<B: Backend>() {
1917    ///   let device = B::Device::default();
1918    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1919    ///   let tensor = tensor.bool();
1920    ///   println!("{tensor}");
1921    ///   // [
1922    ///   //   [true, true, true],
1923    ///   //   [false, true, true]
1924    ///   // ]
1925    /// }
1926    pub fn bool(self) -> Tensor<B, D, Bool> {
1927        Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1928    }
1929
1930    /// Create a random tensor of the given shape on the given device where each element is
1931    /// sampled from the given distribution.
1932    ///
1933    /// See also [`random_like`](Tensor::random_like).
1934    ///
1935    /// # Arguments
1936    ///
1937    /// * `shape` - The shape of the tensor.
1938    /// * `distribution` - The distribution to sample from.
1939    /// * `device` - The device to create the tensor on.
1940    ///
1941    /// # Returns
1942    ///
1943    /// A new tensor with the given shape and elements sampled from the given distribution.
1944    ///
1945    /// # Example
1946    ///
1947    /// ```rust
1948    /// use burn_tensor::backend::Backend;
1949    /// use burn_tensor::{Tensor, Shape, Distribution};
1950    ///
1951    /// fn example<B: Backend>() {
1952    ///   let device = B::Device::default();
1953    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1954    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1955    ///   println!("{tensor}");
1956    ///   // [
1957    ///   //   [0.08347523, 0.70498955, 0.60332155],
1958    ///   //   [0.08173251, 0.18028641, 0.97942924]
1959    ///   // ]
1960    /// }
1961    /// ```
1962    pub fn random<S: Into<Shape>>(
1963        shape: S,
1964        distribution: Distribution,
1965        device: &B::Device,
1966    ) -> Self {
1967        Self::new(K::random(shape.into(), distribution, device))
1968    }
1969
1970    /// Sort the elements by value in ascending order along a given dimension.
1971    ///
1972    /// This sort is unstable (i.e., may reorder equal elements).
1973    ///
1974    /// # Arguments
1975    ///
1976    /// * `dim` - The dimension to sort along.
1977    ///
1978    /// # Returns
1979    ///
1980    /// A new tensor with the elements sorted in ascending order along the given dimension.
1981    ///
1982    /// # Example
1983    ///
1984    /// ```rust
1985    /// use burn_tensor::backend::Backend;
1986    /// use burn_tensor::{Tensor, Shape};
1987    ///
1988    /// fn example<B: Backend>() {
1989    ///   let device = B::Device::default();
1990    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1991    ///   let tensor = tensor.sort(0);
1992    ///   println!("{tensor}");
1993    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1994    ///   let tensor = tensor.sort(1);
1995    ///   println!("{tensor}");
1996    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1997    /// }
1998    /// ```
1999    pub fn sort(self, dim: usize) -> Self {
2000        check!(TensorCheck::sort_dim::<D>("Sort", dim));
2001        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
2002    }
2003
2004    /// Sort the elements by value in descending order along a given dimension.
2005    ///
2006    /// This sort is unstable (i.e., may reorder equal elements).
2007    ///
2008    /// # Arguments
2009    ///
2010    /// * `dim` - The dimension to sort along.
2011    ///
2012    /// # Returns
2013    ///
2014    /// A new tensor with the elements sorted in descending order along the given dimension.
2015    ///
2016    /// # Example
2017    ///
2018    /// ```rust
2019    /// use burn_tensor::backend::Backend;
2020    /// use burn_tensor::{Tensor, Shape};
2021    ///
2022    /// fn example<B: Backend>() {
2023    ///    let device = B::Device::default();
2024    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2025    ///    let tensor = tensor.sort_descending(0);
2026    ///    println!("{tensor}");
2027    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2028    ///    let tensor = tensor.sort_descending(1);
2029    ///    println!("{tensor}");
2030    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
2031    /// }
2032    /// ```
2033    pub fn sort_descending(self, dim: usize) -> Self {
2034        check!(TensorCheck::sort_dim::<D>("Sort", dim));
2035        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
2036    }
2037
2038    /// Sort the elements by value in ascending order along a given dimension.
2039    /// Also returns the indices.
2040    ///
2041    /// This sort is unstable (i.e., may reorder equal elements).
2042    ///
2043    /// # Arguments
2044    ///
2045    /// * `dim` - The dimension to sort along.
2046    ///
2047    /// # Returns
2048    ///
2049    /// A tuple containing the sorted tensor and the indices tensor.
2050    ///
2051    /// # Example
2052    ///
2053    /// ```rust
2054    /// use burn_tensor::backend::Backend;
2055    /// use burn_tensor::{Tensor, Shape};
2056    ///
2057    /// fn example<B: Backend>() {
2058    ///   let device = B::Device::default();
2059    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2060    ///   let (tensor, indices) = tensor.sort_with_indices(0);
2061    ///   println!("{tensor}");
2062    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
2063    ///   println!("{}", indices);
2064    ///   // [[1, 0, 0], [0, 1, 1]]
2065    /// }
2066    /// ```
2067    pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
2068        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
2069        let (values, indices) =
2070            K::sort_with_indices(self.primitive, dim, /*descending*/ false);
2071        (Tensor::new(values), Tensor::new(indices))
2072    }
2073
2074    /// Sort the elements by value in descending order along a given dimension.
2075    /// Also returns the indices.
2076    ///
2077    /// This sort is unstable (i.e., may reorder equal elements).
2078    ///
2079    /// # Arguments
2080    ///
2081    /// * `dim` - The dimension to sort along.
2082    ///
2083    /// # Example
2084    ///
2085    /// ```rust
2086    /// use burn_tensor::backend::Backend;
2087    /// use burn_tensor::{Tensor, Shape};
2088    ///
2089    /// fn example<B: Backend>() {
2090    ///    let device = B::Device::default();
2091    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2092    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);
2093    ///    println!("{tensor}");
2094    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2095    ///    println!("{}", indices);
2096    ///    // [[0, 1, 1], [1, 0, 0]]
2097    /// }
2098    /// ```
2099    pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
2100        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
2101        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
2102        (Tensor::new(values), Tensor::new(indices))
2103    }
2104
2105    /// Returns the indices that sort the elements by value in ascending order along a given dimension.
2106    ///
2107    /// This sort is unstable (i.e., may reorder equal elements).
2108    ///
2109    /// # Arguments
2110    ///
2111    /// * `dim` - The dimension to sort along.
2112    ///
2113    /// # Example
2114    ///
2115    /// ```rust
2116    /// use burn_tensor::backend::Backend;
2117    /// use burn_tensor::{Tensor, Shape};
2118    ///
2119    /// fn example<B: Backend>() {
2120    ///    let device = B::Device::default();
2121    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2122    ///    let tensor = tensor.argsort(0);
2123    ///    println!("{tensor}");
2124    ///    // [[1, 0, 0], [0, 1, 1]]
2125    /// }
2126    /// ```
2127    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
2128        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
2129        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
2130    }
2131
2132    /// Returns the indices that sort the elements by value in descending order along a given dimension.
2133    ///
2134    /// This sort is unstable (i.e., may reorder equal elements).
2135    ///
2136    /// # Arguments
2137    ///
2138    /// * `dim` - The dimension to sort along.
2139    ///
2140    /// # Example
2141    ///
2142    /// ```rust
2143    /// use burn_tensor::backend::Backend;
2144    /// use burn_tensor::{Tensor, Shape};
2145    ///
2146    /// fn example<B: Backend>() {
2147    ///    let device = B::Device::default();
2148    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2149    ///    let tensor = tensor.argsort_descending(0);
2150    ///    println!("{tensor}");
2151    ///    // [[0, 1, 1], [1, 0, 0]]
2152    ///    let tensor = tensor.argsort_descending(1);
2153    ///    println!("{tensor}");
2154    ///    // [[0, 2, 1], [2, 0, 1]]
2155    /// }
2156    /// ```
2157    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
2158        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
2159        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
2160    }
2161
2162    /// Returns the `k` largest elements of the given input tensor along a given dimension.
2163    ///
2164    /// # Arguments
2165    ///
2166    /// * `k` - The number of elements to return.
2167    ///
2168    /// # Returns
2169    ///
2170    /// A new tensor with the `k` largest elements along the given dimension.
2171    ///
2172    /// # Example
2173    ///
2174    /// ```rust
2175    /// use burn_tensor::backend::Backend;
2176    /// use burn_tensor::{Tensor, Shape};
2177    ///
2178    /// fn example<B: Backend>() {
2179    ///   let device = B::Device::default();
2180    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2181    ///   let tensor = tensor.topk(2, 0);
2182    ///   println!("{tensor}");
2183    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2184    ///   let tensor = tensor.topk(1, 1);
2185    ///   println!("{tensor}");
2186    ///   // [[12.0], [6.0]]
2187    /// }
2188    /// ```
2189    pub fn topk(self, k: usize, dim: usize) -> Self {
2190        let k_indices = Tensor::arange(0..k as i64, &self.device());
2191        self.sort_descending(dim).select(dim, k_indices)
2192    }
2193
2194    /// Returns the `k` largest elements of the given input tensor along a given dimension.
2195    /// Also returns the indices.
2196    ///
2197    /// # Arguments
2198    ///
2199    /// * `k` - The number of elements to return.
2200    /// * `dim` - The dimension to sort along.
2201    ///
2202    /// # Example
2203    ///
2204    /// ```rust
2205    /// use burn_tensor::backend::Backend;
2206    /// use burn_tensor::{Tensor, Shape};
2207    ///
2208    /// fn example<B: Backend>() {
2209    ///    let device = B::Device::default();
2210    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2211    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);
2212    ///    println!("{tensor}");
2213    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2214    ///    println!("{}", indices);
2215    ///    // [[0, 1, 1], [1, 0, 0]]
2216    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);
2217    ///    println!("{tensor}");
2218    ///    // [[12.0], [6.0]]
2219    ///    println!("{indices}");
2220    ///    // [[0], [2]]
2221    /// }
2222    /// ```
2223    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
2224        let k_indices = Tensor::arange(0..k as i64, &self.device());
2225        let (values, indices) = self.sort_descending_with_indices(dim);
2226        (
2227            values.select(dim, k_indices.clone()),
2228            indices.select(dim, k_indices),
2229        )
2230    }
2231
2232    /// Pad the tensor of rank two or higher with the given value on the last two dimensions.
2233    ///
2234    /// # Arguments
2235    ///
2236    /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
2237    /// * `value` - The value to pad the tensor with.
2238    ///
2239    /// # Returns
2240    ///
2241    /// A new tensor with the given padding.
2242    ///
2243    /// # Example
2244    ///
2245    /// ```rust
2246    /// use burn_tensor::backend::Backend;
2247    /// use burn_tensor::{Tensor, Shape};
2248    ///
2249    /// fn example<B: Backend<FloatElem: From<f32>>>() {
2250    ///    let device = B::Device::default();
2251    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2252    ///    let tensor = tensor.pad((1, 1, 1, 1), 0.0);
2253    ///    println!("{tensor}");
2254    ///    // [
2255    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0],
2256    ///    //   [0.0, 12.0, -2.0, 3.0, 0.0],
2257    ///    //   [0.0, 5.0, 3.0, 6.0, 0.0],
2258    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0]
2259    ///    // ]
2260    /// }
2261    /// ```
2262    pub fn pad<E: ElementConversion>(
2263        self,
2264        padding: (usize, usize, usize, usize),
2265        value: E,
2266    ) -> Self {
2267        let (left, right, top, bottom) = padding;
2268
2269        let mut padded_dims: [usize; D] = self.dims();
2270
2271        // Update the last two dimensions with padding
2272        padded_dims[D - 2] += top + bottom;
2273        padded_dims[D - 1] += left + right;
2274
2275        // Create the ranges for the padded tensor
2276        let ranges: [core::ops::Range<usize>; D] = padded_dims
2277            .iter()
2278            .enumerate()
2279            .map(|(i, &dim)| {
2280                if i == D - 2 {
2281                    top..dim - bottom
2282                } else if i == D - 1 {
2283                    left..dim - right
2284                } else {
2285                    0..dim
2286                }
2287            })
2288            .collect::<Vec<core::ops::Range<usize>>>()
2289            .try_into()
2290            .unwrap();
2291
2292        // Create the padded tensor
2293        let padded_tensor = Tensor::full(padded_dims, value, &self.device());
2294
2295        // Assign the original tensor data to the appropriate slice of the padded tensor
2296        padded_tensor.slice_assign(ranges, self)
2297    }
2298    /// Create a one hot tensor.
2299    ///
2300    /// # Example
2301    ///
2302    /// ```rust
2303    /// use burn_tensor::backend::Backend;
2304    /// use burn_tensor::Tensor;
2305    ///
2306    /// fn example<B: Backend>(){
2307    ///     let device = Default::default();
2308    ///     let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
2309    ///     let one_hot: Tensor<B, 2> = indices.one_hot(4);
2310    ///     println!("{}", one_hot.to_data());
2311    ///     // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
2312    /// }
2313    /// ```
2314    pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
2315        check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
2316        self.one_hot_fill(num_classes, 1.0, 0.0, -1)
2317    }
2318
2319    /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
2320    ///
2321    /// # Arguments
2322    ///
2323    /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
2324    /// * `on_value`: The value to assign for active positions (corresponding to indices).
2325    /// * `off_value`: The value to assign for inactive positions.
2326    /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
2327    ///
2328    /// # Returns
2329    ///
2330    /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
2331    ///
2332    /// # Example
2333    /// ```rust
2334    /// use burn_tensor::backend::Backend;
2335    /// use burn_tensor::{Tensor, Float};
2336    /// fn example<B: Backend<FloatElem: From<f32>>>() {
2337    ///     let device = B::Device::default();
2338    ///     let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
2339    ///     // One-hot encoding
2340    ///     let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
2341    ///     println!("{tensor}");
2342    ///     // [[[5.0, 0.0, 0.0],
2343    ///     // [0.0, 0.0, 5.0]],
2344    ///     // [[0.0, 5.0, 0.0],
2345    ///     // [0.0, 0.0, 5.0]]]
2346    /// }
2347    /// ```
2348    pub fn one_hot_fill<const D2: usize>(
2349        self,
2350        num_classes: usize,
2351        on_value: f32,
2352        off_value: f32,
2353        axis: i64,
2354    ) -> Tensor<B, D2, K> {
2355        check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
2356        // Initialize shape from the current tensor dimensions and prepare for modification
2357        let mut shape = self.shape();
2358        let device = self.device();
2359        let rank = self.dims().len();
2360
2361        // Adjust negative axis to a positive index
2362        let axis = if axis < 0 {
2363            axis + rank as i64 + 1
2364        } else {
2365            axis
2366        };
2367
2368        // Ensure axis is within valid range
2369        if axis < 0 || axis > rank as i64 {
2370            panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
2371        }
2372        // Convert the input tensor to integer indices
2373        let indices: Tensor<B, D, Int> =
2374            Tensor::from_data(self.to_data().convert::<i64>(), &device);
2375        // Insert the new dimension for the one-hot representation
2376        shape.insert(axis as usize, num_classes);
2377        // Adjust indices to valid range and handle invalid indices
2378        let adjusted_indices = indices
2379            .clone()
2380            .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
2381            .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
2382        // Unsqueeze the indices tensor along the specified axis
2383        let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
2384
2385        // Initialize the output tensor with the off_value
2386        let output = Tensor::full(shape.clone(), off_value, &device);
2387
2388        // Prepare scatter tensor for on_value and off_value adjustments
2389        let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
2390            - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
2391
2392        // Scatter on_value at the appropriate indices to create the one-hot representation
2393        output.scatter(axis as usize, indices_unsqueezed, scatter_on_values)
2394    }
2395
2396    /// Applies the matrix multiplication operation.
2397    ///
2398    /// ```math
2399    /// C = AB
2400    /// ```
2401    pub fn matmul(self, other: Self) -> Self {
2402        check!(TensorCheck::matmul(&self, &other));
2403        Tensor::new(K::matmul(self.primitive, other.primitive))
2404    }
2405}
2406
2407impl<B, K> Tensor<B, 1, K>
2408where
2409    B: Backend,
2410    K: Numeric<B>,
2411    K::Elem: Element,
2412{
2413    /// Calculates the dot product with another tensor.
2414    ///
2415    /// `y = x2.dot(x1)`
2416    ///
2417    /// # Arguments
2418    ///
2419    /// * `other` - The tensor to compute dot product with.
2420    ///
2421    /// # Notes
2422    ///
2423    /// Both tensors must have the same number of elements.
2424    ///
2425    /// # Example
2426    ///
2427    /// ```rust
2428    /// use burn_tensor::backend::Backend;
2429    /// use burn_tensor::{Tensor, Shape};
2430    ///
2431    /// fn example<B: Backend>() {
2432    ///    let device = B::Device::default();
2433    ///    let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
2434    ///    let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
2435    ///    let tensor = tensor1.dot(tensor2);
2436    ///    println!("{tensor}");
2437    ///    // [4]
2438    /// }
2439    /// ```
2440    pub fn dot(self, other: Self) -> Self {
2441        self.mul(other).sum()
2442    }
2443}
2444
2445impl<B, K> Tensor<B, 2, K>
2446where
2447    B: Backend,
2448    K: Numeric<B>,
2449    K::Elem: Element,
2450{
2451    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2452    ///
2453    /// # Arguments
2454    ///
2455    /// * `size` - The size of the square matrix.
2456    pub fn eye(size: usize, device: &B::Device) -> Self {
2457        let dtype = K::Elem::dtype();
2458        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2459        let ones = K::ones([1, size].into(), device, dtype);
2460        let zeros = K::zeros([size, size].into(), device, dtype);
2461
2462        Self::new(K::scatter(0, zeros, indices.primitive, ones))
2463    }
2464}
2465
2466/// Trait that list all operations that can be applied on all numerical tensors.
2467///
2468/// # Warnings
2469///
2470/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2471pub trait Numeric<B: Backend>: BasicOps<B>
2472where
2473    Self::Elem: Element,
2474{
2475    /// Adds two tensors together.
2476    ///
2477    /// # Arguments
2478    ///
2479    /// * `lhs` - The left hand side tensor.
2480    /// * `rhs` - The right hand side tensor.
2481    ///
2482    /// # Returns
2483    ///
2484    /// The sum of the two tensors.
2485    ///
2486    /// # Remarks
2487    ///
2488    /// This is a low-level function used internally by the library to call different backend functions
2489    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2490    /// or use this function directly.
2491    ///
2492    /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
2493    /// which is more high-level and designed for public use.
2494    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2495
2496    /// Adds a scalar to a tensor element-wise.
2497    ///
2498    /// # Arguments
2499    ///
2500    /// * `lhs` - The left hand side tensor.
2501    /// * `rhs` - The right hand side scalar.
2502    ///
2503    /// # Returns
2504    ///
2505    /// The sum of the tensor and the scalar.
2506    ///
2507    /// # Remarks
2508    ///
2509    /// This is a low-level function used internally by the library to call different backend functions
2510    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2511    /// or use this function directly.
2512    ///
2513    /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
2514    /// which is more high-level and designed for public use.
2515    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2516
2517    /// Subtracts two tensors.
2518    ///
2519    /// # Arguments
2520    ///
2521    /// * `lhs` - The left hand side tensor.
2522    /// * `rhs` - The right hand side tensor.
2523    ///
2524    /// # Returns
2525    ///
2526    /// The difference of the two tensors.
2527    ///
2528    /// # Remarks
2529    ///
2530    /// This is a low-level function used internally by the library to call different backend functions
2531    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2532    /// or use this function directly.
2533    ///
2534    /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
2535    /// which is more high-level and designed for public use.
2536    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2537
2538    /// Subtracts a scalar from a tensor element-wise.
2539    ///
2540    /// # Arguments
2541    ///
2542    /// * `lhs` - The left hand side tensor.
2543    /// * `rhs` - The right hand side scalar.
2544    ///
2545    /// # Returns
2546    ///
2547    /// The difference of the tensor and the scalar.
2548    ///
2549    /// # Remarks
2550    ///
2551    /// This is a low-level function used internally by the library to call different backend functions
2552    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2553    /// or use this function directly.
2554    ///
2555    /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
2556    /// which is more high-level and designed for public use.
2557    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2558
2559    /// Divides two tensors.
2560    ///
2561    /// # Arguments
2562    ///
2563    /// * `lhs` - The left hand side tensor.
2564    /// * `rhs` - The right hand side tensor.
2565    ///
2566    /// # Returns
2567    ///
2568    /// The quotient of the two tensors.
2569    ///
2570    /// # Remarks
2571    ///
2572    /// This is a low-level function used internally by the library to call different backend functions
2573    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2574    /// or use this function directly.
2575    ///
2576    /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
2577    /// which is more high-level and designed for public use.
2578    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2579
2580    /// Divides a tensor by a scalar element-wise.
2581    ///
2582    /// # Arguments
2583    ///
2584    /// * `lhs` - The left hand side tensor.
2585    /// * `rhs` - The right hand side scalar.
2586    ///
2587    /// # Returns
2588    ///
2589    /// The quotient of the tensor and the scalar.
2590    ///
2591    /// # Remarks
2592    ///
2593    /// This is a low-level function used internally by the library to call different backend functions
2594    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2595    /// or use this function directly.
2596    ///
2597    /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
2598    /// which is more high-level and designed for public use.
2599    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2600
2601    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2602    /// less than that of the divisor.
2603    ///
2604    /// # Arguments
2605    ///
2606    /// * `lhs` - The dividend.
2607    /// * `rhs` - The divisor.
2608    ///
2609    /// # Returns
2610    ///
2611    /// The modulo of the input tensor with the divisor.
2612    ///
2613    /// # Remarks
2614    ///
2615    /// This is a low-level function used internally by the library to call different backend functions
2616    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2617    /// or use this function directly.
2618    ///
2619    /// For performing the modulo operation, users should prefer the [Tensor::remainder](Tensor::remainder) function,
2620    /// which is more high-level and designed for public use.
2621    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2622
2623    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2624    /// less than that of the divisor.
2625    ///
2626    /// # Arguments
2627    ///
2628    /// * `lhs` - The dividend.
2629    /// * `rhs` - The divisor.
2630    ///
2631    /// # Returns
2632    ///
2633    /// The modulo of the input tensor with the divisor.
2634    ///
2635    /// # Remarks
2636    ///
2637    /// This is a low-level function used internally by the library to call different backend functions
2638    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2639    /// or use this function directly.
2640    ///
2641    /// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
2642    /// which is more high-level and designed for public use.
2643    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2644
2645    /// Multiplies two tensors.
2646    ///
2647    /// # Arguments
2648    ///
2649    /// * `lhs` - The left hand side tensor.
2650    /// * `rhs` - The right hand side tensor.
2651    ///
2652    /// # Returns
2653    ///
2654    /// The product of the two tensors.
2655    ///
2656    /// # Remarks
2657    ///
2658    /// This is a low-level function used internally by the library to call different backend functions
2659    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2660    /// or use this function directly.
2661    ///
2662    /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
2663    /// which is more high-level and designed for public use.
2664    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2665
2666    /// Multiplies a tensor by a scalar element-wise.
2667    ///
2668    /// # Arguments
2669    ///
2670    /// * `lhs` - The left hand side tensor.
2671    /// * `rhs` - The right hand side scalar.
2672    ///
2673    /// # Returns
2674    ///
2675    /// The product of the tensor and the scalar.
2676    ///
2677    /// # Remarks
2678    ///
2679    /// This is a low-level function used internally by the library to call different backend functions
2680    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2681    /// or use this function directly.
2682    ///
2683    /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
2684    /// which is more high-level and designed for public use.
2685    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2686
2687    /// Negates a tensor.
2688    ///
2689    /// # Arguments
2690    ///
2691    /// * `tensor` - The tensor to negate.
2692    ///
2693    /// # Returns
2694    ///
2695    /// The negated tensor.
2696    ///
2697    /// # Remarks
2698    ///
2699    /// This is a low-level function used internally by the library to call different backend functions
2700    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2701    /// or use this function directly.
2702    ///
2703    /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
2704    /// which is more high-level and designed for public use.
2705    fn neg(tensor: Self::Primitive) -> Self::Primitive;
2706
2707    /// Returns the signs of the elements of a tensor.
2708    ///
2709    /// # Arguments
2710    ///
2711    /// * `tensor` - The tensor.
2712    ///
2713    /// # Returns
2714    ///
2715    /// The signs of the elements of the tensor.
2716    ///
2717    /// # Remarks
2718    ///
2719    /// This is a low-level function used internally by the library to call different backend functions
2720    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2721    /// or use this function directly.
2722    ///
2723    /// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
2724    /// which is more high-level and designed for public use.
2725    fn sign(tensor: Self::Primitive) -> Self::Primitive;
2726
2727    /// Creates a tensor filled with zeros.
2728    ///
2729    /// # Arguments
2730    ///
2731    /// * `shape` - The shape of the tensor.
2732    /// * `device` - The device on which the tensor will be allocated.
2733    /// * `dtype` - The target data type.
2734    ///
2735    /// # Returns
2736    ///
2737    /// The tensor filled with zeros.
2738    ///
2739    /// # Remarks
2740    ///
2741    /// This is a low-level function used internally by the library to call different backend functions
2742    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2743    /// or use this function directly.
2744    ///
2745    /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
2746    /// which is more high-level and designed for public use.
2747    fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
2748
2749    /// Creates a tensor filled with ones.
2750    ///
2751    /// # Arguments
2752    ///
2753    /// * `shape` - The shape of the tensor.
2754    /// * `device` - The device on which the tensor will be allocated.
2755    /// * `dtype` - The target data type.
2756    ///
2757    /// # Returns
2758    ///
2759    /// The tensor filled with ones.
2760    ///
2761    /// # Remarks
2762    ///
2763    /// This is a low-level function used internally by the library to call different backend functions
2764    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2765    /// or use this function directly.
2766    ///
2767    /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
2768    /// which is more high-level and designed for public use.
2769    fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
2770
2771    /// Sums all the elements of the tensor.
2772    ///
2773    /// # Arguments
2774    ///
2775    /// * `tensor` - The tensor to sum.
2776    ///
2777    /// # Returns
2778    ///
2779    /// The sum of all the elements of the tensor.
2780    ///
2781    /// # Remarks
2782    ///
2783    /// This is a low-level function used internally by the library to call different backend functions
2784    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2785    /// or use this function directly.
2786    ///
2787    /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
2788    /// which is more high-level and designed for public use.
2789    fn sum(tensor: Self::Primitive) -> Self::Primitive;
2790
2791    /// Sums all the elements of the tensor along a dimension.
2792    ///
2793    /// # Arguments
2794    ///
2795    /// * `tensor` - The tensor to sum.
2796    /// * `dim` - The dimension along which to sum.
2797    ///
2798    /// # Returns
2799    ///
2800    /// The sum of all the elements of the tensor along the specified dimension.
2801    ///
2802    /// # Remarks
2803    ///
2804    /// This is a low-level function used internally by the library to call different backend functions
2805    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2806    /// or use this function directly.
2807    ///
2808    /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
2809    /// which is more high-level and designed for public use.
2810    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2811
2812    /// Computes the product of all the elements of the tensor.
2813    ///
2814    /// # Arguments
2815    ///
2816    /// * `tensor` - The tensor to compute the product of.
2817    ///
2818    /// # Returns
2819    ///
2820    /// The product of all the elements of the tensor.
2821    ///
2822    /// # Remarks
2823    ///
2824    /// This is a low-level function used internally by the library to call different backend functions
2825    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2826    /// or use this function directly.
2827    ///
2828    /// For computing the product of all the elements of a tensor, users should prefer the
2829    /// [Tensor::prod](Tensor::prod) function,
2830    /// which is more high-level and designed for public use.
2831    fn prod(tensor: Self::Primitive) -> Self::Primitive;
2832
2833    /// Computes the product of all the elements of the tensor along a dimension.
2834    ///
2835    /// # Arguments
2836    ///
2837    /// * `tensor` - The tensor to compute the product of.
2838    /// * `dim` - The dimension along which to compute the product.
2839    ///
2840    /// # Returns
2841    ///
2842    /// The product of all the elements of the tensor along the specified dimension.
2843    ///
2844    /// # Remarks
2845    ///
2846    /// This is a low-level function used internally by the library to call different backend functions
2847    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2848    /// or use this function directly.
2849    ///
2850    /// For computing the product of all the elements of a tensor along a dimension, users should
2851    /// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
2852    /// which is more high-level and designed for public use.
2853    ///
2854    ///
2855    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2856
2857    /// Computes the mean of all the elements of the tensor.
2858    ///
2859    /// # Arguments
2860    ///
2861    /// * `tensor` - The tensor to compute the mean of.
2862    ///
2863    /// # Returns
2864    ///
2865    /// The mean of all the elements of the tensor.
2866    ///
2867    /// # Remarks
2868    ///
2869    /// This is a low-level function used internally by the library to call different backend functions
2870    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2871    /// or use this function directly.
2872    ///
2873    /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
2874    /// which is more high-level and designed for public use.
2875    fn mean(tensor: Self::Primitive) -> Self::Primitive;
2876
2877    /// Computes the mean of all the elements of the tensor along a dimension.
2878    ///
2879    /// # Arguments
2880    ///
2881    /// * `tensor` - The tensor to compute the mean of.
2882    /// * `dim` - The dimension along which to compute the mean.
2883    ///
2884    /// # Returns
2885    ///
2886    /// The mean of all the elements of the tensor along the specified dimension.
2887    ///
2888    /// # Remarks
2889    ///
2890    /// This is a low-level function used internally by the library to call different backend functions
2891    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2892    /// or use this function directly.
2893    ///
2894    /// For computing the mean of all the elements of a tensor along a dimension, users should prefer
2895    /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
2896    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2897
2898    /// Computes the cumulative sum of elements along a dimension.
2899    ///
2900    /// # Arguments
2901    ///
2902    /// * `tensor` - The tensor to compute the cumulative sum of.
2903    /// * `dim` - The dimension along which to compute the cumulative sum.
2904    ///
2905    /// # Returns
2906    ///
2907    /// A tensor with the same shape as the input tensor, where each element is the cumulative sum
2908    /// of all elements up to and including that position along the specified dimension.
2909    ///
2910    /// # Remarks
2911    ///
2912    /// This is a low-level function used internally by the library to call different backend functions
2913    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2914    /// or use this function directly.
2915    ///
2916    /// For computing the cumulative sum of elements along a dimension, users should prefer
2917    /// the [Tensor::cumsum](Tensor::cumsum) function, which is more high-level and designed for public use.
2918    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2919
2920    /// Computes the cumulative product of elements along a dimension.
2921    ///
2922    /// # Arguments
2923    ///
2924    /// * `tensor` - The tensor to compute the cumulative product of.
2925    /// * `dim` - The dimension along which to compute the cumulative product.
2926    ///
2927    /// # Returns
2928    ///
2929    /// A tensor with the same shape as the input tensor, where each element is the cumulative product
2930    /// of all elements up to and including that position along the specified dimension.
2931    ///
2932    /// # Remarks
2933    ///
2934    /// This is a low-level function used internally by the library to call different backend functions
2935    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2936    /// or use this function directly.
2937    ///
2938    /// For computing the cumulative product of elements along a dimension, users should prefer
2939    /// the [Tensor::cumprod](Tensor::cumprod) function, which is more high-level and designed for public use.
2940    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2941
2942    /// Computes the cumulative minimum of elements along a dimension.
2943    ///
2944    /// # Arguments
2945    ///
2946    /// * `tensor` - The tensor to compute the cumulative minimum of.
2947    /// * `dim` - The dimension along which to compute the cumulative minimum.
2948    ///
2949    /// # Returns
2950    ///
2951    /// A tensor with the same shape as the input tensor, where each element is the minimum
2952    /// of all elements up to and including that position along the specified dimension.
2953    ///
2954    /// # Remarks
2955    ///
2956    /// This is a low-level function used internally by the library to call different backend functions
2957    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2958    /// or use this function directly.
2959    ///
2960    /// For computing the cumulative minimum of elements along a dimension, users should prefer
2961    /// the [Tensor::cummin](Tensor::cummin) function, which is more high-level and designed for public use.
2962    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2963
2964    /// Computes the cumulative maximum of elements along a dimension.
2965    ///
2966    /// # Arguments
2967    ///
2968    /// * `tensor` - The tensor to compute the cumulative maximum of.
2969    /// * `dim` - The dimension along which to compute the cumulative maximum.
2970    ///
2971    /// # Returns
2972    ///
2973    /// A tensor with the same shape as the input tensor, where each element is the maximum
2974    /// of all elements up to and including that position along the specified dimension.
2975    ///
2976    /// # Remarks
2977    ///
2978    /// This is a low-level function used internally by the library to call different backend functions
2979    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2980    /// or use this function directly.
2981    ///
2982    /// For computing the cumulative maximum of elements along a dimension, users should prefer
2983    /// the [Tensor::cummax](Tensor::cummax) function, which is more high-level and designed for public use.
2984    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2985    /// Element-wise equality between two tensors.
2986    ///
2987    /// # Arguments
2988    ///
2989    /// * `lhs` - The left hand side tensor.
2990    /// * `rhs` - The right hand side tensor.
2991    ///
2992    /// # Returns
2993    ///
2994    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2995    /// corresponding elements of the input tensors are equal, and false otherwise.
2996    ///
2997    /// # Remarks
2998    ///
2999    /// This is a low-level function used internally by the library to call different backend functions
3000    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3001    /// or use this function directly.
3002    ///
3003    /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem)
3004    /// function, which is more high-level and designed for public use.
3005    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3006
3007    /// Element-wise non-equality between two tensors.
3008    ///
3009    /// # Arguments
3010    ///
3011    /// * `lhs` - The left hand side tensor.
3012    /// * `rhs` - The right hand side tensor.
3013    ///
3014    /// # Returns
3015    ///
3016    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3017    /// corresponding elements of the input tensors are equal, and false otherwise.
3018    ///
3019    /// # Remarks
3020    ///
3021    /// This is a low-level function used internally by the library to call different backend functions
3022    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3023    /// or use this function directly.
3024    ///
3025    /// For element-wise non-equality between two tensors, users should prefer the [Tensor::not_equal_elem](Tensor::not_equal_elem)
3026    /// function, which is more high-level and designed for public use.
3027    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3028
3029    /// Element-wise greater than comparison between two tensors.
3030    ///
3031    /// # Arguments
3032    ///
3033    /// * `lhs` - The left hand side tensor.
3034    /// * `rhs` - The right hand side tensor.
3035    ///
3036    /// # Returns
3037    ///
3038    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3039    /// corresponding element of the left hand side tensor is greater than the corresponding element
3040    /// of the right hand side tensor, and false otherwise.
3041    ///
3042    /// # Remarks
3043    ///
3044    /// This is a low-level function used internally by the library to call different backend functions
3045    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3046    /// or use this function directly.
3047    ///
3048    /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
3049    /// which is more high-level and designed for public use.
3050    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3051
3052    /// Element-wise greater than comparison between a tensor and a scalar.
3053    ///
3054    /// # Arguments
3055    ///
3056    /// * `lhs` - The left hand side tensor.
3057    /// * `rhs` - The right hand side scalar.
3058    ///
3059    /// # Returns
3060    ///
3061    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3062    /// corresponding element of the left hand side tensor is greater than the right hand side
3063    /// scalar, and false otherwise.
3064    ///
3065    /// # Remarks
3066    ///
3067    /// This is a low-level function used internally by the library to call different backend functions
3068    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3069    /// or use this function directly.
3070    ///
3071    /// For element-wise greater than comparison between a tensor and a scalar, users should prefer
3072    /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
3073    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3074
3075    /// Element-wise greater than or equal comparison between two tensors.
3076    ///
3077    /// # Arguments
3078    ///
3079    /// * `lhs` - The left hand side tensor.
3080    /// * `rhs` - The right hand side tensor.
3081    ///
3082    /// # Returns
3083    ///
3084    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3085    /// corresponding element of the left hand side tensor is greater than or equal to the
3086    /// corresponding element of the right hand side tensor, and false otherwise.
3087    ///
3088    /// # Remarks
3089    ///
3090    /// This is a low-level function used internally by the library to call different backend functions
3091    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3092    /// or use this function directly.
3093    ///
3094    /// For element-wise greater than or equal comparison between two tensors, users should prefer
3095    /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
3096    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3097
3098    /// Element-wise greater than or equal comparison between a tensor and a scalar.
3099    ///
3100    /// # Arguments
3101    ///
3102    /// * `lhs` - The left hand side tensor.
3103    /// * `rhs` - The right hand side scalar.
3104    ///
3105    /// # Returns
3106    ///
3107    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3108    /// corresponding element of the left hand side tensor is greater than or equal to the right
3109    /// hand side scalar, and false otherwise.
3110    ///
3111    /// # Remarks
3112    ///
3113    /// This is a low-level function used internally by the library to call different backend functions
3114    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3115    /// or use this function directly.
3116    ///
3117    /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
3118    /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
3119    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3120
3121    /// Element-wise less than comparison between two tensors.
3122    ///
3123    /// # Arguments
3124    ///
3125    /// * `lhs` - The left hand side tensor.
3126    /// * `rhs` - The right hand side tensor.
3127    ///
3128    /// # Returns
3129    ///
3130    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3131    /// corresponding element of the left hand side tensor is less than the corresponding element of
3132    /// the right hand side tensor, and false otherwise.
3133    ///
3134    /// # Remarks
3135    ///
3136    /// This is a low-level function used internally by the library to call different backend functions
3137    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3138    /// or use this function directly.
3139    ///
3140    /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
3141    /// which is more high-level and designed for public use.
3142    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3143
3144    /// Element-wise less than comparison between a tensor and a scalar.
3145    ///
3146    /// # Arguments
3147    ///
3148    /// * `lhs` - The left hand side tensor.
3149    /// * `rhs` - The right hand side scalar.
3150    ///
3151    /// # Returns
3152    ///
3153    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3154    /// corresponding element of the left hand side tensor is less than the right hand side scalar,
3155    /// and false otherwise.
3156    ///
3157    /// # Remarks
3158    ///
3159    /// This is a low-level function used internally by the library to call different backend functions
3160    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3161    /// or use this function directly.
3162    ///
3163    /// For element-wise less than comparison between a tensor and a scalar, users should prefer
3164    /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
3165    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3166
3167    /// Element-wise less than or equal comparison between two tensors.
3168    ///
3169    /// # Arguments
3170    ///
3171    /// * `lhs` - The left hand side tensor.
3172    /// * `rhs` - The right hand side tensor.
3173    ///
3174    /// # Returns
3175    ///
3176    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3177    /// corresponding element of the left hand side tensor is less than or equal to the corresponding
3178    /// element of the right hand side tensor, and false otherwise.
3179    ///
3180    /// # Remarks
3181    ///
3182    /// This is a low-level function used internally by the library to call different backend functions
3183    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3184    /// or use this function directly.
3185    ///
3186    /// For element-wise less than or equal comparison between two tensors, users should prefer
3187    /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
3188    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3189
3190    /// Element-wise less than or equal comparison between a tensor and a scalar.
3191    ///
3192    /// # Arguments
3193    ///
3194    /// * `lhs` - The left hand side tensor.
3195    /// * `rhs` - The right hand side scalar.
3196    ///
3197    /// # Returns
3198    ///
3199    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3200    /// corresponding element of the left hand side tensor is less than or equal to the right hand
3201    /// side scalar, and false otherwise.
3202    ///
3203    /// # Remarks
3204    ///
3205    /// This is a low-level function used internally by the library to call different backend functions
3206    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3207    /// or use this function directly.
3208    ///
3209    /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
3210    /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
3211    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3212
3213    /// Selects elements from a tensor based on a boolean mask.
3214    ///
3215    /// # Arguments
3216    ///
3217    /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
3218    /// * `mask` - The boolean mask to use for selecting elements.
3219    /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
3220    ///
3221    /// # Returns
3222    ///
3223    /// A tensor with the same shape as the input tensors, where each element is taken from the
3224    /// corresponding element of the left hand side tensor if the corresponding element of the mask
3225    /// is true, and from the corresponding element of the right hand side tensor otherwise.
3226    ///
3227    /// # Remarks
3228    ///
3229    /// This is a low-level function used internally by the library to call different backend functions
3230    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3231    /// or use this function directly.
3232    ///
3233    /// For selecting elements from a tensor based on a boolean mask, users should prefer the
3234    /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
3235    fn mask_where(
3236        tensor: Self::Primitive,
3237        mask: B::BoolTensorPrimitive,
3238        source: Self::Primitive,
3239    ) -> Self::Primitive;
3240
3241    /// Fills elements of a tensor based on a boolean mask.
3242    ///
3243    /// # Arguments
3244    ///
3245    /// * `tensor` - The tensor where will be overwritten with the value
3246    ///   when the corresponding element of the mask is true.
3247    /// * `mask` - The boolean mask to use for filling elements.
3248    /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
3249    ///
3250    /// # Returns
3251    ///
3252    /// A tensor with the same shape as the input tensors, where each element is taken from the
3253    /// corresponding element unmodified if the corresponding element of the mask is false, and
3254    /// filled with the value otherwise.
3255    ///
3256    /// # Remarks
3257    ///
3258    /// This is a low-level function used internally by the library to call different backend functions
3259    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3260    /// or use this function directly.
3261    ///
3262    /// For filling elements of a tensor based on a boolean mask, users should prefer the
3263    /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
3264    fn mask_fill(
3265        tensor: Self::Primitive,
3266        mask: B::BoolTensorPrimitive,
3267        value: Self::Elem,
3268    ) -> Self::Primitive;
3269
3270    /// Gathers elements from a tensor along an axis.
3271    ///
3272    /// # Arguments
3273    ///
3274    /// * `dim` - The axis along which to gather elements.
3275    /// * `tensor` - The tensor to gather elements from.
3276    /// * `indices` - The indices of the elements to gather.
3277    ///
3278    /// # Returns
3279    ///
3280    /// A tensor with the same shape as the input tensor, where each element is taken from the
3281    /// corresponding element of the input tensor at the corresponding index along the specified axis.
3282    ///
3283    /// # Remarks
3284    ///
3285    /// This is a low-level function used internally by the library to call different backend functions
3286    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3287    /// or use this function directly.
3288    ///
3289    /// For gathering elements from a tensor along an axis, users should prefer the
3290    /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
3291    fn gather(
3292        dim: usize,
3293        tensor: Self::Primitive,
3294        indices: B::IntTensorPrimitive,
3295    ) -> Self::Primitive;
3296
3297    /// Scatters elements into a tensor along an axis.
3298    ///
3299    /// # Arguments
3300    ///
3301    /// * `dim` - The axis along which to scatter elements.
3302    /// * `tensor` - The tensor to scatter elements into.
3303    /// * `indices` - The indices of the elements to scatter.
3304    /// * `values` - The values to scatter into the tensor.
3305    ///
3306    /// # Returns
3307    ///
3308    /// A tensor with the same shape as the input tensor, where each element is taken from the
3309    /// corresponding element of the input tensor at the corresponding index along the specified axis,
3310    /// except for the elements at the specified indices, which are taken from the corresponding
3311    /// element of the values tensor.
3312    ///
3313    /// # Remarks
3314    ///
3315    /// This is a low-level function used internally by the library to call different backend functions
3316    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3317    /// or use this function directly.
3318    ///
3319    /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
3320    /// which is more high-level and designed for public use.
3321    fn scatter(
3322        dim: usize,
3323        tensor: Self::Primitive,
3324        indices: B::IntTensorPrimitive,
3325        values: Self::Primitive,
3326    ) -> Self::Primitive;
3327
3328    /// Gets the indices of the maximum elements of a tensor along an axis.
3329    ///
3330    /// # Arguments
3331    ///
3332    /// * `dim` - The axis along which to get the indices of the maximum elements.
3333    /// * `tensor` - The tensor to get the indices of the maximum elements from.
3334    ///
3335    /// # Returns
3336    ///
3337    /// A tensor with the same shape as the input tensor, where each element is the index of the
3338    /// maximum element of the input tensor at the corresponding index along the specified axis.
3339    ///
3340    /// # Remarks
3341    ///
3342    /// This is a low-level function used internally by the library to call different backend functions
3343    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3344    /// or use this function directly.
3345    ///
3346    /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
3347    /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
3348    fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3349
3350    /// Gets the indices of the minimum elements of a tensor along an axis.
3351    ///
3352    /// # Arguments
3353    ///
3354    /// * `dim` - The axis along which to get the indices of the minimum elements.
3355    /// * `tensor` - The tensor to get the indices of the minimum elements from.
3356    ///
3357    /// # Returns
3358    ///
3359    /// A tensor with the same shape as the input tensor, where each element is the index of the
3360    /// minimum element of the input tensor at the corresponding index along the specified axis.
3361    ///
3362    /// # Remarks
3363    ///
3364    /// This is a low-level function used internally by the library to call different backend functions
3365    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3366    /// or use this function directly.
3367    ///
3368    /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
3369    /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
3370    fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3371
3372    /// Gets the maximum elements of a tensor along an axis.
3373    ///
3374    /// # Arguments
3375    ///
3376    /// * `dim` - The axis along which to get the maximum elements.
3377    ///
3378    /// # Returns
3379    ///
3380    /// A single-element tensor containing the maximum element of the input tensor.
3381    ///
3382    /// # Remarks
3383    ///
3384    /// This is a low-level function used internally by the library to call different backend functions
3385    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3386    /// or use this function directly.
3387    ///
3388    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3389    /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
3390    fn max(tensor: Self::Primitive) -> Self::Primitive;
3391
3392    /// Gets the maximum elements of a tensor along an axis.
3393    ///
3394    /// # Arguments
3395    ///
3396    /// * `tensor` - The tensor to get the maximum elements from.
3397    /// * `dim` - The axis along which to get the maximum elements.
3398    ///
3399    /// # Returns
3400    ///
3401    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3402    /// Each element is the maximum element of the corresponding input dim.
3403    ///
3404    /// # Remarks
3405    ///
3406    /// This is a low-level function used internally by the library to call different backend functions
3407    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3408    /// or use this function directly.
3409    ///
3410    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3411    /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
3412    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3413
3414    /// Gets the maximum elements of a tensor along an axis.
3415    ///
3416    /// # Arguments
3417    ///
3418    /// * `tensor` - The tensor to get the maximum elements from.
3419    /// * `dim` - The axis along which to get the maximum elements.
3420    ///
3421    /// # Returns
3422    ///
3423    /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
3424    /// as the input tensor, where each element is the index of the maximum element of the input tensor
3425    /// at the corresponding index along the specified axis.
3426    ///
3427    /// # Remarks
3428    ///
3429    /// This is a low-level function used internally by the library to call different backend functions
3430    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3431    /// or use this function directly.
3432    ///
3433    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3434    /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
3435    fn max_dim_with_indices(
3436        tensor: Self::Primitive,
3437        dim: usize,
3438    ) -> (Self::Primitive, B::IntTensorPrimitive);
3439
3440    /// Gets the maximum elements of a tensor along an axis.
3441    ///
3442    /// # Arguments
3443    ///
3444    /// * `dim` - The axis along which to get the maximum elements.
3445    ///
3446    /// # Returns
3447    ///
3448    /// A single-element tensor containing the maximum absolute element of the input tensor.
3449    ///
3450    /// # Remarks
3451    ///
3452    /// This is a low-level function used internally by the library to call different backend functions
3453    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3454    /// or use this function directly.
3455    ///
3456    /// For getting the maximum absolute elements of a tensor, users should prefer the
3457    /// [Tensor::max_abs](Tensor::max_abs) function, which is more high-level and designed for public use.
3458    fn max_abs(tensor: Self::Primitive) -> Self::Primitive;
3459
3460    /// Gets the maximum elements of a tensor along an axis.
3461    ///
3462    /// # Arguments
3463    ///
3464    /// * `tensor` - The tensor to get the maximum elements from.
3465    /// * `dim` - The axis along which to get the maximum elements.
3466    ///
3467    /// # Returns
3468    ///
3469    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3470    /// Each element is the maximum absolute element of the corresponding input dim.
3471    ///
3472    /// # Remarks
3473    ///
3474    /// This is a low-level function used internally by the library to call different backend functions
3475    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3476    /// or use this function directly.
3477    ///
3478    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3479    /// [Tensor::max_abs_dim](Tensor::max_abs_dim) function, which is more high-level and designed for public use.
3480    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3481
3482    /// Gets the minimum elements of a tensor along an axis.
3483    ///
3484    /// # Arguments
3485    ///
3486    /// * `tensor` - The tensor to get the minimum elements from.
3487    ///
3488    /// # Returns
3489    ///
3490    /// A single-element tensor containing the minimum element of the input tensor.
3491    ///
3492    /// # Remarks
3493    ///
3494    /// This is a low-level function used internally by the library to call different backend functions
3495    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3496    /// or use this function directly.
3497    ///
3498    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3499    /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
3500    fn min(tensor: Self::Primitive) -> Self::Primitive;
3501
3502    /// Gets the minimum elements of a tensor along an axis.
3503    ///
3504    /// # Arguments
3505    ///
3506    /// * `tensor` - The tensor to get the minimum elements from.
3507    /// * `dim` - The axis along which to get the minimum elements.
3508    ///
3509    /// # Returns
3510    ///
3511    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3512    /// Each element is the minimum element of the corresponding input dim.
3513    ///
3514    /// # Remarks
3515    ///
3516    /// This is a low-level function used internally by the library to call different backend functions
3517    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3518    /// or use this function directly.
3519    ///
3520    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3521    /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
3522    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3523
3524    /// Gets the minimum elements and indices of a tensor along an axis.
3525    ///
3526    /// # Arguments
3527    ///
3528    /// * `tensor` - The tensor to get the minimum elements from.
3529    ///
3530    /// # Returns
3531    ///
3532    /// A tensor with the same shape as the input tensor and corresponding indices, where
3533    /// each element is the minimum element of the input tensor at the corresponding index
3534    /// along the specified axis.
3535    ///
3536    /// # Remarks
3537    ///
3538    /// This is a low-level function used internally by the library to call different backend functions
3539    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3540    /// or use this function directly.
3541    ///
3542    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3543    /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
3544    fn min_dim_with_indices(
3545        tensor: Self::Primitive,
3546        dim: usize,
3547    ) -> (Self::Primitive, B::IntTensorPrimitive);
3548
3549    /// Clamp the tensor between the given min and max values.
3550    ///
3551    /// # Arguments
3552    ///
3553    /// * `min` - The minimum value.
3554    /// * `max` - The maximum value.
3555    ///
3556    /// # Returns
3557    ///
3558    /// A new tensor with the values clamped between the given min and max values.
3559    ///
3560    /// # Remarks
3561    ///
3562    /// This is a low-level function used internally by the library to call different backend functions
3563    /// with static dispatch. It is not designed for direct usage by users.
3564    ///
3565    /// For clamping a tensor between the given min and max values, users should prefer the
3566    /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use.
3567    fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
3568
3569    /// Clamps a tensor under a minimum value.
3570    ///
3571    /// # Arguments
3572    ///
3573    /// * `tensor` - The tensor to clamp.
3574    /// * `min` - The minimum value.
3575    ///
3576    /// # Returns
3577    ///
3578    /// A new tensor with the values clamped under the given min value.
3579    ///
3580    /// # Remarks
3581    ///
3582    /// This is a low-level function used internally by the library to call different backend functions
3583    /// with static dispatch. It is not designed for direct usage by users.
3584    ///
3585    /// For clamping a tensor under a minimum value, users should prefer the
3586    /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use.
3587    fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
3588
3589    /// Clamps a tensor over a maximum value.
3590    ///
3591    /// # Arguments
3592    ///
3593    /// * `tensor` - The tensor to clamp.
3594    /// * `max` - The maximum value.
3595    ///
3596    /// # Returns
3597    ///
3598    /// A new tensor with the values clamped over the given max value.
3599    ///
3600    /// # Remarks
3601    ///
3602    /// This is a low-level function used internally by the library to call different backend functions
3603    /// with static dispatch. It is not designed for direct usage by users.
3604    ///
3605    /// For clamping a tensor over a maximum value, users should prefer the
3606    /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
3607    fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
3608
3609    /// Calculate absolute value on all elements of a tensor
3610    ///
3611    /// # Arguments
3612    ///
3613    /// * `tensor` - The tensor to apply abs to.
3614    ///
3615    /// # Returns
3616    ///
3617    /// A tensor with absolute values.
3618    ///
3619    /// # Remarks
3620    ///
3621    /// This is a low-level function used internally by the library to call different backend functions
3622    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3623    /// or use this function directly.
3624    ///
3625    /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
3626    /// which is more high-level and designed for public use.
3627    fn abs(tensor: Self::Primitive) -> Self::Primitive;
3628
3629    /// Element-wise power of a tensor to a float tensor
3630    ///
3631    /// # Arguments
3632    /// * `tensor` - The tensor to apply power to.
3633    /// * `power` - The power to apply to the tensor.
3634    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3635
3636    /// Element-wise power of a tensor
3637    ///
3638    /// # Arguments
3639    /// * `tensor` - The tensor to apply power to.
3640    /// * `power` - The power to apply to the tensor.
3641    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3642
3643    /// Element-wise power of a tensor to a scalar float
3644    ///
3645    /// # Arguments
3646    /// * `tensor` - The tensor to apply power to.
3647    /// * `power` - The power to apply to the tensor.
3648    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3649
3650    /// Element-wise power of a tensor to a scalar int
3651    ///
3652    /// # Arguments
3653    /// * `tensor` - The tensor to apply power to.
3654    /// * `power` - The power to apply to the tensor.
3655    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3656
3657    /// Create a random tensor.
3658    ///
3659    /// # Arguments
3660    ///
3661    /// * `shape` - The shape of the output tensor.
3662    /// * `distribution` - The distribution used to sample.
3663    /// * `device` - The device to use.
3664    ///
3665    /// # Returns
3666    ///
3667    /// A new tensor.
3668    ///
3669    /// # Remarks
3670    ///
3671    /// This is a low-level function used internally by the library to call different backend functions
3672    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3673    /// or use this function directly.
3674    ///
3675    /// Users should prefer the [Tensor::random](Tensor::random) function,
3676    /// which is more high-level and designed for public use.
3677    fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
3678
3679    /// Sort the elements of the input `tensor` by value along a given dimension.
3680    ///
3681    /// This sort is unstable (i.e., may reorder equal elements).
3682    ///
3683    /// # Arguments
3684    ///
3685    /// * `tensor` - The input tensor.
3686    /// * `dim` - The axis along which to sort.
3687    /// * `descending` - The sorting order.
3688    ///
3689    /// # Returns
3690    ///
3691    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
3692    ///
3693    /// # Remarks
3694    /// This is a low-level function used internally by the library to call different backend functions
3695    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3696    /// or use this function directly.
3697    ///
3698    /// Users should prefer the [Tensor::sort](Tensor::sort) function,
3699    /// which is more high-level and designed for public use.
3700    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
3701
3702    /// Sort the elements of the input `tensor` by value along a given dimension.
3703    ///
3704    /// This sort is unstable (i.e., may reorder equal elements).
3705    ///
3706    /// # Arguments
3707    ///
3708    /// * `tensor` - The input tensor.
3709    /// * `dim` - The axis along which to sort.
3710    /// * `descending` - The sorting order.
3711    ///
3712    /// # Returns
3713    ///
3714    /// A tensor with the same shape as the input tensor and corresponding indices, where
3715    /// the elements are sorted by value and the indices map back to the original input tensor.
3716    ///
3717    /// # Remarks
3718    /// This is a low-level function used internally by the library to call different backend functions
3719    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3720    /// or use this function directly.
3721    ///
3722    /// For sorting the elements of a tensor, users should prefer the
3723    /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
3724    /// and designed for public use.
3725    fn sort_with_indices(
3726        tensor: Self::Primitive,
3727        dim: usize,
3728        descending: bool,
3729    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive);
3730
3731    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
3732    ///
3733    /// This sort is unstable (i.e., may reorder equal elements).
3734    ///
3735    /// # Arguments
3736    ///
3737    /// * `tensor` - The input tensor.
3738    /// * `dim` - The axis along which to sort.
3739    /// * `descending` - The sorting order.
3740    ///
3741    /// # Returns
3742    ///
3743    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
3744    ///
3745    /// # Remarks
3746    /// This is a low-level function used internally by the library to call different backend functions
3747    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3748    /// or use this function directly.
3749    ///
3750    /// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
3751    /// which is more high-level and designed for public use.
3752    fn argsort(
3753        tensor: Self::Primitive,
3754        dim: usize,
3755        descending: bool,
3756    ) -> <Int as TensorKind<B>>::Primitive;
3757
3758    /// Applies the matrix multiplication operation.
3759    ///
3760    /// ```math
3761    /// C = AB
3762    /// ```
3763    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3764}
3765
3766impl<B: Backend> Numeric<B> for Int {
3767    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3768        B::int_add(lhs, rhs)
3769    }
3770    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3771        B::int_add_scalar(lhs, rhs.elem())
3772    }
3773    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3774        B::int_sub(lhs, rhs)
3775    }
3776    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3777        B::int_sub_scalar(lhs, rhs.elem())
3778    }
3779    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3780        B::int_div(lhs, rhs)
3781    }
3782    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3783        B::int_div_scalar(lhs, rhs.elem())
3784    }
3785    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3786        B::int_remainder(lhs, rhs)
3787    }
3788    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3789        B::int_remainder_scalar(lhs, rhs.elem())
3790    }
3791    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3792        B::int_mul(lhs, rhs)
3793    }
3794    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3795        B::int_mul_scalar(lhs, rhs.elem())
3796    }
3797    fn neg(tensor: Self::Primitive) -> Self::Primitive {
3798        B::int_neg(tensor)
3799    }
3800    fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3801        B::int_zeros(shape, device, dtype.into())
3802    }
3803    fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3804        B::int_ones(shape, device, dtype.into())
3805    }
3806
3807    fn sum(tensor: Self::Primitive) -> Self::Primitive {
3808        B::int_sum(tensor)
3809    }
3810
3811    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3812        B::int_sum_dim(tensor, dim)
3813    }
3814
3815    fn prod(tensor: Self::Primitive) -> Self::Primitive {
3816        B::int_prod(tensor)
3817    }
3818
3819    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3820        B::int_prod_dim(tensor, dim)
3821    }
3822
3823    fn mean(tensor: Self::Primitive) -> Self::Primitive {
3824        B::int_mean(tensor)
3825    }
3826    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3827        B::int_mean_dim(tensor, dim)
3828    }
3829    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3830        B::int_cumsum(tensor, dim)
3831    }
3832    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3833        B::int_cumprod(tensor, dim)
3834    }
3835
3836    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3837        B::int_cummin(tensor, dim)
3838    }
3839
3840    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3841        B::int_cummax(tensor, dim)
3842    }
3843
3844    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3845        B::int_equal_elem(lhs, rhs)
3846    }
3847    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3848        B::int_not_equal_elem(lhs, rhs)
3849    }
3850    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3851        B::int_greater(lhs, rhs)
3852    }
3853
3854    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3855        B::int_greater_elem(lhs, rhs)
3856    }
3857
3858    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3859        B::int_greater_equal(lhs, rhs)
3860    }
3861
3862    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3863        B::int_greater_equal_elem(lhs, rhs)
3864    }
3865
3866    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3867        B::int_lower(lhs, rhs)
3868    }
3869
3870    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3871        B::int_lower_elem(lhs, rhs)
3872    }
3873
3874    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3875        B::int_lower_equal(lhs, rhs)
3876    }
3877
3878    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3879        B::int_lower_equal_elem(lhs, rhs)
3880    }
3881
3882    fn mask_where(
3883        tensor: Self::Primitive,
3884        mask: B::BoolTensorPrimitive,
3885        source: Self::Primitive,
3886    ) -> Self::Primitive {
3887        B::int_mask_where(tensor, mask, source)
3888    }
3889
3890    fn mask_fill(
3891        tensor: Self::Primitive,
3892        mask: B::BoolTensorPrimitive,
3893        value: Self::Elem,
3894    ) -> Self::Primitive {
3895        B::int_mask_fill(tensor, mask, value)
3896    }
3897
3898    fn gather(
3899        dim: usize,
3900        tensor: Self::Primitive,
3901        indices: B::IntTensorPrimitive,
3902    ) -> Self::Primitive {
3903        B::int_gather(dim, tensor, indices)
3904    }
3905
3906    fn scatter(
3907        dim: usize,
3908        tensor: Self::Primitive,
3909        indices: B::IntTensorPrimitive,
3910        values: Self::Primitive,
3911    ) -> Self::Primitive {
3912        B::int_scatter(dim, tensor, indices, values)
3913    }
3914
3915    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3916        B::int_argmax(tensor, dim)
3917    }
3918
3919    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3920        B::int_argmin(tensor, dim)
3921    }
3922
3923    fn max(tensor: Self::Primitive) -> Self::Primitive {
3924        B::int_max(tensor)
3925    }
3926
3927    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3928        B::int_max_dim(tensor, dim)
3929    }
3930
3931    fn max_dim_with_indices(
3932        tensor: Self::Primitive,
3933        dim: usize,
3934    ) -> (Self::Primitive, IntTensor<B>) {
3935        B::int_max_dim_with_indices(tensor, dim)
3936    }
3937
3938    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
3939        B::int_max_abs(tensor)
3940    }
3941
3942    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3943        B::int_max_abs_dim(tensor, dim)
3944    }
3945
3946    fn min(tensor: Self::Primitive) -> Self::Primitive {
3947        B::int_min(tensor)
3948    }
3949
3950    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3951        B::int_min_dim(tensor, dim)
3952    }
3953
3954    fn min_dim_with_indices(
3955        tensor: Self::Primitive,
3956        dim: usize,
3957    ) -> (Self::Primitive, IntTensor<B>) {
3958        B::int_min_dim_with_indices(tensor, dim)
3959    }
3960
3961    fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
3962        B::int_clamp(tensor, min, max)
3963    }
3964
3965    fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
3966        B::int_clamp_min(tensor, min)
3967    }
3968
3969    fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
3970        B::int_clamp_max(tensor, max)
3971    }
3972
3973    fn abs(tensor: Self::Primitive) -> Self::Primitive {
3974        B::int_abs(tensor)
3975    }
3976
3977    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3978        B::int_powf(lhs, B::int_into_float(rhs))
3979    }
3980
3981    fn powf_scalar<E: ElementConversion>(
3982        lhs: Self::Primitive,
3983        rhs: E,
3984    ) -> <Int as TensorKind<B>>::Primitive {
3985        B::int_powf_scalar(lhs, rhs.elem())
3986    }
3987
3988    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3989        B::int_powi(lhs, rhs)
3990    }
3991
3992    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3993        B::int_powi_scalar(lhs, rhs.elem())
3994    }
3995
3996    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
3997        B::int_random(shape, distribution, device)
3998    }
3999
4000    fn sign(tensor: Self::Primitive) -> Self::Primitive {
4001        B::int_sign(tensor)
4002    }
4003
4004    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4005        B::int_sort(tensor, dim, descending)
4006    }
4007
4008    fn sort_with_indices(
4009        tensor: Self::Primitive,
4010        dim: usize,
4011        descending: bool,
4012    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4013        B::int_sort_with_indices(tensor, dim, descending)
4014    }
4015
4016    fn argsort(
4017        tensor: Self::Primitive,
4018        dim: usize,
4019        descending: bool,
4020    ) -> <Int as TensorKind<B>>::Primitive {
4021        B::int_argsort(tensor, dim, descending)
4022    }
4023
4024    /// Applies the matrix multiplication operation.
4025    ///
4026    /// `C = AB`
4027    ///
4028    /// # Panics
4029    ///
4030    /// If the two tensors don't have a compatible shape.
4031    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4032        B::int_matmul(lhs, rhs)
4033    }
4034}
4035
4036impl<B: Backend> Numeric<B> for Float {
4037    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4038        q_bin_ops!(lhs, rhs, float_add, q_add)
4039    }
4040
4041    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4042        match lhs {
4043            TensorPrimitive::Float(lhs) => {
4044                TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
4045            }
4046            TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs.elem()),
4047        }
4048    }
4049
4050    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4051        q_bin_ops!(lhs, rhs, float_sub, q_sub)
4052    }
4053
4054    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4055        match lhs {
4056            TensorPrimitive::Float(lhs) => {
4057                TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
4058            }
4059            TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs.elem()),
4060        }
4061    }
4062
4063    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4064        q_bin_ops!(lhs, rhs, float_div, q_div)
4065    }
4066
4067    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4068        match lhs {
4069            TensorPrimitive::Float(lhs) => {
4070                TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
4071            }
4072            TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs.elem()),
4073        }
4074    }
4075    fn remainder(
4076        lhs: Self::Primitive,
4077        rhs: Self::Primitive,
4078    ) -> <Float as TensorKind<B>>::Primitive {
4079        TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
4080    }
4081
4082    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4083        TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs.elem()))
4084    }
4085
4086    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4087        q_bin_ops!(lhs, rhs, float_mul, q_mul)
4088    }
4089
4090    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4091        match lhs {
4092            TensorPrimitive::Float(lhs) => {
4093                TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
4094            }
4095            TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs.elem()),
4096        }
4097    }
4098    fn neg(tensor: Self::Primitive) -> Self::Primitive {
4099        match tensor {
4100            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
4101            TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
4102        }
4103    }
4104    fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
4105        TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
4106    }
4107    fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
4108        TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
4109    }
4110
4111    fn sum(tensor: Self::Primitive) -> Self::Primitive {
4112        match tensor {
4113            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
4114            TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
4115        }
4116    }
4117
4118    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4119        match tensor {
4120            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
4121            TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
4122        }
4123    }
4124
4125    fn prod(tensor: Self::Primitive) -> Self::Primitive {
4126        match tensor {
4127            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
4128            TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
4129        }
4130    }
4131
4132    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4133        match tensor {
4134            TensorPrimitive::Float(tensor) => {
4135                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
4136            }
4137            TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
4138        }
4139    }
4140
4141    fn mean(tensor: Self::Primitive) -> Self::Primitive {
4142        match tensor {
4143            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
4144            TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
4145        }
4146    }
4147
4148    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4149        match tensor {
4150            TensorPrimitive::Float(tensor) => {
4151                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
4152            }
4153            TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
4154        }
4155    }
4156
4157    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4158        match tensor {
4159            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
4160            TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
4161        }
4162    }
4163
4164    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4165        match tensor {
4166            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
4167            TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
4168        }
4169    }
4170
4171    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4172        match tensor {
4173            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
4174            TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
4175        }
4176    }
4177
4178    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4179        match tensor {
4180            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
4181            TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
4182        }
4183    }
4184
4185    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4186        B::float_equal_elem(lhs.tensor(), rhs)
4187    }
4188    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4189        B::float_not_equal_elem(lhs.tensor(), rhs)
4190    }
4191    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4192        B::float_greater(lhs.tensor(), rhs.tensor())
4193    }
4194
4195    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4196        B::float_greater_elem(lhs.tensor(), rhs)
4197    }
4198
4199    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4200        B::float_greater_equal(lhs.tensor(), rhs.tensor())
4201    }
4202
4203    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4204        B::float_greater_equal_elem(lhs.tensor(), rhs)
4205    }
4206
4207    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4208        B::float_lower(lhs.tensor(), rhs.tensor())
4209    }
4210
4211    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4212        B::float_lower_elem(lhs.tensor(), rhs)
4213    }
4214
4215    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4216        B::float_lower_equal(lhs.tensor(), rhs.tensor())
4217    }
4218
4219    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4220        B::float_lower_equal_elem(lhs.tensor(), rhs)
4221    }
4222
4223    fn mask_where(
4224        tensor: Self::Primitive,
4225        mask: B::BoolTensorPrimitive,
4226        source: Self::Primitive,
4227    ) -> Self::Primitive {
4228        TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
4229    }
4230
4231    fn mask_fill(
4232        tensor: Self::Primitive,
4233        mask: B::BoolTensorPrimitive,
4234        value: Self::Elem,
4235    ) -> Self::Primitive {
4236        TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
4237    }
4238
4239    fn gather(
4240        dim: usize,
4241        tensor: Self::Primitive,
4242        indices: B::IntTensorPrimitive,
4243    ) -> Self::Primitive {
4244        match tensor {
4245            TensorPrimitive::Float(tensor) => {
4246                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
4247            }
4248            TensorPrimitive::QFloat(tensor) => {
4249                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
4250            }
4251        }
4252    }
4253
4254    fn scatter(
4255        dim: usize,
4256        tensor: Self::Primitive,
4257        indices: B::IntTensorPrimitive,
4258        values: Self::Primitive,
4259    ) -> Self::Primitive {
4260        TensorPrimitive::Float(B::float_scatter(
4261            dim,
4262            tensor.tensor(),
4263            indices,
4264            values.tensor(),
4265        ))
4266    }
4267
4268    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4269        match tensor {
4270            TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
4271            TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
4272        }
4273    }
4274
4275    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4276        match tensor {
4277            TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
4278            TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
4279        }
4280    }
4281
4282    fn max(tensor: Self::Primitive) -> Self::Primitive {
4283        match tensor {
4284            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
4285            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
4286        }
4287    }
4288
4289    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4290        match tensor {
4291            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
4292            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
4293        }
4294    }
4295
4296    fn max_dim_with_indices(
4297        tensor: Self::Primitive,
4298        dim: usize,
4299    ) -> (Self::Primitive, IntTensor<B>) {
4300        match tensor {
4301            TensorPrimitive::Float(tensor) => {
4302                let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
4303                (TensorPrimitive::Float(values), indices)
4304            }
4305            TensorPrimitive::QFloat(tensor) => {
4306                let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
4307                (TensorPrimitive::QFloat(values), indices)
4308            }
4309        }
4310    }
4311
4312    fn min(tensor: Self::Primitive) -> Self::Primitive {
4313        match tensor {
4314            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
4315            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
4316        }
4317    }
4318
4319    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4320        match tensor {
4321            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
4322            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
4323        }
4324    }
4325
4326    fn min_dim_with_indices(
4327        tensor: Self::Primitive,
4328        dim: usize,
4329    ) -> (Self::Primitive, IntTensor<B>) {
4330        match tensor {
4331            TensorPrimitive::Float(tensor) => {
4332                let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
4333                (TensorPrimitive::Float(values), indices)
4334            }
4335            TensorPrimitive::QFloat(tensor) => {
4336                let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
4337                (TensorPrimitive::QFloat(values), indices)
4338            }
4339        }
4340    }
4341
4342    fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
4343        match tensor {
4344            TensorPrimitive::Float(tensor) => {
4345                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
4346            }
4347            TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
4348        }
4349    }
4350
4351    fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
4352        match tensor {
4353            TensorPrimitive::Float(tensor) => {
4354                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
4355            }
4356            TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
4357        }
4358    }
4359
4360    fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
4361        match tensor {
4362            TensorPrimitive::Float(tensor) => {
4363                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
4364            }
4365            TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
4366        }
4367    }
4368
4369    fn abs(tensor: Self::Primitive) -> Self::Primitive {
4370        match tensor {
4371            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
4372            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
4373        }
4374    }
4375
4376    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4377        q_bin_ops!(lhs, rhs, float_powf, q_powf)
4378    }
4379
4380    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4381        match lhs {
4382            TensorPrimitive::Float(lhs) => {
4383                TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4384            }
4385            TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()),
4386        }
4387    }
4388
4389    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4390        q_bin_ops!(lhs, rhs, float_powf, q_powf)
4391    }
4392
4393    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4394        match lhs {
4395            TensorPrimitive::Float(lhs) => {
4396                TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem()))
4397            }
4398            TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()),
4399        }
4400    }
4401
4402    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
4403        TensorPrimitive::Float(B::float_random(shape, distribution, device))
4404    }
4405
4406    fn sign(tensor: Self::Primitive) -> Self::Primitive {
4407        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
4408    }
4409
4410    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4411        match tensor {
4412            TensorPrimitive::Float(tensor) => {
4413                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
4414            }
4415            TensorPrimitive::QFloat(tensor) => {
4416                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
4417            }
4418        }
4419    }
4420
4421    fn sort_with_indices(
4422        tensor: Self::Primitive,
4423        dim: usize,
4424        descending: bool,
4425    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4426        match tensor {
4427            TensorPrimitive::Float(tensor) => {
4428                let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
4429                (TensorPrimitive::Float(values), indices)
4430            }
4431            TensorPrimitive::QFloat(tensor) => {
4432                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
4433                (TensorPrimitive::QFloat(values), indices)
4434            }
4435        }
4436    }
4437
4438    fn argsort(
4439        tensor: Self::Primitive,
4440        dim: usize,
4441        descending: bool,
4442    ) -> <Int as TensorKind<B>>::Primitive {
4443        match tensor {
4444            TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
4445            TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
4446        }
4447    }
4448
4449    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
4450        match tensor {
4451            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
4452            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
4453        }
4454    }
4455
4456    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4457        match tensor {
4458            TensorPrimitive::Float(tensor) => {
4459                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
4460            }
4461            TensorPrimitive::QFloat(tensor) => {
4462                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
4463            }
4464        }
4465    }
4466
4467    /// Applies the matrix multiplication operation.
4468    ///
4469    /// `C = AB`
4470    ///
4471    /// # Panics
4472    ///
4473    /// If the two tensors don't have a compatible shape.
4474    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4475        match (lhs, rhs) {
4476            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4477                TensorPrimitive::Float(B::float_matmul(lhs, rhs))
4478            }
4479            (lhs, rhs) => B::q_matmul(lhs, rhs),
4480        }
4481    }
4482}
4483
4484// Tensor + tensor
4485impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
4486where
4487    K::Elem: Element,
4488{
4489    type Output = Self;
4490
4491    fn add(self, rhs: Self) -> Self::Output {
4492        Self::add(self, rhs)
4493    }
4494}
4495
4496// Tensor + scalar
4497impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
4498    for Tensor<B, D, K>
4499where
4500    K::Elem: Element,
4501{
4502    type Output = Self;
4503
4504    fn add(self, other: E) -> Self::Output {
4505        Tensor::add_scalar(self, other)
4506    }
4507}
4508
4509// Scalar + tensor
4510macro_rules! impl_tensor_scalar_add {
4511    ($($t:ty),*) => {
4512        $(
4513            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
4514            where
4515                K::Elem: Element,
4516            {
4517                type Output = Tensor<B, D, K>;
4518
4519                fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
4520                    Tensor::add_scalar(tensor, self)
4521                }
4522            }
4523        )*
4524    }
4525}
4526impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
4527
4528// Tensor - tensor
4529impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
4530where
4531    K::Elem: Element,
4532{
4533    type Output = Self;
4534
4535    fn sub(self, rhs: Self) -> Self::Output {
4536        Tensor::sub(self, rhs)
4537    }
4538}
4539
4540// Tensor - scalar
4541impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
4542    for Tensor<B, D, K>
4543where
4544    K::Elem: Element,
4545{
4546    type Output = Self;
4547
4548    fn sub(self, other: E) -> Self::Output {
4549        Tensor::sub_scalar(self, other)
4550    }
4551}
4552
4553// Scalar - tensor
4554macro_rules! impl_tensor_scalar_sub {
4555    ($($t:ty),*) => {
4556        $(
4557            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
4558            where
4559                K::Elem: Element,
4560            {
4561                type Output = Tensor<B, D, K>;
4562
4563                fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
4564                    Tensor::add_scalar(Tensor::neg(tensor), self)
4565                }
4566            }
4567        )*
4568    }
4569}
4570impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
4571
4572// Tensor / tensor
4573impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
4574where
4575    K::Elem: Element,
4576{
4577    type Output = Self;
4578
4579    fn div(self, rhs: Self) -> Self::Output {
4580        Tensor::div(self, rhs)
4581    }
4582}
4583
4584// Tensor / scalar
4585impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
4586    for Tensor<B, D, K>
4587where
4588    K::Elem: Element,
4589{
4590    type Output = Self;
4591
4592    fn div(self, other: E) -> Self::Output {
4593        Tensor::div_scalar(self, other)
4594    }
4595}
4596
4597// Scalar / tensor (float only)
4598macro_rules! impl_tensor_scalar_div {
4599    ($($t:ty),*) => {
4600        $(
4601            impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
4602            {
4603                type Output = Tensor<B, D>;
4604
4605                fn div(self, tensor: Tensor<B, D>) -> Self::Output {
4606                    tensor.recip().mul_scalar(self)
4607                }
4608            }
4609        )*
4610    }
4611}
4612
4613impl_tensor_scalar_div!(f32, f64);
4614
4615// Tensor % tensor.
4616impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
4617where
4618    K::Elem: Element,
4619{
4620    type Output = Self;
4621
4622    fn rem(self, rhs: Self) -> Self::Output {
4623        Tensor::remainder(self, rhs)
4624    }
4625}
4626
4627// Tensor % scalar.
4628impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
4629    for Tensor<B, D, K>
4630where
4631    K::Elem: Element,
4632{
4633    type Output = Self;
4634
4635    fn rem(self, other: E) -> Self::Output {
4636        Tensor::remainder_scalar(self, other)
4637    }
4638}
4639
4640// Tensor * tensor.
4641impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
4642where
4643    K::Elem: Element,
4644{
4645    type Output = Self;
4646
4647    fn mul(self, rhs: Self) -> Self::Output {
4648        Tensor::mul(self, rhs)
4649    }
4650}
4651
4652// Tensor * scalar.
4653impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
4654    for Tensor<B, D, K>
4655where
4656    K::Elem: Element,
4657{
4658    type Output = Self;
4659
4660    fn mul(self, other: E) -> Self::Output {
4661        Tensor::mul_scalar(self, other)
4662    }
4663}
4664
4665macro_rules! impl_tensor_scalar_mul {
4666    ($($t:ty),*) => {
4667        $(
4668            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
4669            where
4670                K::Elem: Element,
4671            {
4672                type Output = Tensor<B, D, K>;
4673
4674                fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
4675                    Tensor::mul_scalar(other, self)
4676                }
4677            }
4678        )*
4679    }
4680}
4681
4682impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
4683
4684impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
4685where
4686    B: Backend,
4687    K: Numeric<B>,
4688    K::Elem: Element,
4689{
4690    type Output = Self;
4691
4692    fn neg(self) -> Self::Output {
4693        Tensor::neg(self)
4694    }
4695}