burn_tensor/tensor/api/
numeric.rs

1use alloc::vec::Vec;
2
3use crate::{alloc::borrow::ToOwned, cast::ToElement};
4
5use crate::TensorPrimitive;
6use crate::{
7    BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor,
8    TensorKind,
9    backend::Backend,
10    check,
11    check::TensorCheck,
12    ops::{Device, IntTensor},
13};
14
15/// Default RTOL value for `is_close` and `all_close`.
16pub const DEFAULT_RTOL: f64 = 1e-5;
17
18/// Default ATOL value for `is_close` and `all_close`.
19pub const DEFAULT_ATOL: f64 = 1e-8;
20
21impl<B, const D: usize, K> Tensor<B, D, K>
22where
23    B: Backend,
24    K: Numeric<B>,
25    K::Elem: Element,
26{
27    /// Applies element wise addition operation.
28    ///
29    /// `y = x2 + x1`
30    ///
31    /// # Arguments
32    ///
33    /// * `other` - The tensor to add.
34    ///
35    /// # Example
36    ///
37    /// ```rust
38    /// use burn_tensor::backend::Backend;
39    /// use burn_tensor::{Tensor, Shape};
40    ///
41    /// fn example<B: Backend>() {
42    ///    let device = B::Device::default();
43    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
44    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
45    ///    let tensor = tensor1 + tensor2;
46    ///    println!("{tensor}");
47    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
48    /// }
49    /// ```
50    #[allow(clippy::should_implement_trait)]
51    pub fn add(self, other: Self) -> Self {
52        check!(TensorCheck::binary_ops_ew("Add", &self, &other));
53        Self::new(K::add(self.primitive, other.primitive))
54    }
55
56    /// Applies element wise addition operation with a scalar.
57    ///
58    /// `y = x + s`
59    ///
60    /// # Arguments
61    ///
62    /// * `other` - The scalar to add, element wise.
63    ///
64    /// # Example
65    ///
66    /// ```rust
67    /// use burn_tensor::backend::Backend;
68    /// use burn_tensor::{Tensor, Shape};
69    ///
70    /// fn example<B: Backend>() {
71    ///   let device = B::Device::default();
72    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
73    ///   let scalar = 2.0;
74    ///   let tensor = tensor + scalar;
75    ///   println!("{tensor}");
76    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
77    /// }
78    /// ```
79    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
80        Self::new(K::add_scalar::<E>(self.primitive, other))
81    }
82
83    /// Applies element wise subtraction operation.
84    ///
85    /// `y = x2 - x1`
86    ///
87    /// # Arguments
88    ///
89    /// * `other` - The tensor to subtract.
90    ///
91    /// # Example
92    ///
93    /// ```rust
94    /// use burn_tensor::backend::Backend;
95    /// use burn_tensor::{Tensor, Shape};
96    ///
97    /// fn example<B: Backend>() {
98    ///   let device = B::Device::default();
99    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
100    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
101    ///   let tensor = tensor1 - tensor2;
102    ///   println!("{tensor}");
103    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
104    /// }
105    /// ```
106    #[allow(clippy::should_implement_trait)]
107    pub fn sub(self, other: Self) -> Self {
108        check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
109        Self::new(K::sub(self.primitive, other.primitive))
110    }
111
112    /// Applies element wise subtraction operation with a scalar.
113    ///
114    /// `y = x - s`
115    ///
116    /// # Arguments
117    ///
118    /// * `other` - The scalar to subtract, element wise.
119    ///
120    /// # Example
121    ///
122    /// ```rust
123    /// use burn_tensor::backend::Backend;
124    /// use burn_tensor::{Tensor, Shape};
125    ///
126    /// fn example<B: Backend>() {
127    ///    let device = B::Device::default();
128    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
129    ///    let scalar = 2.0;
130    ///    let tensor = tensor - scalar;
131    ///    println!("{tensor}");
132    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
133    /// }
134    /// ```
135    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
136        Self::new(K::sub_scalar::<E>(self.primitive, other))
137    }
138
139    /// Applies element wise division operation.
140    ///
141    /// `y = x2 / x1`
142    ///
143    /// # Arguments
144    ///
145    /// * `other` - The tensor to divide.
146    ///
147    /// # Example
148    ///
149    /// ```rust
150    /// use burn_tensor::backend::Backend;
151    /// use burn_tensor::{Tensor, Shape};
152    ///
153    /// fn example<B: Backend>() {
154    ///    let device = B::Device::default();
155    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
156    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
157    ///    let tensor = tensor1 / tensor2;
158    ///    println!("{tensor}");
159    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
160    /// }
161    /// ```
162    #[allow(clippy::should_implement_trait)]
163    pub fn div(self, other: Self) -> Self {
164        check!(TensorCheck::binary_ops_ew("Div", &self, &other));
165        Self::new(K::div(self.primitive, other.primitive))
166    }
167
168    /// Applies element wise division operation with a scalar.
169    ///
170    /// `y = x / s`
171    ///
172    /// # Arguments
173    ///
174    /// * `other` - The scalar to divide, element wise.
175    ///
176    /// # Example
177    ///
178    /// ```rust
179    /// use burn_tensor::backend::Backend;
180    /// use burn_tensor::{Tensor, Shape};
181    ///
182    /// fn example<B: Backend>() {
183    ///    let device = B::Device::default();
184    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
185    ///    let scalar = 2.0;
186    ///    let tensor = tensor / scalar;
187    ///    println!("{tensor}");
188    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
189    /// }
190    /// ```
191    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
192        Self::new(K::div_scalar::<E>(self.primitive, other))
193    }
194
195    /// Applies element wise the remainder operation with a scalar.
196    ///
197    /// `y = x2 % x1`
198    pub fn remainder(self, other: Self) -> Self {
199        Self::new(K::remainder(self.primitive, other.primitive))
200    }
201
202    /// Applies element wise the remainder operation with a scalar.
203    ///
204    /// `y = x % s`
205    ///
206    /// # Arguments
207    ///
208    /// * `other` - The scalar to divide, element wise.
209    ///
210    /// # Example
211    ///
212    /// ```rust
213    /// use burn_tensor::backend::Backend;
214    /// use burn_tensor::{Tensor, Shape};
215    ///
216    /// fn example<B: Backend>() {
217    ///    let device = B::Device::default();
218    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
219    ///    let scalar = 2.0;
220    ///    let tensor = tensor1 % scalar;
221    ///    println!("{tensor}");
222    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
223    /// }
224    /// ```
225    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
226        Self::new(K::remainder_scalar::<E>(self.primitive, other))
227    }
228
229    /// Applies element wise multiplication operation.
230    ///
231    /// `y = x2 * x1`
232    ///
233    /// # Arguments
234    ///
235    /// * `other` - The tensor to multiply.
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use burn_tensor::backend::Backend;
241    /// use burn_tensor::{Tensor, Shape};
242    ///
243    /// fn example<B: Backend>() {
244    ///    let device = B::Device::default();
245    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
246    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
247    ///    let tensor = tensor1 * tensor2;
248    ///    println!("{tensor}");
249    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
250    /// }
251    /// ```
252    #[allow(clippy::should_implement_trait)]
253    pub fn mul(self, other: Self) -> Self {
254        check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
255        Self::new(K::mul(self.primitive, other.primitive))
256    }
257
258    /// Applies element wise multiplication operation with a scalar.
259    ///
260    /// `y = x * s`
261    ///
262    /// # Arguments
263    ///
264    /// * `other` - The scalar to multiply, element wise.
265    ///
266    /// # Example
267    ///
268    /// ```rust
269    /// use burn_tensor::backend::Backend;
270    /// use burn_tensor::{Tensor, Shape};
271    ///
272    /// fn example<B: Backend>() {
273    ///    let device = B::Device::default();
274    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
275    ///    let scalar = 2.0;
276    ///    let tensor = tensor * scalar;
277    ///    println!("{tensor}");
278    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
279    /// }
280    /// ```
281    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
282        Self::new(K::mul_scalar::<E>(self.primitive, other))
283    }
284
285    /// Switch sign of each element in the tensor.
286    ///
287    /// `y = -x`
288    ///
289    /// # Example
290    ///
291    /// ```rust
292    /// use burn_tensor::backend::Backend;
293    /// use burn_tensor::{Tensor, Shape};
294    ///
295    /// fn example<B: Backend>() {
296    ///    let device = B::Device::default();
297    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
298    ///    let tensor = -tensor;
299    ///    println!("{tensor}");
300    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
301    /// }
302    /// ```
303    #[allow(clippy::should_implement_trait)]
304    pub fn neg(self) -> Self {
305        Self::new(K::neg(self.primitive))
306    }
307
308    /// Returns the signs of the elements of the input tensor.
309    ///
310    /// # Example
311    ///
312    /// ```rust
313    /// use burn_tensor::backend::Backend;
314    /// use burn_tensor::{Tensor, Shape};
315    ///
316    /// fn example<B: Backend>() {
317    ///    let device = B::Device::default();
318    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
319    ///    let tensor = tensor.sign();
320    ///    println!("{tensor}");
321    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
322    /// }
323    /// ```
324    pub fn sign(self) -> Self {
325        Self::new(K::sign(self.primitive))
326    }
327
328    /// Create a tensor of the given shape where each element is zero.
329    ///
330    /// # Example
331    ///
332    /// ```rust
333    /// use burn_tensor::backend::Backend;
334    /// use burn_tensor::{Tensor, Shape};
335    ///
336    /// fn example<B: Backend>() {
337    ///    let device = B::Device::default();
338    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
339    ///    println!("{tensor}");
340    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
341    /// }
342    /// ```
343    pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
344        let shape = shape.into();
345        check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
346        Self::new(K::zeros(shape, device))
347    }
348
349    /// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
350    ///
351    /// # Example
352    ///
353    /// ```rust
354    /// use burn_tensor::backend::Backend;
355    /// use burn_tensor::{Tensor, Shape};
356    ///
357    /// fn example<B: Backend>() {
358    ///   let device = B::Device::default();
359    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
360    ///   let tensor = tensor.zeros_like();
361    ///   println!("{tensor}");
362    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
363    /// }
364    /// ```
365    pub fn zeros_like(&self) -> Self {
366        Self::zeros(self.shape(), &self.device())
367    }
368
369    /// Create a tensor of the given shape where each element is one.
370    ///
371    /// # Example
372    ///
373    /// ```rust
374    /// use burn_tensor::backend::Backend;
375    /// use burn_tensor::{Tensor, Shape};
376    ///
377    /// fn example<B: Backend>() {
378    ///   let device = B::Device::default();
379    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
380    ///   println!("{tensor}");
381    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
382    /// }
383    /// ```
384    pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
385        let shape = shape.into();
386        check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
387        Self::new(K::ones(shape, device))
388    }
389
390    /// Returns a new tensor with the same shape and device as the current tensor filled with ones.
391    ///
392    /// # Example
393    ///
394    /// ```rust
395    /// use burn_tensor::backend::Backend;
396    /// use burn_tensor::{Tensor, Shape};
397    ///
398    /// fn example<B: Backend>() {
399    ///    let device = B::Device::default();
400    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
401    ///    let tensor = tensor.ones_like();
402    ///    println!("{tensor}");
403    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
404    /// }
405    /// ```
406    pub fn ones_like(&self) -> Self {
407        Self::ones(self.shape(), &self.device())
408    }
409
410    /// Create a tensor of the given shape where each element is equal to the provided value.
411    ///
412    /// # Example
413    ///
414    /// ```rust
415    /// use burn_tensor::backend::Backend;
416    /// use burn_tensor::{Tensor, Shape};
417    ///
418    /// fn example<B: Backend>() {
419    ///   let device = B::Device::default();
420    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
421    ///   println!("{tensor}");
422    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
423    /// }
424    /// ```
425    pub fn full<S: Into<Shape>, E: ElementConversion>(
426        shape: S,
427        fill_value: E,
428        device: &B::Device,
429    ) -> Self {
430        let shape = shape.into();
431        check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
432        Self::new(K::full(shape, fill_value, device))
433    }
434
435    ///Returns a new tensor with the same shape and device as the current tensor filled with the provided value.
436    ///
437    /// # Example
438    ///
439    /// ```rust
440    /// use burn_tensor::backend::Backend;
441    /// use burn_tensor::{Tensor, Shape};
442    ///
443    /// fn example<B: Backend>() {
444    ///    let device = B::Device::default();
445    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
446    ///    let tensor = tensor.full_like(5.0);
447    ///    println!("{tensor}");
448    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
449    /// }
450    /// ```
451    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
452        Self::full(self.shape(), fill_value, &self.device())
453    }
454
455    /// Aggregate all elements in the tensor with the mean operation.
456    ///
457    /// # Example
458    ///
459    /// ```rust
460    /// use burn_tensor::backend::Backend;
461    /// use burn_tensor::{Tensor, Shape};
462    ///
463    /// fn example<B: Backend>() {
464    ///    let device = B::Device::default();
465    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
466    ///    let tensor = tensor.mean();
467    ///    println!("{tensor}");
468    ///    // [3.6666667]
469    /// }
470    /// ```
471    pub fn mean(self) -> Tensor<B, 1, K> {
472        Tensor::new(K::mean(self.primitive))
473    }
474
475    /// Aggregate all elements in the tensor with the sum operation.
476    ///
477    /// # Example
478    ///
479    /// ```rust
480    /// use burn_tensor::backend::Backend;
481    /// use burn_tensor::{Tensor, Shape};
482    ///
483    /// fn example<B: Backend>() {
484    ///   let device = B::Device::default();
485    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
486    ///   let tensor = tensor.sum();
487    ///   println!("{tensor}");
488    ///   // [22.0]
489    /// }
490    /// ```
491    pub fn sum(self) -> Tensor<B, 1, K> {
492        Tensor::new(K::sum(self.primitive))
493    }
494
495    /// Aggregate all elements along the given *dimension* or *axis*
496    /// in the tensor with the mean operation.
497    ///
498    /// # Arguments
499    ///
500    /// * `dim` - The dimension or axis along which to aggregate the elements.
501    ///
502    /// # Example
503    ///
504    /// ```rust
505    /// use burn_tensor::backend::Backend;
506    /// use burn_tensor::{Tensor, Shape};
507    ///
508    /// fn example<B: Backend>() {
509    ///   let device = B::Device::default();
510    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
511    ///   let tensor = tensor.clone().mean_dim(0);
512    ///   println!("{tensor}");
513    ///   // [[3.0, 3.5, 4.5]]
514    ///   let tensor = tensor.clone().mean_dim(1);
515    ///   println!("{tensor}");
516    ///   // [[0.6666667], [6.6666665]]
517    /// }
518    /// ```
519    pub fn mean_dim(self, dim: usize) -> Self {
520        check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
521        Self::new(K::mean_dim(self.primitive, dim))
522    }
523
524    /// Aggregate all elements along the given *dimension* or *axis*
525    /// in the tensor with the sum operation.
526    ///
527    /// # Arguments
528    ///
529    /// * `dim` - The dimension or axis along which to aggregate the elements.
530    ///
531    /// # Example
532    ///
533    /// ```rust
534    /// use burn_tensor::backend::Backend;
535    /// use burn_tensor::{Tensor, Shape};
536    ///
537    /// fn example<B: Backend>() {
538    ///    let device = B::Device::default();
539    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
540    ///    let tensor = tensor.clone().sum_dim(0);
541    ///    println!("{tensor}");
542    ///    // [[6.0, 7.0, 9.0]]
543    ///    let tensor = tensor.clone().sum_dim(1);
544    ///    println!("{tensor}");
545    ///    // [[2.0], [20.0]]
546    /// }
547    /// ```
548    pub fn sum_dim(self, dim: usize) -> Self {
549        check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
550        Self::new(K::sum_dim(self.primitive, dim))
551    }
552
553    /// Aggregate all elements in the tensor with the product operation.
554    ///
555    /// # Example
556    ///
557    /// ```rust
558    /// use burn_tensor::backend::Backend;
559    /// use burn_tensor::{Tensor, Shape};
560    ///
561    /// fn example<B: Backend>() {
562    ///    let device = B::Device::default();
563    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
564    ///    let tensor = tensor.prod();
565    ///    println!("{tensor}");
566    ///    // [-1620.0]
567    /// }
568    /// ```
569    pub fn prod(self) -> Tensor<B, 1, K> {
570        Tensor::new(K::prod(self.primitive))
571    }
572
573    /// Aggregate all elements along the given *dimension* or *axis*
574    /// in the tensor with the product operation.
575    ///
576    /// # Arguments
577    ///
578    /// * `dim` - The dimension or axis along which to aggregate the elements.
579    ///
580    /// # Example
581    ///
582    /// ```rust
583    /// use burn_tensor::backend::Backend;
584    /// use burn_tensor::{Tensor, Shape};
585    ///
586    /// fn example<B: Backend>() {
587    ///    let device = B::Device::default();
588    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
589    ///    let tensor = tensor.clone().prod_dim(0);
590    ///    println!("{tensor}");
591    ///    // [[5.0, -18.0, 18.0]]
592    ///    let tensor = tensor.clone().prod_dim(1);
593    ///    println!("{tensor}");
594    ///    // [[-6.0], [270.0]]
595    /// }
596    /// ```
597    pub fn prod_dim(self, dim: usize) -> Self {
598        check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
599        Self::new(K::prod_dim(self.primitive, dim))
600    }
601
602    /// Applies element wise equal comparison and returns a boolean tensor.
603    ///
604    /// # Arguments
605    ///
606    /// * `other` - The element to compare.
607    ///
608    /// # Example
609    ///
610    /// ```rust
611    /// use burn_tensor::backend::Backend;
612    /// use burn_tensor::{Tensor, Shape};
613    ///
614    /// fn example<B: Backend>() {
615    ///    let device = B::Device::default();
616    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
617    ///    let tensor = tensor.equal_elem(3.0);
618    ///    println!("{tensor}");
619    ///    // [[false, false, true], [false, false, false]]
620    /// }
621    /// ```
622    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
623        Tensor::new(K::equal_elem(self.primitive, other.elem()))
624    }
625
626    /// Applies element wise non-equality comparison and returns a boolean tensor.
627    ///
628    /// # Arguments
629    ///
630    /// * `other` - The element to compare.
631    ///
632    /// # Example
633    ///
634    /// ```rust
635    /// use burn_tensor::backend::Backend;
636    /// use burn_tensor::{Tensor, Shape};
637    ///
638    /// fn example<B: Backend>() {
639    ///    let device = B::Device::default();
640    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
641    ///    let tensor = tensor.not_equal_elem(3.0);
642    ///    println!("{tensor}");
643    ///    // [[true, true, false], [true, true, true]]
644    /// }
645    /// ```
646    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
647        Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
648    }
649
650    /// Applies element wise greater comparison and returns a boolean tensor.
651    ///
652    /// # Panics
653    ///
654    /// If the two tensors don't have the same shape.
655    ///
656    /// # Example
657    ///
658    /// ```rust
659    /// use burn_tensor::backend::Backend;
660    /// use burn_tensor::{Tensor, Shape};
661    ///
662    /// fn example<B: Backend>() {
663    ///   let device = B::Device::default();
664    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
665    ///   let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
666    ///   let tensor = tensor1.greater(tensor2);
667    ///   println!("{tensor}");
668    ///   // [[false, false, false], [true, true, true]]
669    /// }
670    /// ```
671    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
672        check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
673        Tensor::new(K::greater(self.primitive, other.primitive))
674    }
675
676    /// Applies element wise greater-equal comparison and returns a boolean tensor.
677    ///
678    /// # Panics
679    ///
680    /// If the two tensors don't have the same shape.
681    ///
682    /// # Example
683    ///
684    /// ```rust
685    /// use burn_tensor::backend::Backend;
686    /// use burn_tensor::{Tensor, Shape};
687    ///
688    /// fn example<B: Backend>() {
689    ///    let device = B::Device::default();
690    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
691    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
692    ///    let tensor = tensor1.greater_equal(tensor2);
693    ///    println!("{tensor}");
694    ///    // [[true, false, false], [true, true, true]]
695    /// }
696    /// ```
697    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
698        check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
699        Tensor::new(K::greater_equal(self.primitive, other.primitive))
700    }
701
702    /// Applies element wise lower comparison and returns a boolean tensor.
703    ///
704    /// # Panics
705    ///
706    /// If the two tensors don't have the same shape.
707    ///
708    /// # Example
709    ///
710    /// ```rust
711    /// use burn_tensor::backend::Backend;
712    /// use burn_tensor::{Tensor, Shape};
713    ///
714    /// fn example<B: Backend>() {
715    ///    let device = B::Device::default();
716    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
717    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
718    ///    let tensor = tensor1.lower(tensor2);
719    ///    println!("{tensor}");
720    ///    // [[false, true, true], [false, false, false]]
721    /// }
722    /// ```
723    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
724        check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
725        Tensor::new(K::lower(self.primitive, other.primitive))
726    }
727
728    /// Applies element wise lower-equal comparison and returns a boolean tensor.
729    ///
730    /// # Panics
731    ///
732    /// If the two tensors don't have the same shape.
733    ///
734    /// # Example
735    ///
736    /// ```rust
737    /// use burn_tensor::backend::Backend;
738    /// use burn_tensor::{Tensor, Shape};
739    ///
740    /// fn example<B: Backend>() {
741    ///    let device = B::Device::default();
742    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
743    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
744    ///    let tensor = tensor1.lower_equal(tensor2);
745    ///    println!("{tensor}");
746    ///    // [[true, true, true], [false, false, false]]
747    /// }
748    /// ```
749    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
750        check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
751        Tensor::new(K::lower_equal(self.primitive, other.primitive))
752    }
753
754    /// Applies greater than `other` comparison and returns a boolean tensor.
755    ///
756    /// # Arguments
757    ///
758    /// * `other` - The element to compare.
759    ///
760    /// # Example
761    ///
762    /// ```rust
763    /// use burn_tensor::backend::Backend;
764    /// use burn_tensor::{Tensor, Shape};
765    ///
766    /// fn example<B: Backend>() {
767    ///    let device = B::Device::default();
768    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
769    ///    let tensor = tensor.greater_elem(3.0);
770    ///    println!("{tensor}");
771    ///    // [[false, false, true], [true, true, true]]
772    /// }
773    /// ```
774    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
775        Tensor::new(K::greater_elem(self.primitive, other.elem()))
776    }
777
778    /// Applies greater-equal than `other` comparison and returns a boolean tensor.
779    ///
780    /// # Arguments
781    ///
782    /// * `other` - The element to compare.
783    ///
784    /// # Example
785    ///
786    /// ```rust
787    /// use burn_tensor::backend::Backend;
788    /// use burn_tensor::{Tensor, Shape};
789    ///
790    /// fn example<B: Backend>() {
791    ///    let device = B::Device::default();
792    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
793    ///    let tensor = tensor.greater_equal_elem(3.0);
794    ///    println!("{tensor}");
795    ///    // [[false, false, true], [true, true, true]]
796    /// }
797    /// ```
798    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
799        Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
800    }
801
802    /// Applies lower than `other` comparison and returns a boolean tensor.
803    ///
804    /// # Arguments
805    ///
806    /// * `other` - The element to compare.
807    ///
808    /// # Example
809    ///
810    /// ```rust
811    /// use burn_tensor::backend::Backend;
812    /// use burn_tensor::{Tensor, Shape};
813    ///
814    /// fn example<B: Backend>() {
815    ///     let device = B::Device::default();
816    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
817    ///     let tensor = tensor.lower_elem(3.0);
818    ///     println!("{tensor}");
819    ///     // [[true, true, false], [false, false, false]]
820    /// }
821    /// ```
822    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
823        Tensor::new(K::lower_elem(self.primitive, other.elem()))
824    }
825
826    /// Applies lower-equal than `other` comparison and returns a boolean tensor.
827    ///
828    /// # Arguments
829    ///
830    /// * `other` - The element to compare.
831    ///
832    /// # Example
833    ///
834    /// ```rust
835    /// use burn_tensor::backend::Backend;
836    /// use burn_tensor::{Tensor, Shape};
837    ///
838    /// fn example<B: Backend>() {
839    ///    let device = B::Device::default();
840    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
841    ///    let tensor = tensor.lower_equal_elem(3.0);
842    ///    println!("{tensor}");
843    ///    // [[true, true, true], [false, false, false]]
844    /// }
845    /// ```
846    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
847        Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
848    }
849
850    /// Update the given tensor with the value tensor where the mask is true.
851    ///
852    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
853    /// a scalar.
854    ///
855    /// # Example
856    ///
857    /// ```rust
858    /// use burn_tensor::backend::Backend;
859    /// use burn_tensor::{Tensor, Shape, Bool};
860    ///
861    /// fn example<B: Backend>() {
862    ///   let device = B::Device::default();
863    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
864    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
865    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
866    ///   let tensor = tensor.mask_where(mask, value);
867    ///   println!("{tensor}");
868    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
869    /// }
870    /// ```
871    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
872        Self::new(K::mask_where(
873            self.primitive,
874            mask.primitive,
875            value.primitive,
876        ))
877    }
878
879    /// Update the given tensor with the value where the mask is true.
880    ///
881    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
882    /// a tensor.
883    ///
884    /// # Example
885    ///
886    /// ```rust
887    /// use burn_tensor::backend::Backend;
888    /// use burn_tensor::{Tensor, Shape, Bool};
889    ///
890    /// fn example<B: Backend>() {
891    ///   let device = B::Device::default();
892    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
893    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
894    ///   let tensor = tensor.mask_fill(mask, 3.0);
895    ///   println!("{tensor}");
896    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
897    /// }
898    /// ```
899    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
900        Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
901    }
902
903    /// Gather tensor elements corresponding to the given indices from the specified dim.
904    ///
905    /// Example using a 3D tensor:
906    ///
907    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
908    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
909    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
910    ///
911    /// # Notes
912    ///
913    /// The index tensor should have the same shape as the original tensor except for the dim
914    /// specified.
915    ///
916    /// # Warning
917    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
918    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
919    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
920        check!(TensorCheck::gather::<D>(
921            dim,
922            &self.shape(),
923            &indices.shape()
924        ));
925
926        Self::new(K::gather(dim, self.primitive, indices.primitive))
927    }
928
929    /// Assign the gathered elements corresponding to the given indices along the specified dimension
930    /// from the value tensor to the original tensor using sum reduction.
931    ///
932    /// Example using a 3D tensor:
933    ///
934    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
935    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
936    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
937    ///
938    /// # Notes
939    ///
940    /// The index tensor should have the same shape as the original tensor except for the specified
941    /// dimension. The value and index tensors should have the same shape.
942    ///
943    /// Other references to the input tensor will not be modified by this operation.
944    ///
945    /// # Warning
946    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
947    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
948    pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
949        check!(TensorCheck::scatter::<D>(
950            dim,
951            &self.shape(),
952            &indices.shape(),
953            &values.shape()
954        ));
955
956        Self::new(K::scatter(
957            dim,
958            self.primitive,
959            indices.primitive,
960            values.primitive,
961        ))
962    }
963
964    /// Select the tensor elements along the given dimension corresponding to the given indices.
965    ///
966    /// Example using a 3D tensor:
967    ///
968    /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0`
969    /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1`
970    /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2`
971    ///
972    /// # Warning
973    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
974    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
975    ///
976    /// # Example
977    ///
978    /// ```rust
979    /// use burn_tensor::backend::Backend;
980    /// use burn_tensor::{Tensor, Shape, Int};
981    ///
982    /// fn example<B: Backend>() {
983    ///   let device = B::Device::default();
984    ///   let tensor = Tensor::<B, 3>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
985    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);
986    ///   let tensor = tensor.select(0, indices);
987    ///   println!("{tensor}");
988    ///   //  [[1.0, -2.0, 3.0]]
989    /// }
990    /// ```
991    pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
992        check!(TensorCheck::select::<D>(dim));
993        Self::new(K::select(self.primitive, dim, indices))
994    }
995
996    /// Assign the selected elements along the given dimension corresponding to the given indices
997    /// from the value tensor to the original tensor using sum reduction.
998    ///
999    /// Example using a 3D tensor:
1000    ///
1001    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1002    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1003    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1004    ///
1005    /// # Warning
1006    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1007    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1008    pub fn select_assign(
1009        self,
1010        dim: usize,
1011        indices: Tensor<B, 1, Int>,
1012        values: Tensor<B, D, K>,
1013    ) -> Self {
1014        check!(TensorCheck::select_assign::<D>(dim));
1015
1016        Self::new(K::select_assign(
1017            self.primitive,
1018            dim,
1019            indices,
1020            values.primitive,
1021        ))
1022    }
1023
1024    /// Applies the argmax function along the given dimension and returns an integer tensor.
1025    ///
1026    /// # Example
1027    ///
1028    /// ```rust
1029    /// use burn_tensor::backend::Backend;
1030    /// use burn_tensor::{Tensor, Shape};
1031    ///
1032    /// fn example<B: Backend>() {
1033    ///     let device = B::Device::default();
1034    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1035    ///     let tensor = tensor.argmax(1);
1036    ///     println!("{:?}", tensor.shape());
1037    ///     // Shape { dims: [2, 1, 3] }
1038    /// }
1039    /// ```
1040    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
1041        Tensor::new(K::argmax(self.primitive, dim))
1042    }
1043
1044    /// Find the maximum value.
1045    ///
1046    /// # Example
1047    ///
1048    /// ```rust
1049    /// use burn_tensor::backend::Backend;
1050    /// use burn_tensor::{Tensor, Shape};
1051    ///
1052    /// fn example<B: Backend>() {
1053    ///   let device = B::Device::default();
1054    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1055    ///   let tensor = tensor.max();
1056    ///   println!("{tensor}");
1057    ///   // [9.0]
1058    /// }
1059    /// ```
1060    pub fn max(self) -> Tensor<B, 1, K> {
1061        Tensor::new(K::max(self.primitive))
1062    }
1063
1064    /// Find the maximum value along the given dimension.
1065    ///
1066    /// # Example
1067    ///
1068    /// ```rust
1069    /// use burn_tensor::backend::Backend;
1070    /// use burn_tensor::{Tensor, Shape};
1071    ///
1072    /// fn example<B: Backend>() {
1073    ///   let device = B::Device::default();
1074    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1075    ///   let tensor = tensor.max_dim(0);
1076    ///   println!("{tensor}");
1077    ///   // [[5.0, 9.0, 6.0]]
1078    /// }
1079    /// ```
1080    pub fn max_dim(self, dim: usize) -> Tensor<B, D, K> {
1081        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1082
1083        Tensor::new(K::max_dim(self.primitive, dim))
1084    }
1085
1086    /// Find the maximum value along the given dimension.
1087    ///
1088    /// Also returns the indices.
1089    ///
1090    /// # Example
1091    ///
1092    /// ```rust
1093    /// use burn_tensor::backend::Backend;
1094    /// use burn_tensor::{Tensor, Shape};
1095    ///
1096    /// fn example<B: Backend>() {
1097    ///    let device = B::Device::default();
1098    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1099    ///    let (tensor, index) = tensor.max_dim_with_indices(0);
1100    ///    // [[5.0, 9.0, 6.0]]
1101    ///    println!("{tensor}");
1102    ///    // [[1, 1, 1]]
1103    ///    println!("{index}");
1104    /// }
1105    /// ```
1106    pub fn max_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1107        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1108
1109        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1110
1111        let tensor = Tensor::new(tensor);
1112        let index = Tensor::new(index);
1113
1114        (tensor, index)
1115    }
1116
1117    /// Finds the maximum pair wise values with another tensor.
1118    ///
1119    /// # Arguments
1120    ///
1121    /// * `other` - Other tensor to find maximum elements with
1122    ///
1123    /// # Returns
1124    ///
1125    /// A tensor with the same shape as the input tensors containing the maximum value found
1126    /// in the input tensors.
1127    ///
1128    /// # Example
1129    ///
1130    /// ```rust
1131    /// use burn_tensor::backend::Backend;
1132    /// use burn_tensor::{Tensor, Shape};
1133    ///
1134    /// fn example<B: Backend>() {
1135    ///    let device = B::Device::default();
1136    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1137    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1138    ///    let tensor = tensor1.max_pair(tensor2);
1139    ///    println!("{tensor}");
1140    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1141    /// }
1142    /// ```
1143    pub fn max_pair(self, other: Self) -> Self {
1144        let mask = self.clone().lower(other.clone());
1145        self.mask_where(mask, other)
1146    }
1147
1148    /// Find the maximum absolute value.
1149    ///
1150    /// # Example
1151    ///
1152    /// ```rust
1153    /// use burn_tensor::backend::Backend;
1154    /// use burn_tensor::{Tensor, Shape};
1155    ///
1156    /// fn example<B: Backend>() {
1157    ///   let device = B::Device::default();
1158    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
1159    ///   let tensor = tensor.max_abs();
1160    ///   println!("{tensor}");
1161    ///   // [7.0]
1162    /// }
1163    /// ```
1164    pub fn max_abs(self) -> Tensor<B, 1, K> {
1165        Tensor::new(K::max_abs(self.primitive))
1166    }
1167
1168    /// Find the maximum absolute value along the given dimension.
1169    ///
1170    /// # Example
1171    ///
1172    /// ```rust
1173    /// use burn_tensor::backend::Backend;
1174    /// use burn_tensor::{Tensor, Shape};
1175    ///
1176    /// fn example<B: Backend>() {
1177    ///   let device = B::Device::default();
1178    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1179    ///   let tensor = tensor.max_dim(0);
1180    ///   println!("{tensor}");
1181    ///   // [[5.0, 9.0, 6.0]]
1182    /// }
1183    /// ```
1184    pub fn max_abs_dim(self, dim: usize) -> Tensor<B, D, K> {
1185        check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
1186
1187        Tensor::new(K::max_abs_dim(self.primitive, dim))
1188    }
1189
1190    /// Applies the argmin function along the given dimension and returns an integer tensor.
1191    ///
1192    /// # Example
1193    ///
1194    /// ```rust
1195    /// use burn_tensor::backend::Backend;
1196    /// use burn_tensor::{Tensor, Shape};
1197    ///
1198    /// fn example<B: Backend>() {
1199    ///     let device = Default::default();
1200    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1201    ///     let tensor = tensor.argmin(1);
1202    ///     println!("{:?}", tensor.shape());
1203    ///     // Shape { dims: [2, 1, 3] }
1204    /// }
1205    /// ```
1206    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1207        Tensor::new(K::argmin(self.primitive, dim))
1208    }
1209
1210    /// Find the minimum value.
1211    ///
1212    /// # Example
1213    ///
1214    /// ```rust
1215    /// use burn_tensor::backend::Backend;
1216    /// use burn_tensor::{Tensor, Shape};
1217    ///
1218    /// fn example<B: Backend>() {
1219    ///    let device = B::Device::default();
1220    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1221    ///    let tensor = tensor.min();
1222    ///    println!("{tensor}");
1223    ///    // [-2.0]
1224    /// }
1225    /// ```
1226    pub fn min(self) -> Tensor<B, 1, K> {
1227        Tensor::new(K::min(self.primitive))
1228    }
1229
1230    /// Find the minimum value along the given dimension.
1231    ///
1232    /// # Example
1233    ///
1234    /// ```rust
1235    /// use burn_tensor::backend::Backend;
1236    /// use burn_tensor::{Tensor, Shape};
1237    ///
1238    /// fn example<B: Backend>() {
1239    ///    let device = B::Device::default();
1240    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1241    ///    let tensor = tensor.min_dim(0);
1242    ///    println!("{tensor}");
1243    ///    // [[1.0, -2.0, 3.0]]
1244    /// }
1245    /// ```
1246    pub fn min_dim(self, dim: usize) -> Tensor<B, D, K> {
1247        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1248        Tensor::new(K::min_dim(self.primitive, dim))
1249    }
1250
1251    /// Find the minimum value along the given dimension.
1252    ///
1253    /// Also returns the indices.
1254    ///
1255    /// # Example
1256    ///
1257    /// ```rust
1258    /// use burn_tensor::backend::Backend;
1259    /// use burn_tensor::{Tensor, Shape};
1260    ///
1261    /// fn example<B: Backend>() {
1262    ///    let device = B::Device::default();
1263    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1264    ///    let (tensor, index) = tensor.min_dim_with_indices(0);
1265    ///    println!("{tensor}");
1266    ///    // [[5.0, -2.0, 3.0]]
1267    ///    println!("{}", index);
1268    ///    // [[1, 0, 0]]
1269    /// }
1270    /// ```
1271    pub fn min_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1272        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1273
1274        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1275
1276        let tensor = Tensor::new(tensor);
1277        let index = Tensor::new(index);
1278
1279        (tensor, index)
1280    }
1281
1282    /// Finds the minimum pair wise values with another tensor.
1283    ///
1284    /// # Arguments
1285    ///
1286    /// * `other` - Other tensor to find minimum elements with
1287    ///
1288    /// # Returns
1289    ///
1290    /// A tensor with the same shape as the input tensors containing the minimum value found
1291    /// between each element of the two source tensors.
1292    ///
1293    /// # Example
1294    ///
1295    /// ```rust
1296    /// use burn_tensor::backend::Backend;
1297    /// use burn_tensor::{Tensor, Shape};
1298    ///
1299    /// fn example<B: Backend>() {
1300    ///    let device = B::Device::default();
1301    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1302    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1303    ///    let tensor = tensor1.min_pair(tensor2);
1304    ///    println!("{tensor}");
1305    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1306    /// }
1307    pub fn min_pair(self, other: Self) -> Self {
1308        let mask = other.clone().lower(self.clone());
1309        self.mask_where(mask, other)
1310    }
1311
1312    /// Clamp element wise between the given min and max values.
1313    ///
1314    /// # Arguments
1315    ///
1316    /// * `min` - The minimum value.
1317    /// * `max` - The maximum value.
1318    ///
1319    /// # Returns
1320    ///
1321    /// A new tensor with the values clamped between the given min and max values.
1322    ///
1323    /// # Example
1324    ///
1325    /// ```rust
1326    /// use burn_tensor::backend::Backend;
1327    /// use burn_tensor::{Int, Tensor};
1328    ///
1329    /// fn example<B: Backend>() {
1330    ///   let device = Default::default();
1331    ///   let tensor = Tensor::<B, 2, Int>::from_ints(
1332    ///    [
1333    ///     [1, 2, 3],
1334    ///     [4, 5, 6],
1335    ///     [7, 8, 9]
1336    ///    ],
1337    ///    &device);
1338    ///    let tensor = tensor.clamp(2, 6);
1339    ///    println!("{tensor}");
1340    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1341    /// }
1342    /// ```
1343    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1344        Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1345    }
1346
1347    /// Clamp element wise under a minimum value.
1348    ///
1349    /// # Arguments
1350    ///
1351    /// * `tensor` - The tensor to clamp.
1352    /// * `min` - The minimum value.
1353    ///
1354    /// # Returns
1355    ///
1356    /// A new tensor with the values clamped under the given min value.
1357    ///
1358    /// # Example
1359    ///
1360    /// ```rust
1361    /// use burn_tensor::backend::Backend;
1362    /// use burn_tensor::{Int, Tensor};
1363    ///
1364    /// fn example<B: Backend>() {
1365    ///    let device = Default::default();
1366    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1367    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1368    ///    &device);
1369    ///    let tensor = tensor.clamp_min(4);
1370    ///    println!("{tensor}");
1371    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1372    /// }
1373    /// ```
1374    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1375        Self::new(K::clamp_min(self.primitive, min.elem()))
1376    }
1377
1378    /// Clamp element wise over a maximum value.
1379    ///
1380    /// # Arguments
1381    ///
1382    /// * `tensor` - The tensor to clamp.
1383    /// * `max` - The maximum value.
1384    ///
1385    /// # Returns
1386    ///
1387    /// A new tensor with the values clamped over the given max value.
1388    ///
1389    /// # Example
1390    ///
1391    /// ```rust
1392    /// use burn_tensor::backend::Backend;
1393    /// use burn_tensor::{Int, Tensor};
1394    ///
1395    /// fn example<B: Backend>() {
1396    ///    let device = Default::default();
1397    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1398    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1399    ///    &device);
1400    ///    let tensor = tensor.clamp_max(5);
1401    ///    println!("{tensor}");
1402    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1403    /// }
1404    /// ```
1405    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1406        Self::new(K::clamp_max(self.primitive, max.elem()))
1407    }
1408
1409    /// Apply element wise absolute value operation.
1410    ///
1411    /// # Example
1412    ///
1413    /// ```rust
1414    /// use burn_tensor::backend::Backend;
1415    /// use burn_tensor::{Int, Tensor};
1416    ///
1417    /// fn example<B: Backend>() {
1418    ///   let device = Default::default();
1419    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1420    ///   let tensor = tensor.abs();
1421    ///   println!("{tensor}");
1422    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1423    /// }
1424    /// ```
1425    pub fn abs(self) -> Self {
1426        Self::new(K::abs(self.primitive))
1427    }
1428
1429    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1430    /// the other elements of the result tensor out are set to 0.
1431    ///
1432    /// See also [`triu_mask`](Tensor::triu_mask).
1433    ///
1434    /// # Arguments
1435    ///
1436    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1437    ///   towards the upper triangle.
1438    ///
1439    /// # Example
1440    /// ```rust
1441    /// use burn_tensor::backend::Backend;
1442    /// use burn_tensor::{Int, Tensor};
1443    ///
1444    /// fn example<B: Backend>() {
1445    ///    let device = Default::default();
1446    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1447    ///        [
1448    ///          [1, 2, 3],
1449    ///          [4, 5, 6],
1450    ///          [7, 8, 9]
1451    ///        ],
1452    ///        &device
1453    ///    );
1454    ///    let tensor = tensor.triu(1);
1455    ///    println!("{tensor}");
1456    ///    // [
1457    ///    //   [0, 2, 3],
1458    ///    //   [0, 0, 6],
1459    ///    //   [0, 0, 0]
1460    ///    // ]
1461    /// }
1462    /// ```
1463    pub fn triu(self, diagonal: i64) -> Self {
1464        check!(TensorCheck::tri::<{ D }>());
1465
1466        // last two dimensions
1467        let shape = &self.shape().dims[D - 2..].to_owned();
1468
1469        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1470        self.mask_fill(mask, 0)
1471    }
1472
1473    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1474    /// the other elements of the result tensor out are set to 0.
1475    ///
1476    /// See also [`tril_mask`](Tensor::tril_mask).
1477    ///
1478    /// # Arguments
1479    ///
1480    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1481    ///   towards the upper triangle.
1482    ///
1483    /// # Example
1484    /// ```rust
1485    /// use burn_tensor::backend::Backend;
1486    /// use burn_tensor::{Int, Tensor};
1487    ///
1488    /// fn example<B: Backend>() {
1489    ///    let device = Default::default();
1490    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1491    ///        [
1492    ///          [1, 2, 3],
1493    ///          [4, 5, 6],
1494    ///          [7, 8, 9]
1495    ///        ],
1496    ///        &device
1497    ///    );
1498    ///
1499    ///    let tensor = tensor.tril(-1);
1500    ///    println!("{tensor}");
1501    ///    // [
1502    ///    //   [0, 0, 0],
1503    ///    //   [4, 0, 0],
1504    ///    //   [7, 8, 0]
1505    ///    // ]
1506    /// }
1507    /// ```
1508    pub fn tril(self, diagonal: i64) -> Self {
1509        check!(TensorCheck::tri::<{ D }>());
1510
1511        // last two dimensions
1512        let shape = &self.shape().dims[D - 2..].to_owned();
1513        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1514
1515        self.mask_fill(mask, 0)
1516    }
1517
1518    /// Applies element wise power operation with a float Tensor
1519    ///
1520    /// # Arguments
1521    ///
1522    /// * `other` - The tensor to apply the power operation with.
1523    ///
1524    /// # Example
1525    ///
1526    /// ```rust
1527    /// use burn_tensor::backend::Backend;
1528    /// use burn_tensor::{Tensor, Shape};
1529    ///
1530    /// fn example<B: Backend>() {
1531    ///    let device = B::Device::default();
1532    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1533    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1534    ///    let tensor = tensor1.powf(tensor2);
1535    ///    println!("{tensor}");
1536    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1537    /// }
1538    /// ```
1539    pub fn powf(self, other: Self) -> Self {
1540        Self::new(K::powf(self.primitive, other.primitive))
1541    }
1542
1543    /// Applies element wise power operation with a float scalar
1544    ///
1545    /// # Arguments
1546    ///
1547    /// * `other` - The scalar to apply the power operation with.
1548    ///
1549    /// # Example
1550    ///
1551    /// ```rust
1552    /// use burn_tensor::backend::Backend;
1553    /// use burn_tensor::{Tensor, Shape};
1554    ///
1555    /// fn example<B: Backend>() {
1556    ///    let device = B::Device::default();
1557    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1558    ///    let tensor = tensor.powf_scalar(2.0);
1559    ///    println!("{tensor}");
1560    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1561    /// }
1562    /// ```
1563    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1564        Self::new(K::powf_scalar::<E>(self.primitive, other))
1565    }
1566
1567    /// Applies element wise power operation with a integer Tensor
1568    ///
1569    /// # Arguments
1570    ///
1571    /// * `other` - The tensor to apply the power operation with.
1572    ///
1573    /// # Example
1574    ///
1575    /// ```rust
1576    /// use burn_tensor::backend::Backend;
1577    /// use burn_tensor::{Tensor, Shape, Int};
1578    ///
1579    /// fn example<B: Backend>() {
1580    ///    let device = B::Device::default();
1581    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1582    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1583    ///    let tensor = tensor1.powi(tensor2);
1584    ///    println!("{tensor}");
1585    ///    // [[1, -8, 81], [5, 81, 216]]
1586    /// }
1587    /// ```
1588    pub fn powi(self, other: Self) -> Self {
1589        Self::new(K::powi(self.primitive, other.primitive))
1590    }
1591
1592    /// Applies element wise power operation with a integer scalar
1593    ///
1594    /// # Arguments
1595    ///
1596    /// * `other` - The scalar to apply the power operation with.
1597    ///
1598    /// # Example
1599    ///
1600    /// ```rust
1601    /// use burn_tensor::backend::Backend;
1602    /// use burn_tensor::{Tensor, Shape, Int};
1603    ///
1604    /// fn example<B: Backend>() {
1605    ///    let device = B::Device::default();
1606    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1607    ///    let tensor = tensor.powi_scalar(2);
1608    ///    println!("{tensor}");
1609    ///
1610    ///    // [[1, 4, 9], [25, 81, 36]]
1611    ///    let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
1612    ///    let tensor = tensor.powi_scalar(2);
1613    ///    println!("{tensor}");
1614    ///    // [[2.25, 4., 9.], [25., 81., 36.]]
1615    /// }
1616    /// ```
1617    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1618        Self::new(K::powi_scalar::<E>(self.primitive, other))
1619    }
1620
1621    /// Checks element wise if the tensor is close to another tensor.
1622    ///
1623    /// The tolerance is defined by the following equation:
1624    ///
1625    /// ```text
1626    /// abs(a - b) <= (atol + rtol * abs(b))
1627    ///
1628    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1629    /// and `atol` is the absolute tolerance.
1630    /// ```
1631    ///
1632    /// # Arguments
1633    ///
1634    /// * `other` - The tensor to compare with.
1635    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
1636    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
1637    ///
1638    /// # Returns
1639    ///
1640    /// A boolean tensor with the same shape as the input tensors.
1641    ///
1642    /// # Example
1643    ///
1644    /// ```rust
1645    /// use burn_tensor::backend::Backend;
1646    /// use burn_tensor::{Tensor, Shape};
1647    ///
1648    /// fn example<B: Backend>() {
1649    ///    let device = B::Device::default();
1650    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1651    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1652    ///    let tensor = tensor1.is_close(tensor2, None, None);
1653    ///    println!("{tensor}");
1654    ///    // [[true, true, true], [true, true, true]]
1655    /// }
1656    /// ```
1657    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
1658        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
1659        let atol = atol.unwrap_or(DEFAULT_ATOL);
1660
1661        Tensor::new(K::lower_equal(
1662            K::abs(K::sub(self.primitive, other.primitive.clone())),
1663            K::add_scalar(K::mul_scalar(K::abs(other.primitive), rtol), atol),
1664        ))
1665    }
1666
1667    /// Checks if all elements are close to another tensor.
1668    ///
1669    /// The tolerance is defined by the following equation:
1670    ///
1671    /// ```text
1672    ///
1673    /// abs(a - b) <= (atol + rtol * abs(b))
1674    ///
1675    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1676    /// and `atol` is the absolute tolerance.
1677    ///
1678    /// ```
1679    ///
1680    /// # Arguments
1681    ///
1682    /// * `other` - The tensor to compare with.
1683    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
1684    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
1685    ///
1686    /// # Returns
1687    ///
1688    /// A boolean scalar.
1689    ///
1690    /// # Remarks
1691    ///
1692    /// # Example
1693    ///
1694    /// ```rust
1695    /// use burn_tensor::backend::Backend;
1696    /// use burn_tensor::{Tensor, Shape};
1697    ///
1698    /// fn example<B: Backend>() {
1699    ///    let device = B::Device::default();
1700    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1701    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1702    ///    let result = tensor1.all_close(tensor2, None, None);
1703    ///    println!("{}", result);
1704    ///    // true
1705    /// }
1706    /// ```
1707    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
1708        self.is_close(other, rtol, atol)
1709            .all()
1710            .into_scalar()
1711            .to_bool()
1712    }
1713
1714    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1715    ///
1716    /// # Returns
1717    ///
1718    /// A boolean tensor with the same shape as the input tensor.
1719    ///
1720    /// # Example
1721    ///
1722    /// ```rust
1723    /// use burn_tensor::backend::Backend;
1724    /// use burn_tensor::{Tensor, Shape};
1725    ///
1726    /// fn example<B: Backend>() {
1727    ///   let device = B::Device::default();
1728    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1729    ///   let tensor = tensor.bool();
1730    ///   println!("{tensor}");
1731    ///   // [
1732    ///   //   [true, true, true],
1733    ///   //   [false, true, true]
1734    ///   // ]
1735    /// }
1736    pub fn bool(self) -> Tensor<B, D, Bool> {
1737        Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1738    }
1739
1740    /// Create a random tensor of the given shape on the given device where each element is
1741    /// sampled from the given distribution.
1742    ///
1743    /// See also [`random_like`](Tensor::random_like).
1744    ///
1745    /// # Arguments
1746    ///
1747    /// * `shape` - The shape of the tensor.
1748    /// * `distribution` - The distribution to sample from.
1749    /// * `device` - The device to create the tensor on.
1750    ///
1751    /// # Returns
1752    ///
1753    /// A new tensor with the given shape and elements sampled from the given distribution.
1754    ///
1755    /// # Example
1756    ///
1757    /// ```rust
1758    /// use burn_tensor::backend::Backend;
1759    /// use burn_tensor::{Tensor, Shape, Distribution};
1760    ///
1761    /// fn example<B: Backend>() {
1762    ///   let device = B::Device::default();
1763    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1764    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1765    ///   println!("{tensor}");
1766    ///   // [
1767    ///   //   [0.08347523, 0.70498955, 0.60332155],
1768    ///   //   [0.08173251, 0.18028641, 0.97942924]
1769    ///   // ]
1770    /// }
1771    /// ```
1772    pub fn random<S: Into<Shape>>(
1773        shape: S,
1774        distribution: Distribution,
1775        device: &B::Device,
1776    ) -> Self {
1777        Self::new(K::random(shape.into(), distribution, device))
1778    }
1779
1780    /// Sort the elements by value in ascending order along a given dimension.
1781    ///
1782    /// This sort is unstable (i.e., may reorder equal elements).
1783    ///
1784    /// # Arguments
1785    ///
1786    /// * `dim` - The dimension to sort along.
1787    ///
1788    /// # Returns
1789    ///
1790    /// A new tensor with the elements sorted in ascending order along the given dimension.
1791    ///
1792    /// # Example
1793    ///
1794    /// ```rust
1795    /// use burn_tensor::backend::Backend;
1796    /// use burn_tensor::{Tensor, Shape};
1797    ///
1798    /// fn example<B: Backend>() {
1799    ///   let device = B::Device::default();
1800    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1801    ///   let tensor = tensor.sort(0);
1802    ///   println!("{tensor}");
1803    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1804    ///   let tensor = tensor.sort(1);
1805    ///   println!("{tensor}");
1806    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1807    /// }
1808    /// ```
1809    pub fn sort(self, dim: usize) -> Tensor<B, D, K> {
1810        check!(TensorCheck::sort_dim::<D>("Sort", dim));
1811        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
1812    }
1813
1814    /// Sort the elements by value in descending order along a given dimension.
1815    ///
1816    /// This sort is unstable (i.e., may reorder equal elements).
1817    ///
1818    /// # Arguments
1819    ///
1820    /// * `dim` - The dimension to sort along.
1821    ///
1822    /// # Returns
1823    ///
1824    /// A new tensor with the elements sorted in descending order along the given dimension.
1825    ///
1826    /// # Example
1827    ///
1828    /// ```rust
1829    /// use burn_tensor::backend::Backend;
1830    /// use burn_tensor::{Tensor, Shape};
1831    ///
1832    /// fn example<B: Backend>() {
1833    ///    let device = B::Device::default();
1834    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1835    ///    let tensor = tensor.sort_descending(0);
1836    ///    println!("{tensor}");
1837    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1838    ///    let tensor = tensor.sort_descending(1);
1839    ///    println!("{tensor}");
1840    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
1841    /// }
1842    /// ```
1843    pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K> {
1844        check!(TensorCheck::sort_dim::<D>("Sort", dim));
1845        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
1846    }
1847
1848    /// Sort the elements by value in ascending order along a given dimension.
1849    /// Also returns the indices.
1850    ///
1851    /// This sort is unstable (i.e., may reorder equal elements).
1852    ///
1853    /// # Arguments
1854    ///
1855    /// * `dim` - The dimension to sort along.
1856    ///
1857    /// # Returns
1858    ///
1859    /// A tuple containing the sorted tensor and the indices tensor.
1860    ///
1861    /// # Example
1862    ///
1863    /// ```rust
1864    /// use burn_tensor::backend::Backend;
1865    /// use burn_tensor::{Tensor, Shape};
1866    ///
1867    /// fn example<B: Backend>() {
1868    ///   let device = B::Device::default();
1869    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1870    ///   let (tensor, indices) = tensor.sort_with_indices(0);
1871    ///   println!("{tensor}");
1872    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1873    ///   println!("{}", indices);
1874    ///   // [[1, 0, 0], [0, 1, 1]]
1875    /// }
1876    /// ```
1877    pub fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1878        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1879        let (values, indices) =
1880            K::sort_with_indices(self.primitive, dim, /*descending*/ false);
1881        (Tensor::new(values), Tensor::new(indices))
1882    }
1883
1884    /// Sort the elements by value in descending order along a given dimension.
1885    /// Also returns the indices.
1886    ///
1887    /// This sort is unstable (i.e., may reorder equal elements).
1888    ///
1889    /// # Arguments
1890    ///
1891    /// * `dim` - The dimension to sort along.
1892    ///
1893    /// # Example
1894    ///
1895    /// ```rust
1896    /// use burn_tensor::backend::Backend;
1897    /// use burn_tensor::{Tensor, Shape};
1898    ///
1899    /// fn example<B: Backend>() {
1900    ///    let device = B::Device::default();
1901    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1902    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);
1903    ///    println!("{tensor}");
1904    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1905    ///    println!("{}", indices);
1906    ///    // [[0, 1, 1], [1, 0, 0]]
1907    /// }
1908    /// ```
1909    pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1910        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1911        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
1912        (Tensor::new(values), Tensor::new(indices))
1913    }
1914
1915    /// Returns the indices that sort the elements by value in ascending order along a given dimension.
1916    ///
1917    /// This sort is unstable (i.e., may reorder equal elements).
1918    ///
1919    /// # Arguments
1920    ///
1921    /// * `dim` - The dimension to sort along.
1922    ///
1923    /// # Example
1924    ///
1925    /// ```rust
1926    /// use burn_tensor::backend::Backend;
1927    /// use burn_tensor::{Tensor, Shape};
1928    ///
1929    /// fn example<B: Backend>() {
1930    ///    let device = B::Device::default();
1931    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1932    ///    let tensor = tensor.argsort(0);
1933    ///    println!("{tensor}");
1934    ///    // [[1, 0, 0], [0, 1, 1]]
1935    /// }
1936    /// ```
1937    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
1938        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1939        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
1940    }
1941
1942    /// Returns the indices that sort the elements by value in descending order along a given dimension.
1943    ///
1944    /// This sort is unstable (i.e., may reorder equal elements).
1945    ///
1946    /// # Arguments
1947    ///
1948    /// * `dim` - The dimension to sort along.
1949    ///
1950    /// # Example
1951    ///
1952    /// ```rust
1953    /// use burn_tensor::backend::Backend;
1954    /// use burn_tensor::{Tensor, Shape};
1955    ///
1956    /// fn example<B: Backend>() {
1957    ///    let device = B::Device::default();
1958    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1959    ///    let tensor = tensor.argsort_descending(0);
1960    ///    println!("{tensor}");
1961    ///    // [[0, 1, 1], [1, 0, 0]]
1962    ///    let tensor = tensor.argsort_descending(1);
1963    ///    println!("{tensor}");
1964    ///    // [[0, 2, 1], [2, 0, 1]]
1965    /// }
1966    /// ```
1967    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
1968        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1969        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
1970    }
1971
1972    /// Returns the `k` largest elements of the given input tensor along a given dimension.
1973    ///
1974    /// # Arguments
1975    ///
1976    /// * `k` - The number of elements to return.
1977    ///
1978    /// # Returns
1979    ///
1980    /// A new tensor with the `k` largest elements along the given dimension.
1981    ///
1982    /// # Example
1983    ///
1984    /// ```rust
1985    /// use burn_tensor::backend::Backend;
1986    /// use burn_tensor::{Tensor, Shape};
1987    ///
1988    /// fn example<B: Backend>() {
1989    ///   let device = B::Device::default();
1990    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1991    ///   let tensor = tensor.topk(2, 0);
1992    ///   println!("{tensor}");
1993    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1994    ///   let tensor = tensor.topk(1, 1);
1995    ///   println!("{tensor}");   
1996    ///   // [[12.0], [6.0]]
1997    /// }
1998    /// ```
1999    pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
2000        let k_indices = Tensor::arange(0..k as i64, &self.device());
2001        self.sort_descending(dim).select(dim, k_indices)
2002    }
2003
2004    /// Returns the `k` largest elements of the given input tensor along a given dimension.
2005    /// Also returns the indices.
2006    ///
2007    /// # Arguments
2008    ///
2009    /// * `k` - The number of elements to return.
2010    /// * `dim` - The dimension to sort along.
2011    ///
2012    /// # Example
2013    ///
2014    /// ```rust
2015    /// use burn_tensor::backend::Backend;
2016    /// use burn_tensor::{Tensor, Shape};
2017    ///
2018    /// fn example<B: Backend>() {
2019    ///    let device = B::Device::default();
2020    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2021    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);
2022    ///    println!("{tensor}");
2023    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2024    ///    println!("{}", indices);
2025    ///    // [[0, 1, 1], [1, 0, 0]]
2026    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);
2027    ///    println!("{tensor}");
2028    ///    // [[12.0], [6.0]]
2029    ///    println!("{indices}");
2030    ///    // [[0], [2]]
2031    /// }
2032    /// ```
2033    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
2034        let k_indices = Tensor::arange(0..k as i64, &self.device());
2035        let (values, indices) = self.sort_descending_with_indices(dim);
2036        (
2037            values.select(dim, k_indices.clone()),
2038            indices.select(dim, k_indices),
2039        )
2040    }
2041
2042    /// Pad the tensor of rank two or higher with the given value on the last two dimensions.
2043    ///
2044    /// # Arguments
2045    ///
2046    /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
2047    /// * `value` - The value to pad the tensor with.
2048    ///
2049    /// # Returns
2050    ///
2051    /// A new tensor with the given padding.
2052    ///
2053    /// # Example
2054    ///
2055    /// ```rust
2056    /// use burn_tensor::backend::Backend;
2057    /// use burn_tensor::{Tensor, Shape};
2058    ///
2059    /// fn example<B: Backend<FloatElem: From<f32>>>() {
2060    ///    let device = B::Device::default();
2061    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2062    ///    let tensor = tensor.pad((1, 1, 1, 1), 0.0);
2063    ///    println!("{tensor}");
2064    ///    // [
2065    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0],
2066    ///    //   [0.0, 12.0, -2.0, 3.0, 0.0],
2067    ///    //   [0.0, 5.0, 3.0, 6.0, 0.0],
2068    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0]
2069    ///    // ]
2070    /// }
2071    /// ```
2072    pub fn pad<E: ElementConversion>(
2073        self,
2074        padding: (usize, usize, usize, usize),
2075        value: E,
2076    ) -> Tensor<B, D, K> {
2077        let (left, right, top, bottom) = padding;
2078
2079        let mut padded_dims: [usize; D] = self.dims();
2080
2081        // Update the last two dimensions with padding
2082        padded_dims[D - 2] += top + bottom;
2083        padded_dims[D - 1] += left + right;
2084
2085        // Create the ranges for the padded tensor
2086        let ranges: [core::ops::Range<usize>; D] = padded_dims
2087            .iter()
2088            .enumerate()
2089            .map(|(i, &dim)| {
2090                if i == D - 2 {
2091                    top..dim - bottom
2092                } else if i == D - 1 {
2093                    left..dim - right
2094                } else {
2095                    0..dim
2096                }
2097            })
2098            .collect::<Vec<core::ops::Range<usize>>>()
2099            .try_into()
2100            .unwrap();
2101
2102        // Create the padded tensor
2103        let padded_tensor = Tensor::full(padded_dims, value, &self.device());
2104
2105        // Assign the original tensor data to the appropriate slice of the padded tensor
2106        padded_tensor.slice_assign(ranges, self)
2107    }
2108    /// Create a one hot tensor.
2109    ///
2110    /// # Example
2111    ///
2112    /// ```rust
2113    /// use burn_tensor::backend::Backend;
2114    /// use burn_tensor::Tensor;
2115    ///
2116    /// fn example<B: Backend>(){
2117    ///     let device = Default::default();
2118    ///     let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
2119    ///     let one_hot: Tensor<B, 2> = indices.one_hot(4);
2120    ///     println!("{}", one_hot.to_data());
2121    ///     // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
2122    /// }
2123    /// ```
2124    pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
2125        check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
2126        self.one_hot_fill(num_classes, 1.0, 0.0, -1)
2127    }
2128
2129    /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
2130    ///
2131    /// # Arguments
2132    ///
2133    /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
2134    /// * `on_value`: The value to assign for active positions (corresponding to indices).
2135    /// * `off_value`: The value to assign for inactive positions.
2136    /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
2137    ///
2138    /// # Returns
2139    ///
2140    /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
2141    ///
2142    /// # Example
2143    /// ```rust
2144    /// use burn_tensor::backend::Backend;
2145    /// use burn_tensor::{Tensor, Float};
2146    /// fn example<B: Backend<FloatElem: From<f32>>>() {
2147    ///     let device = B::Device::default();
2148    ///     let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
2149    ///     // One-hot encoding
2150    ///     let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
2151    ///     println!("{tensor}");
2152    ///     // [[[5.0, 0.0, 0.0],
2153    ///     // [0.0, 0.0, 5.0]],
2154    ///     // [[0.0, 5.0, 0.0],
2155    ///     // [0.0, 0.0, 5.0]]]
2156    /// }
2157    /// ```
2158    pub fn one_hot_fill<const D2: usize>(
2159        self,
2160        num_classes: usize,
2161        on_value: f32,
2162        off_value: f32,
2163        axis: i64,
2164    ) -> Tensor<B, D2, K> {
2165        check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
2166        // Initialize shape from the current tensor dimensions and prepare for modification
2167        let mut shape = self.shape().dims::<D>().to_vec();
2168        let device = self.device();
2169        let rank = self.dims().len();
2170
2171        // Adjust negative axis to a positive index
2172        let axis = if axis < 0 {
2173            axis + rank as i64 + 1
2174        } else {
2175            axis
2176        };
2177
2178        // Ensure axis is within valid range
2179        if axis < 0 || axis > rank as i64 {
2180            panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
2181        }
2182        // Convert the input tensor to integer indices
2183        let indices: Tensor<B, D, Int> =
2184            Tensor::from_data(self.to_data().convert::<i64>(), &device);
2185        // Insert the new dimension for the one-hot representation
2186        shape.insert(axis as usize, num_classes);
2187        // Adjust indices to valid range and handle invalid indices
2188        let adjusted_indices = indices
2189            .clone()
2190            .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
2191            .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
2192        // Unsqueeze the indices tensor along the specified axis
2193        let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
2194
2195        // Initialize the output tensor with the off_value
2196        let output = Tensor::full(shape.clone(), off_value, &device);
2197
2198        // Prepare scatter tensor for on_value and off_value adjustments
2199        let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
2200            - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
2201
2202        // Scatter on_value at the appropriate indices to create the one-hot representation
2203        output.scatter(axis as usize, indices_unsqueezed, scatter_on_values)
2204    }
2205
2206    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
2207    ///
2208    /// # Returns
2209    ///
2210    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
2211    ///
2212    /// # Example
2213    ///
2214    /// ```rust
2215    /// use burn_tensor::backend::Backend;
2216    /// use burn_tensor::{Tensor, Bool, Shape};
2217    ///
2218    /// fn example<B: Backend>() {
2219    ///    let device = B::Device::default();
2220    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
2221    ///    let tensor = tensor.is_nan();
2222    ///    println!("{tensor}");
2223    ///    // [[false, true, false], [false, false, false]]
2224    /// }
2225    /// ```
2226    pub fn is_nan(&self) -> Tensor<B, D, Bool> {
2227        // Check if the input tensor is NaN by comparing it to itself
2228        // NaN is the only value that is not equal to itself
2229        Tensor::new(K::not_equal(self.primitive.clone(), self.primitive.clone()))
2230    }
2231
2232    /// Checks if the tensor contains any NaN values.
2233    ///
2234    /// # Returns
2235    ///
2236    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
2237    ///
2238    /// # Example
2239    ///
2240    /// ```rust
2241    /// use burn_tensor::backend::Backend;
2242    /// use burn_tensor::{Tensor, Bool, Shape};
2243    ///
2244    /// fn example<B: Backend>() {
2245    ///   let device = B::Device::default();
2246    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
2247    ///   let tensor = tensor.contains_nan();
2248    ///   println!("{tensor}");
2249    ///   // [true]
2250    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2251    ///   let tensor = tensor.contains_nan();
2252    ///   println!("{tensor}");
2253    ///   // [false]
2254    /// }
2255    /// ```
2256    pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
2257        // Summing the tensor will result in NaN if the tensor contains any NaN values
2258        // This is faster than checking each element individually
2259        // because it rolls up the NaN values into a single value
2260        let sum = K::sum(self.primitive.clone());
2261
2262        // Check if the sum is NaN by comparing it to itself
2263        Tensor::new(K::not_equal(sum.clone(), sum))
2264    }
2265}
2266
2267impl<B, K> Tensor<B, 2, K>
2268where
2269    B: Backend,
2270    K: Numeric<B>,
2271    K::Elem: Element,
2272{
2273    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2274    ///
2275    /// # Arguments
2276    ///
2277    /// * `size` - The size of the square matrix.
2278    pub fn eye(size: usize, device: &B::Device) -> Self {
2279        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2280        let ones = K::ones([1, size].into(), device);
2281        let zeros = K::zeros([size, size].into(), device);
2282
2283        Self::new(K::scatter(0, zeros, indices.primitive, ones))
2284    }
2285}
2286
2287/// Trait that list all operations that can be applied on all numerical tensors.
2288///
2289/// # Warnings
2290///
2291/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2292pub trait Numeric<B: Backend>: BasicOps<B>
2293where
2294    Self::Elem: Element,
2295{
2296    /// Adds two tensors together.
2297    ///
2298    /// # Arguments
2299    ///
2300    /// * `lhs` - The left hand side tensor.
2301    /// * `rhs` - The right hand side tensor.
2302    ///
2303    /// # Returns
2304    ///
2305    /// The sum of the two tensors.
2306    ///
2307    /// # Remarks
2308    ///
2309    /// This is a low-level function used internally by the library to call different backend functions
2310    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2311    /// or use this function directly.
2312    ///
2313    /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
2314    /// which is more high-level and designed for public use.
2315    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2316
2317    /// Adds a scalar to a tensor element-wise.
2318    ///
2319    /// # Arguments
2320    ///
2321    /// * `lhs` - The left hand side tensor.
2322    /// * `rhs` - The right hand side scalar.
2323    ///
2324    /// # Returns
2325    ///
2326    /// The sum of the tensor and the scalar.
2327    ///
2328    /// # Remarks
2329    ///
2330    /// This is a low-level function used internally by the library to call different backend functions
2331    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2332    /// or use this function directly.
2333    ///
2334    /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
2335    /// which is more high-level and designed for public use.
2336    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2337
2338    /// Subtracts two tensors.
2339    ///
2340    /// # Arguments
2341    ///
2342    /// * `lhs` - The left hand side tensor.
2343    /// * `rhs` - The right hand side tensor.
2344    ///
2345    /// # Returns
2346    ///
2347    /// The difference of the two tensors.
2348    ///
2349    /// # Remarks
2350    ///
2351    /// This is a low-level function used internally by the library to call different backend functions
2352    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2353    /// or use this function directly.
2354    ///
2355    /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
2356    /// which is more high-level and designed for public use.
2357    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2358
2359    /// Subtracts a scalar from a tensor element-wise.
2360    ///
2361    /// # Arguments
2362    ///
2363    /// * `lhs` - The left hand side tensor.
2364    /// * `rhs` - The right hand side scalar.
2365    ///
2366    /// # Returns
2367    ///
2368    /// The difference of the tensor and the scalar.
2369    ///
2370    /// # Remarks
2371    ///
2372    /// This is a low-level function used internally by the library to call different backend functions
2373    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2374    /// or use this function directly.
2375    ///
2376    /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
2377    /// which is more high-level and designed for public use.
2378    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2379
2380    /// Divides two tensors.
2381    ///
2382    /// # Arguments
2383    ///
2384    /// * `lhs` - The left hand side tensor.
2385    /// * `rhs` - The right hand side tensor.
2386    ///
2387    /// # Returns
2388    ///
2389    /// The quotient of the two tensors.
2390    ///
2391    /// # Remarks
2392    ///
2393    /// This is a low-level function used internally by the library to call different backend functions
2394    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2395    /// or use this function directly.
2396    ///
2397    /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
2398    /// which is more high-level and designed for public use.
2399    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2400
2401    /// Divides a tensor by a scalar element-wise.
2402    ///
2403    /// # Arguments
2404    ///
2405    /// * `lhs` - The left hand side tensor.
2406    /// * `rhs` - The right hand side scalar.
2407    ///
2408    /// # Returns
2409    ///
2410    /// The quotient of the tensor and the scalar.
2411    ///
2412    /// # Remarks
2413    ///
2414    /// This is a low-level function used internally by the library to call different backend functions
2415    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2416    /// or use this function directly.
2417    ///
2418    /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
2419    /// which is more high-level and designed for public use.
2420    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2421
2422    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2423    /// less than that of the divisor.
2424    ///
2425    /// # Arguments
2426    ///
2427    /// * `lhs` - The dividend.
2428    /// * `rhs` - The divisor.
2429    ///
2430    /// # Returns
2431    ///
2432    /// The modulo of the input tensor with the divisor.
2433    ///
2434    /// # Remarks
2435    ///
2436    /// This is a low-level function used internally by the library to call different backend functions
2437    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2438    /// or use this function directly.
2439    ///
2440    /// For performing the modulo operation, users should prefer the [Tensor::remainder](Tensor::remainder) function,
2441    /// which is more high-level and designed for public use.
2442    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2443
2444    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2445    /// less than that of the divisor.
2446    ///
2447    /// # Arguments
2448    ///
2449    /// * `lhs` - The dividend.
2450    /// * `rhs` - The divisor.
2451    ///
2452    /// # Returns
2453    ///
2454    /// The modulo of the input tensor with the divisor.
2455    ///
2456    /// # Remarks
2457    ///
2458    /// This is a low-level function used internally by the library to call different backend functions
2459    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2460    /// or use this function directly.
2461    ///
2462    /// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
2463    /// which is more high-level and designed for public use.
2464    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2465
2466    /// Multiplies two tensors.
2467    ///
2468    /// # Arguments
2469    ///
2470    /// * `lhs` - The left hand side tensor.
2471    /// * `rhs` - The right hand side tensor.
2472    ///
2473    /// # Returns
2474    ///
2475    /// The product of the two tensors.
2476    ///
2477    /// # Remarks
2478    ///
2479    /// This is a low-level function used internally by the library to call different backend functions
2480    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2481    /// or use this function directly.
2482    ///
2483    /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
2484    /// which is more high-level and designed for public use.
2485    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2486
2487    /// Multiplies a tensor by a scalar element-wise.
2488    ///
2489    /// # Arguments
2490    ///
2491    /// * `lhs` - The left hand side tensor.
2492    /// * `rhs` - The right hand side scalar.
2493    ///
2494    /// # Returns
2495    ///
2496    /// The product of the tensor and the scalar.
2497    ///
2498    /// # Remarks
2499    ///
2500    /// This is a low-level function used internally by the library to call different backend functions
2501    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2502    /// or use this function directly.
2503    ///
2504    /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
2505    /// which is more high-level and designed for public use.
2506    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2507
2508    /// Negates a tensor.
2509    ///
2510    /// # Arguments
2511    ///
2512    /// * `tensor` - The tensor to negate.
2513    ///
2514    /// # Returns
2515    ///
2516    /// The negated tensor.
2517    ///
2518    /// # Remarks
2519    ///
2520    /// This is a low-level function used internally by the library to call different backend functions
2521    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2522    /// or use this function directly.
2523    ///
2524    /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
2525    /// which is more high-level and designed for public use.
2526    fn neg(tensor: Self::Primitive) -> Self::Primitive;
2527
2528    /// Returns the signs of the elements of a tensor.
2529    ///
2530    /// # Arguments
2531    ///
2532    /// * `tensor` - The tensor.
2533    ///
2534    /// # Returns
2535    ///
2536    /// The signs of the elements of the tensor.
2537    ///
2538    /// # Remarks
2539    ///
2540    /// This is a low-level function used internally by the library to call different backend functions
2541    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2542    /// or use this function directly.
2543    ///
2544    /// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
2545    /// which is more high-level and designed for public use.
2546    fn sign(tensor: Self::Primitive) -> Self::Primitive;
2547
2548    /// Creates a tensor filled with zeros.
2549    ///
2550    /// # Arguments
2551    ///
2552    /// * `shape` - The shape of the tensor.
2553    /// * `device` - The device on which the tensor will be allocated.
2554    ///
2555    /// # Returns
2556    ///
2557    /// The tensor filled with zeros.
2558    ///
2559    /// # Remarks
2560    ///
2561    /// This is a low-level function used internally by the library to call different backend functions
2562    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2563    /// or use this function directly.
2564    ///
2565    /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
2566    /// which is more high-level and designed for public use.
2567    fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive;
2568
2569    /// Creates a tensor filled with ones.
2570    ///
2571    /// # Arguments
2572    ///
2573    /// * `shape` - The shape of the tensor.
2574    /// * `device` - The device on which the tensor will be allocated.
2575    ///
2576    /// # Returns
2577    ///
2578    /// The tensor filled with ones.
2579    ///
2580    /// # Remarks
2581    ///
2582    /// This is a low-level function used internally by the library to call different backend functions
2583    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2584    /// or use this function directly.
2585    ///
2586    /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
2587    /// which is more high-level and designed for public use.
2588    fn ones(shape: Shape, device: &B::Device) -> Self::Primitive;
2589
2590    /// Creates a tensor filled with elements equal to the given value.
2591    ///
2592    /// # Arguments
2593    ///
2594    /// * `shape` - The shape of the tensor.
2595    /// * `fill_value` - The value with which to fill the tensor
2596    /// * `device` - The device on which the tensor will be allocated.
2597    ///
2598    /// # Returns
2599    ///
2600    /// The tensor filled with elements equal to the given value
2601    ///
2602    /// # Remarks
2603    ///
2604    /// This is a low-level function used internally by the library to call different backend functions
2605    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2606    /// or use this function directly.
2607    ///
2608    /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function,
2609    /// which is more high-level and designed for public use.
2610    fn full<E: ElementConversion>(
2611        shape: Shape,
2612        fill_value: E,
2613        device: &B::Device,
2614    ) -> Self::Primitive;
2615
2616    /// Sums all the elements of the tensor.
2617    ///
2618    /// # Arguments
2619    ///
2620    /// * `tensor` - The tensor to sum.
2621    ///
2622    /// # Returns
2623    ///
2624    /// The sum of all the elements of the tensor.
2625    ///
2626    /// # Remarks
2627    ///
2628    /// This is a low-level function used internally by the library to call different backend functions
2629    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2630    /// or use this function directly.
2631    ///
2632    /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
2633    /// which is more high-level and designed for public use.
2634    fn sum(tensor: Self::Primitive) -> Self::Primitive;
2635
2636    /// Sums all the elements of the tensor along a dimension.
2637    ///
2638    /// # Arguments
2639    ///
2640    /// * `tensor` - The tensor to sum.
2641    /// * `dim` - The dimension along which to sum.
2642    ///
2643    /// # Returns
2644    ///
2645    /// The sum of all the elements of the tensor along the specified dimension.
2646    ///
2647    /// # Remarks
2648    ///
2649    /// This is a low-level function used internally by the library to call different backend functions
2650    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2651    /// or use this function directly.
2652    ///
2653    /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
2654    /// which is more high-level and designed for public use.
2655    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2656
2657    /// Computes the product of all the elements of the tensor.
2658    ///
2659    /// # Arguments
2660    ///
2661    /// * `tensor` - The tensor to compute the product of.
2662    ///
2663    /// # Returns
2664    ///
2665    /// The product of all the elements of the tensor.
2666    ///
2667    /// # Remarks
2668    ///
2669    /// This is a low-level function used internally by the library to call different backend functions
2670    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2671    /// or use this function directly.
2672    ///
2673    /// For computing the product of all the elements of a tensor, users should prefer the
2674    /// [Tensor::prod](Tensor::prod) function,
2675    /// which is more high-level and designed for public use.
2676    fn prod(tensor: Self::Primitive) -> Self::Primitive;
2677
2678    /// Computes the product of all the elements of the tensor along a dimension.
2679    ///
2680    /// # Arguments
2681    ///
2682    /// * `tensor` - The tensor to compute the product of.
2683    /// * `dim` - The dimension along which to compute the product.
2684    ///
2685    /// # Returns
2686    ///
2687    /// The product of all the elements of the tensor along the specified dimension.
2688    ///
2689    /// # Remarks
2690    ///
2691    /// This is a low-level function used internally by the library to call different backend functions
2692    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2693    /// or use this function directly.
2694    ///
2695    /// For computing the product of all the elements of a tensor along a dimension, users should
2696    /// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
2697    /// which is more high-level and designed for public use.
2698    ///
2699    ///
2700    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2701
2702    /// Computes the mean of all the elements of the tensor.
2703    ///
2704    /// # Arguments
2705    ///
2706    /// * `tensor` - The tensor to compute the mean of.
2707    ///
2708    /// # Returns
2709    ///
2710    /// The mean of all the elements of the tensor.
2711    ///
2712    /// # Remarks
2713    ///
2714    /// This is a low-level function used internally by the library to call different backend functions
2715    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2716    /// or use this function directly.
2717    ///
2718    /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
2719    /// which is more high-level and designed for public use.
2720    fn mean(tensor: Self::Primitive) -> Self::Primitive;
2721
2722    /// Computes the mean of all the elements of the tensor along a dimension.
2723    ///
2724    /// # Arguments
2725    ///
2726    /// * `tensor` - The tensor to compute the mean of.
2727    /// * `dim` - The dimension along which to compute the mean.
2728    ///
2729    /// # Returns
2730    ///
2731    /// The mean of all the elements of the tensor along the specified dimension.
2732    ///
2733    /// # Remarks
2734    ///
2735    /// This is a low-level function used internally by the library to call different backend functions
2736    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2737    /// or use this function directly.
2738    ///
2739    /// For computing the mean of all the elements of a tensor along a dimension, users should prefer
2740    /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
2741    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2742
2743    /// Element-wise equality between two tensors.
2744    ///
2745    /// # Arguments
2746    ///
2747    /// * `lhs` - The left hand side tensor.
2748    /// * `rhs` - The right hand side tensor.
2749    ///
2750    /// # Returns
2751    ///
2752    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2753    /// corresponding elements of the input tensors are equal, and false otherwise.
2754    ///
2755    /// # Remarks
2756    ///
2757    /// This is a low-level function used internally by the library to call different backend functions
2758    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2759    /// or use this function directly.
2760    ///
2761    /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem)
2762    /// function, which is more high-level and designed for public use.
2763    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2764
2765    /// Element-wise non-equality between two tensors.
2766    ///
2767    /// # Arguments
2768    ///
2769    /// * `lhs` - The left hand side tensor.
2770    /// * `rhs` - The right hand side tensor.
2771    ///
2772    /// # Returns
2773    ///
2774    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2775    /// corresponding elements of the input tensors are equal, and false otherwise.
2776    ///
2777    /// # Remarks
2778    ///
2779    /// This is a low-level function used internally by the library to call different backend functions
2780    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2781    /// or use this function directly.
2782    ///
2783    /// For element-wise non-equality between two tensors, users should prefer the [Tensor::not_equal_elem](Tensor::not_equal_elem)
2784    /// function, which is more high-level and designed for public use.
2785    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2786
2787    /// Element-wise greater than comparison between two tensors.
2788    ///
2789    /// # Arguments
2790    ///
2791    /// * `lhs` - The left hand side tensor.
2792    /// * `rhs` - The right hand side tensor.
2793    ///
2794    /// # Returns
2795    ///
2796    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2797    /// corresponding element of the left hand side tensor is greater than the corresponding element
2798    /// of the right hand side tensor, and false otherwise.
2799    ///
2800    /// # Remarks
2801    ///
2802    /// This is a low-level function used internally by the library to call different backend functions
2803    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2804    /// or use this function directly.
2805    ///
2806    /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
2807    /// which is more high-level and designed for public use.
2808    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2809
2810    /// Element-wise greater than comparison between a tensor and a scalar.
2811    ///
2812    /// # Arguments
2813    ///
2814    /// * `lhs` - The left hand side tensor.
2815    /// * `rhs` - The right hand side scalar.
2816    ///
2817    /// # Returns
2818    ///
2819    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2820    /// corresponding element of the left hand side tensor is greater than the right hand side
2821    /// scalar, and false otherwise.
2822    ///
2823    /// # Remarks
2824    ///
2825    /// This is a low-level function used internally by the library to call different backend functions
2826    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2827    /// or use this function directly.
2828    ///
2829    /// For element-wise greater than comparison between a tensor and a scalar, users should prefer
2830    /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
2831    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2832
2833    /// Element-wise greater than or equal comparison between two tensors.
2834    ///
2835    /// # Arguments
2836    ///
2837    /// * `lhs` - The left hand side tensor.
2838    /// * `rhs` - The right hand side tensor.
2839    ///
2840    /// # Returns
2841    ///
2842    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2843    /// corresponding element of the left hand side tensor is greater than or equal to the
2844    /// corresponding element of the right hand side tensor, and false otherwise.
2845    ///
2846    /// # Remarks
2847    ///
2848    /// This is a low-level function used internally by the library to call different backend functions
2849    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2850    /// or use this function directly.
2851    ///
2852    /// For element-wise greater than or equal comparison between two tensors, users should prefer
2853    /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
2854    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2855
2856    /// Element-wise greater than or equal comparison between a tensor and a scalar.
2857    ///
2858    /// # Arguments
2859    ///
2860    /// * `lhs` - The left hand side tensor.
2861    /// * `rhs` - The right hand side scalar.
2862    ///
2863    /// # Returns
2864    ///
2865    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2866    /// corresponding element of the left hand side tensor is greater than or equal to the right
2867    /// hand side scalar, and false otherwise.
2868    ///
2869    /// # Remarks
2870    ///
2871    /// This is a low-level function used internally by the library to call different backend functions
2872    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2873    /// or use this function directly.
2874    ///
2875    /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
2876    /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
2877    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2878
2879    /// Element-wise less than comparison between two tensors.
2880    ///
2881    /// # Arguments
2882    ///
2883    /// * `lhs` - The left hand side tensor.
2884    /// * `rhs` - The right hand side tensor.
2885    ///
2886    /// # Returns
2887    ///
2888    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2889    /// corresponding element of the left hand side tensor is less than the corresponding element of
2890    /// the right hand side tensor, and false otherwise.
2891    ///
2892    /// # Remarks
2893    ///
2894    /// This is a low-level function used internally by the library to call different backend functions
2895    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2896    /// or use this function directly.
2897    ///
2898    /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
2899    /// which is more high-level and designed for public use.
2900    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2901
2902    /// Element-wise less than comparison between a tensor and a scalar.
2903    ///
2904    /// # Arguments
2905    ///
2906    /// * `lhs` - The left hand side tensor.
2907    /// * `rhs` - The right hand side scalar.
2908    ///
2909    /// # Returns
2910    ///
2911    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2912    /// corresponding element of the left hand side tensor is less than the right hand side scalar,
2913    /// and false otherwise.
2914    ///
2915    /// # Remarks
2916    ///
2917    /// This is a low-level function used internally by the library to call different backend functions
2918    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2919    /// or use this function directly.
2920    ///
2921    /// For element-wise less than comparison between a tensor and a scalar, users should prefer
2922    /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
2923    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2924
2925    /// Element-wise less than or equal comparison between two tensors.
2926    ///
2927    /// # Arguments
2928    ///
2929    /// * `lhs` - The left hand side tensor.
2930    /// * `rhs` - The right hand side tensor.
2931    ///
2932    /// # Returns
2933    ///
2934    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2935    /// corresponding element of the left hand side tensor is less than or equal to the corresponding
2936    /// element of the right hand side tensor, and false otherwise.
2937    ///
2938    /// # Remarks
2939    ///
2940    /// This is a low-level function used internally by the library to call different backend functions
2941    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2942    /// or use this function directly.
2943    ///
2944    /// For element-wise less than or equal comparison between two tensors, users should prefer
2945    /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
2946    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2947
2948    /// Element-wise less than or equal comparison between a tensor and a scalar.
2949    ///
2950    /// # Arguments
2951    ///
2952    /// * `lhs` - The left hand side tensor.
2953    /// * `rhs` - The right hand side scalar.
2954    ///
2955    /// # Returns
2956    ///
2957    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2958    /// corresponding element of the left hand side tensor is less than or equal to the right hand
2959    /// side scalar, and false otherwise.
2960    ///
2961    /// # Remarks
2962    ///
2963    /// This is a low-level function used internally by the library to call different backend functions
2964    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2965    /// or use this function directly.
2966    ///
2967    /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
2968    /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
2969    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2970
2971    /// Selects elements from a tensor based on a boolean mask.
2972    ///
2973    /// # Arguments
2974    ///
2975    /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
2976    /// * `mask` - The boolean mask to use for selecting elements.
2977    /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
2978    ///
2979    /// # Returns
2980    ///
2981    /// A tensor with the same shape as the input tensors, where each element is taken from the
2982    /// corresponding element of the left hand side tensor if the corresponding element of the mask
2983    /// is true, and from the corresponding element of the right hand side tensor otherwise.
2984    ///
2985    /// # Remarks
2986    ///
2987    /// This is a low-level function used internally by the library to call different backend functions
2988    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2989    /// or use this function directly.
2990    ///
2991    /// For selecting elements from a tensor based on a boolean mask, users should prefer the
2992    /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
2993    fn mask_where(
2994        tensor: Self::Primitive,
2995        mask: B::BoolTensorPrimitive,
2996        source: Self::Primitive,
2997    ) -> Self::Primitive;
2998
2999    /// Fills elements of a tensor based on a boolean mask.
3000    ///
3001    /// # Arguments
3002    ///
3003    /// * `tensor` - The tensor where will be overwritten with the value
3004    ///   when the corresponding element of the mask is true.
3005    /// * `mask` - The boolean mask to use for filling elements.
3006    /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
3007    ///
3008    /// # Returns
3009    ///
3010    /// A tensor with the same shape as the input tensors, where each element is taken from the
3011    /// corresponding element unmodified if the corresponding element of the mask is false, and
3012    /// filled with the value otherwise.
3013    ///
3014    /// # Remarks
3015    ///
3016    /// This is a low-level function used internally by the library to call different backend functions
3017    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3018    /// or use this function directly.
3019    ///
3020    /// For filling elements of a tensor based on a boolean mask, users should prefer the
3021    /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
3022    fn mask_fill(
3023        tensor: Self::Primitive,
3024        mask: B::BoolTensorPrimitive,
3025        value: Self::Elem,
3026    ) -> Self::Primitive;
3027
3028    /// Gathers elements from a tensor along an axis.
3029    ///
3030    /// # Arguments
3031    ///
3032    /// * `dim` - The axis along which to gather elements.
3033    /// * `tensor` - The tensor to gather elements from.
3034    /// * `indices` - The indices of the elements to gather.
3035    ///
3036    /// # Returns
3037    ///
3038    /// A tensor with the same shape as the input tensor, where each element is taken from the
3039    /// corresponding element of the input tensor at the corresponding index along the specified axis.
3040    ///
3041    /// # Remarks
3042    ///
3043    /// This is a low-level function used internally by the library to call different backend functions
3044    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3045    /// or use this function directly.
3046    ///
3047    /// For gathering elements from a tensor along an axis, users should prefer the
3048    /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
3049    fn gather(
3050        dim: usize,
3051        tensor: Self::Primitive,
3052        indices: B::IntTensorPrimitive,
3053    ) -> Self::Primitive;
3054
3055    /// Scatters elements into a tensor along an axis.
3056    ///
3057    /// # Arguments
3058    ///
3059    /// * `dim` - The axis along which to scatter elements.
3060    /// * `tensor` - The tensor to scatter elements into.
3061    /// * `indices` - The indices of the elements to scatter.
3062    /// * `values` - The values to scatter into the tensor.
3063    ///
3064    /// # Returns
3065    ///
3066    /// A tensor with the same shape as the input tensor, where each element is taken from the
3067    /// corresponding element of the input tensor at the corresponding index along the specified axis,
3068    /// except for the elements at the specified indices, which are taken from the corresponding
3069    /// element of the values tensor.
3070    ///
3071    /// # Remarks
3072    ///
3073    /// This is a low-level function used internally by the library to call different backend functions
3074    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3075    /// or use this function directly.
3076    ///
3077    /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
3078    /// which is more high-level and designed for public use.
3079    fn scatter(
3080        dim: usize,
3081        tensor: Self::Primitive,
3082        indices: B::IntTensorPrimitive,
3083        values: Self::Primitive,
3084    ) -> Self::Primitive;
3085
3086    /// Select tensor elements along the given dimension corresponding for the given indices.
3087    ///
3088    /// # Arguments
3089    ///
3090    /// * `tensor` - The tensor to select elements from.
3091    /// * `dim` - The axis along which to select elements.
3092    /// * `indices` - The indices of the elements to select.
3093    ///
3094    /// # Returns
3095    ///
3096    /// A tensor with the same shape as the input tensor, where each element is taken from the
3097    /// corresponding element of the input tensor at the corresponding index along the specified axis.
3098    ///
3099    /// # Remarks
3100    ///
3101    /// This is a low-level function used internally by the library to call different backend functions
3102    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3103    /// or use this function directly.
3104    ///
3105    /// For selecting elements from a tensor along an axis, users should prefer the
3106    /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
3107    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive;
3108
3109    /// Assign the selected elements along the given dimension corresponding to the given indices
3110    /// from the value tensor.
3111    ///
3112    /// # Arguments
3113    ///
3114    /// * `tensor` - The tensor to assign elements to.
3115    /// * `dim` - The axis along which to assign elements.
3116    /// * `indices` - The indices of the elements to assign.
3117    /// * `values` - The values to assign to the tensor.
3118    ///
3119    /// # Returns
3120    ///
3121    /// A tensor with the same shape as the input tensor, where each element is taken from the
3122    /// corresponding element of the input tensor at the corresponding index along the specified axis,
3123    /// except for the elements at the specified indices, which are taken from the corresponding
3124    /// element of the values tensor.
3125    ///
3126    /// # Remarks
3127    ///
3128    /// This is a low-level function used internally by the library to call different backend functions
3129    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3130    /// or use this function directly.
3131    ///
3132    /// For assigning elements to a tensor along an axis, users should prefer the
3133    /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
3134    fn select_assign(
3135        tensor: Self::Primitive,
3136        dim: usize,
3137        indices: Tensor<B, 1, Int>,
3138        values: Self::Primitive,
3139    ) -> Self::Primitive;
3140
3141    /// Gets the indices of the maximum elements of a tensor along an axis.
3142    ///
3143    /// # Arguments
3144    ///
3145    /// * `dim` - The axis along which to get the indices of the maximum elements.
3146    /// * `tensor` - The tensor to get the indices of the maximum elements from.
3147    ///
3148    /// # Returns
3149    ///
3150    /// A tensor with the same shape as the input tensor, where each element is the index of the
3151    /// maximum element of the input tensor at the corresponding index along the specified axis.
3152    ///
3153    /// # Remarks
3154    ///
3155    /// This is a low-level function used internally by the library to call different backend functions
3156    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3157    /// or use this function directly.
3158    ///
3159    /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
3160    /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
3161    fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3162
3163    /// Gets the indices of the minimum elements of a tensor along an axis.
3164    ///
3165    /// # Arguments
3166    ///
3167    /// * `dim` - The axis along which to get the indices of the minimum elements.
3168    /// * `tensor` - The tensor to get the indices of the minimum elements from.
3169    ///
3170    /// # Returns
3171    ///
3172    /// A tensor with the same shape as the input tensor, where each element is the index of the
3173    /// minimum element of the input tensor at the corresponding index along the specified axis.
3174    ///
3175    /// # Remarks
3176    ///
3177    /// This is a low-level function used internally by the library to call different backend functions
3178    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3179    /// or use this function directly.
3180    ///
3181    /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
3182    /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
3183    fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3184
3185    /// Gets the maximum elements of a tensor along an axis.
3186    ///
3187    /// # Arguments
3188    ///
3189    /// * `dim` - The axis along which to get the maximum elements.
3190    ///
3191    /// # Returns
3192    ///
3193    /// A single-element tensor containing the maximum element of the input tensor.
3194    ///
3195    /// # Remarks
3196    ///
3197    /// This is a low-level function used internally by the library to call different backend functions
3198    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3199    /// or use this function directly.
3200    ///
3201    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3202    /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
3203    fn max(tensor: Self::Primitive) -> Self::Primitive;
3204
3205    /// Gets the maximum elements of a tensor along an axis.
3206    ///
3207    /// # Arguments
3208    ///
3209    /// * `tensor` - The tensor to get the maximum elements from.
3210    /// * `dim` - The axis along which to get the maximum elements.
3211    ///
3212    /// # Returns
3213    ///
3214    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3215    /// Each element is the maximum element of the corresponding input dim.
3216    ///
3217    /// # Remarks
3218    ///
3219    /// This is a low-level function used internally by the library to call different backend functions
3220    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3221    /// or use this function directly.
3222    ///
3223    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3224    /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
3225    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3226
3227    /// Gets the maximum elements of a tensor along an axis.
3228    ///
3229    /// # Arguments
3230    ///
3231    /// * `tensor` - The tensor to get the maximum elements from.
3232    /// * `dim` - The axis along which to get the maximum elements.
3233    ///
3234    /// # Returns
3235    ///
3236    /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
3237    /// as the input tensor, where each element is the index of the maximum element of the input tensor
3238    /// at the corresponding index along the specified axis.
3239    ///
3240    /// # Remarks
3241    ///
3242    /// This is a low-level function used internally by the library to call different backend functions
3243    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3244    /// or use this function directly.
3245    ///
3246    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3247    /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
3248    fn max_dim_with_indices(
3249        tensor: Self::Primitive,
3250        dim: usize,
3251    ) -> (Self::Primitive, B::IntTensorPrimitive);
3252
3253    /// Gets the maximum elements of a tensor along an axis.
3254    ///
3255    /// # Arguments
3256    ///
3257    /// * `dim` - The axis along which to get the maximum elements.
3258    ///
3259    /// # Returns
3260    ///
3261    /// A single-element tensor containing the maximum absolute element of the input tensor.
3262    ///
3263    /// # Remarks
3264    ///
3265    /// This is a low-level function used internally by the library to call different backend functions
3266    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3267    /// or use this function directly.
3268    ///
3269    /// For getting the maximum absolute elements of a tensor, users should prefer the
3270    /// [Tensor::max_abs](Tensor::max_abs) function, which is more high-level and designed for public use.
3271    fn max_abs(tensor: Self::Primitive) -> Self::Primitive;
3272
3273    /// Gets the maximum elements of a tensor along an axis.
3274    ///
3275    /// # Arguments
3276    ///
3277    /// * `tensor` - The tensor to get the maximum elements from.
3278    /// * `dim` - The axis along which to get the maximum elements.
3279    ///
3280    /// # Returns
3281    ///
3282    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3283    /// Each element is the maximum absolute element of the corresponding input dim.
3284    ///
3285    /// # Remarks
3286    ///
3287    /// This is a low-level function used internally by the library to call different backend functions
3288    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3289    /// or use this function directly.
3290    ///
3291    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3292    /// [Tensor::max_abs_dim](Tensor::max_abs_dim) function, which is more high-level and designed for public use.
3293    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3294
3295    /// Gets the minimum elements of a tensor along an axis.
3296    ///
3297    /// # Arguments
3298    ///
3299    /// * `tensor` - The tensor to get the minimum elements from.
3300    ///
3301    /// # Returns
3302    ///
3303    /// A single-element tensor containing the minimum element of the input tensor.
3304    ///
3305    /// # Remarks
3306    ///
3307    /// This is a low-level function used internally by the library to call different backend functions
3308    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3309    /// or use this function directly.
3310    ///
3311    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3312    /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
3313    fn min(tensor: Self::Primitive) -> Self::Primitive;
3314
3315    /// Gets the minimum elements of a tensor along an axis.
3316    ///
3317    /// # Arguments
3318    ///
3319    /// * `tensor` - The tensor to get the minimum elements from.
3320    /// * `dim` - The axis along which to get the minimum elements.
3321    ///
3322    /// # Returns
3323    ///
3324    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3325    /// Each element is the minimum element of the corresponding input dim.
3326    ///
3327    /// # Remarks
3328    ///
3329    /// This is a low-level function used internally by the library to call different backend functions
3330    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3331    /// or use this function directly.
3332    ///
3333    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3334    /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
3335    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3336
3337    /// Gets the minimum elements and indices of a tensor along an axis.
3338    ///
3339    /// # Arguments
3340    ///
3341    /// * `tensor` - The tensor to get the minimum elements from.
3342    ///
3343    /// # Returns
3344    ///
3345    /// A tensor with the same shape as the input tensor and corresponding indices, where
3346    /// each element is the minimum element of the input tensor at the corresponding index
3347    /// along the specified axis.
3348    ///
3349    /// # Remarks
3350    ///
3351    /// This is a low-level function used internally by the library to call different backend functions
3352    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3353    /// or use this function directly.
3354    ///
3355    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3356    /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
3357    fn min_dim_with_indices(
3358        tensor: Self::Primitive,
3359        dim: usize,
3360    ) -> (Self::Primitive, B::IntTensorPrimitive);
3361
3362    /// Clamp the tensor between the given min and max values.
3363    ///
3364    /// # Arguments
3365    ///
3366    /// * `min` - The minimum value.
3367    /// * `max` - The maximum value.
3368    ///
3369    /// # Returns
3370    ///
3371    /// A new tensor with the values clamped between the given min and max values.
3372    ///
3373    /// # Remarks
3374    ///
3375    /// This is a low-level function used internally by the library to call different backend functions
3376    /// with static dispatch. It is not designed for direct usage by users.
3377    ///
3378    /// For clamping a tensor between the given min and max values, users should prefer the
3379    /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use.
3380    fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
3381
3382    /// Clamps a tensor under a minimum value.
3383    ///
3384    /// # Arguments
3385    ///
3386    /// * `tensor` - The tensor to clamp.
3387    /// * `min` - The minimum value.
3388    ///
3389    /// # Returns
3390    ///
3391    /// A new tensor with the values clamped under the given min value.
3392    ///
3393    /// # Remarks
3394    ///
3395    /// This is a low-level function used internally by the library to call different backend functions
3396    /// with static dispatch. It is not designed for direct usage by users.
3397    ///
3398    /// For clamping a tensor under a minimum value, users should prefer the
3399    /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use.
3400    fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
3401
3402    /// Clamps a tensor over a maximum value.
3403    ///
3404    /// # Arguments
3405    ///
3406    /// * `tensor` - The tensor to clamp.
3407    /// * `max` - The maximum value.
3408    ///
3409    /// # Returns
3410    ///
3411    /// A new tensor with the values clamped over the given max value.
3412    ///
3413    /// # Remarks
3414    ///
3415    /// This is a low-level function used internally by the library to call different backend functions
3416    /// with static dispatch. It is not designed for direct usage by users.
3417    ///
3418    /// For clamping a tensor over a maximum value, users should prefer the
3419    /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
3420    fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
3421
3422    /// Calculate absolute value on all elements of a tensor
3423    ///
3424    /// # Arguments
3425    ///
3426    /// * `tensor` - The tensor to apply abs to.
3427    ///
3428    /// # Returns
3429    ///
3430    /// A tensor with absolute values.
3431    ///
3432    /// # Remarks
3433    ///
3434    /// This is a low-level function used internally by the library to call different backend functions
3435    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3436    /// or use this function directly.
3437    ///
3438    /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
3439    /// which is more high-level and designed for public use.
3440    fn abs(tensor: Self::Primitive) -> Self::Primitive;
3441
3442    /// Element-wise power of a tensor to a float tensor
3443    ///
3444    /// # Arguments
3445    /// * `tensor` - The tensor to apply power to.
3446    /// * `power` - The power to apply to the tensor.
3447    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3448
3449    /// Element-wise power of a tensor
3450    ///
3451    /// # Arguments
3452    /// * `tensor` - The tensor to apply power to.
3453    /// * `power` - The power to apply to the tensor.
3454    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3455
3456    /// Element-wise power of a tensor to a scalar float
3457    ///
3458    /// # Arguments
3459    /// * `tensor` - The tensor to apply power to.
3460    /// * `power` - The power to apply to the tensor.
3461    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3462
3463    /// Element-wise power of a tensor to a scalar int
3464    ///
3465    /// # Arguments
3466    /// * `tensor` - The tensor to apply power to.
3467    /// * `power` - The power to apply to the tensor.
3468    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3469
3470    /// Create a random tensor.
3471    ///
3472    /// # Arguments
3473    ///
3474    /// * `shape` - The shape of the output tensor.
3475    /// * `distribution` - The distribution used to sample.
3476    /// * `device` - The device to use.
3477    ///
3478    /// # Returns
3479    ///
3480    /// A new tensor.
3481    ///
3482    /// # Remarks
3483    ///
3484    /// This is a low-level function used internally by the library to call different backend functions
3485    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3486    /// or use this function directly.
3487    ///
3488    /// Users should prefer the [Tensor::random](Tensor::random) function,
3489    /// which is more high-level and designed for public use.
3490    fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
3491
3492    /// Sort the elements of the input `tensor` by value along a given dimension.
3493    ///
3494    /// This sort is unstable (i.e., may reorder equal elements).
3495    ///
3496    /// # Arguments
3497    ///
3498    /// * `tensor` - The input tensor.
3499    /// * `dim` - The axis along which to sort.
3500    /// * `descending` - The sorting order.
3501    ///
3502    /// # Returns
3503    ///
3504    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
3505    ///
3506    /// # Remarks
3507    /// This is a low-level function used internally by the library to call different backend functions
3508    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3509    /// or use this function directly.
3510    ///
3511    /// Users should prefer the [Tensor::sort](Tensor::sort) function,
3512    /// which is more high-level and designed for public use.
3513    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
3514
3515    /// Sort the elements of the input `tensor` by value along a given dimension.
3516    ///
3517    /// This sort is unstable (i.e., may reorder equal elements).
3518    ///
3519    /// # Arguments
3520    ///
3521    /// * `tensor` - The input tensor.
3522    /// * `dim` - The axis along which to sort.
3523    /// * `descending` - The sorting order.
3524    ///
3525    /// # Returns
3526    ///
3527    /// A tensor with the same shape as the input tensor and corresponding indices, where
3528    /// the elements are sorted by value and the indices map back to the original input tensor.
3529    ///
3530    /// # Remarks
3531    /// This is a low-level function used internally by the library to call different backend functions
3532    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3533    /// or use this function directly.
3534    ///
3535    /// For sorting the elements of a tensor, users should prefer the
3536    /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
3537    /// and designed for public use.
3538    fn sort_with_indices(
3539        tensor: Self::Primitive,
3540        dim: usize,
3541        descending: bool,
3542    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive);
3543
3544    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
3545    ///
3546    /// This sort is unstable (i.e., may reorder equal elements).
3547    ///
3548    /// # Arguments
3549    ///
3550    /// * `tensor` - The input tensor.
3551    /// * `dim` - The axis along which to sort.
3552    /// * `descending` - The sorting order.
3553    ///
3554    /// # Returns
3555    ///
3556    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
3557    ///
3558    /// # Remarks
3559    /// This is a low-level function used internally by the library to call different backend functions
3560    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3561    /// or use this function directly.
3562    ///
3563    /// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
3564    /// which is more high-level and designed for public use.
3565    fn argsort(
3566        tensor: Self::Primitive,
3567        dim: usize,
3568        descending: bool,
3569    ) -> <Int as TensorKind<B>>::Primitive;
3570}
3571
3572impl<B: Backend> Numeric<B> for Int {
3573    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3574        B::int_add(lhs, rhs)
3575    }
3576    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3577        B::int_add_scalar(lhs, rhs.elem())
3578    }
3579    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3580        B::int_sub(lhs, rhs)
3581    }
3582    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3583        B::int_sub_scalar(lhs, rhs.elem())
3584    }
3585    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3586        B::int_div(lhs, rhs)
3587    }
3588    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3589        B::int_div_scalar(lhs, rhs.elem())
3590    }
3591    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3592        B::int_remainder(lhs, rhs)
3593    }
3594    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3595        B::int_remainder_scalar(lhs, rhs.elem())
3596    }
3597    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3598        B::int_mul(lhs, rhs)
3599    }
3600    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3601        B::int_mul_scalar(lhs, rhs.elem())
3602    }
3603    fn neg(tensor: Self::Primitive) -> Self::Primitive {
3604        B::int_neg(tensor)
3605    }
3606    fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3607        B::int_zeros(shape, device)
3608    }
3609    fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3610        B::int_ones(shape, device)
3611    }
3612    fn full<E: ElementConversion>(
3613        shape: Shape,
3614        fill_value: E,
3615        device: &B::Device,
3616    ) -> Self::Primitive {
3617        B::int_full(shape, fill_value.elem(), device)
3618    }
3619
3620    fn sum(tensor: Self::Primitive) -> Self::Primitive {
3621        B::int_sum(tensor)
3622    }
3623
3624    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3625        B::int_sum_dim(tensor, dim)
3626    }
3627
3628    fn prod(tensor: Self::Primitive) -> Self::Primitive {
3629        B::int_prod(tensor)
3630    }
3631
3632    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3633        B::int_prod_dim(tensor, dim)
3634    }
3635
3636    fn mean(tensor: Self::Primitive) -> Self::Primitive {
3637        B::int_mean(tensor)
3638    }
3639    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3640        B::int_mean_dim(tensor, dim)
3641    }
3642
3643    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3644        B::int_equal_elem(lhs, rhs)
3645    }
3646    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3647        B::int_not_equal_elem(lhs, rhs)
3648    }
3649    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3650        B::int_greater(lhs, rhs)
3651    }
3652
3653    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3654        B::int_greater_elem(lhs, rhs)
3655    }
3656
3657    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3658        B::int_greater_equal(lhs, rhs)
3659    }
3660
3661    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3662        B::int_greater_equal_elem(lhs, rhs)
3663    }
3664
3665    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3666        B::int_lower(lhs, rhs)
3667    }
3668
3669    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3670        B::int_lower_elem(lhs, rhs)
3671    }
3672
3673    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3674        B::int_lower_equal(lhs, rhs)
3675    }
3676
3677    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3678        B::int_lower_equal_elem(lhs, rhs)
3679    }
3680
3681    fn mask_where(
3682        tensor: Self::Primitive,
3683        mask: B::BoolTensorPrimitive,
3684        source: Self::Primitive,
3685    ) -> Self::Primitive {
3686        B::int_mask_where(tensor, mask, source)
3687    }
3688
3689    fn mask_fill(
3690        tensor: Self::Primitive,
3691        mask: B::BoolTensorPrimitive,
3692        value: Self::Elem,
3693    ) -> Self::Primitive {
3694        B::int_mask_fill(tensor, mask, value)
3695    }
3696
3697    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3698        B::int_select(tensor, dim, indices.primitive)
3699    }
3700
3701    fn select_assign(
3702        tensor: Self::Primitive,
3703        dim: usize,
3704        indices: Tensor<B, 1, Int>,
3705        values: Self::Primitive,
3706    ) -> Self::Primitive {
3707        B::int_select_assign(tensor, dim, indices.primitive, values)
3708    }
3709    fn gather(
3710        dim: usize,
3711        tensor: Self::Primitive,
3712        indices: B::IntTensorPrimitive,
3713    ) -> Self::Primitive {
3714        B::int_gather(dim, tensor, indices)
3715    }
3716
3717    fn scatter(
3718        dim: usize,
3719        tensor: Self::Primitive,
3720        indices: B::IntTensorPrimitive,
3721        values: Self::Primitive,
3722    ) -> Self::Primitive {
3723        B::int_scatter(dim, tensor, indices, values)
3724    }
3725
3726    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3727        B::int_argmax(tensor, dim)
3728    }
3729
3730    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3731        B::int_argmin(tensor, dim)
3732    }
3733
3734    fn max(tensor: Self::Primitive) -> Self::Primitive {
3735        B::int_max(tensor)
3736    }
3737
3738    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3739        B::int_max_dim(tensor, dim)
3740    }
3741
3742    fn max_dim_with_indices(
3743        tensor: Self::Primitive,
3744        dim: usize,
3745    ) -> (Self::Primitive, IntTensor<B>) {
3746        B::int_max_dim_with_indices(tensor, dim)
3747    }
3748
3749    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
3750        B::int_max_abs(tensor)
3751    }
3752
3753    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3754        B::int_max_abs_dim(tensor, dim)
3755    }
3756
3757    fn min(tensor: Self::Primitive) -> Self::Primitive {
3758        B::int_min(tensor)
3759    }
3760
3761    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3762        B::int_min_dim(tensor, dim)
3763    }
3764
3765    fn min_dim_with_indices(
3766        tensor: Self::Primitive,
3767        dim: usize,
3768    ) -> (Self::Primitive, IntTensor<B>) {
3769        B::int_min_dim_with_indices(tensor, dim)
3770    }
3771
3772    fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
3773        B::int_clamp(tensor, min, max)
3774    }
3775
3776    fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
3777        B::int_clamp_min(tensor, min)
3778    }
3779
3780    fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
3781        B::int_clamp_max(tensor, max)
3782    }
3783
3784    fn abs(tensor: Self::Primitive) -> Self::Primitive {
3785        B::int_abs(tensor)
3786    }
3787
3788    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3789        B::int_powf(lhs, B::int_into_float(rhs))
3790    }
3791
3792    fn powf_scalar<E: ElementConversion>(
3793        lhs: Self::Primitive,
3794        rhs: E,
3795    ) -> <Int as TensorKind<B>>::Primitive {
3796        B::int_powf_scalar(lhs, rhs.elem())
3797    }
3798
3799    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3800        B::int_powi(lhs, rhs)
3801    }
3802
3803    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3804        B::int_powf_scalar(lhs, rhs.elem())
3805    }
3806
3807    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
3808        B::int_random(shape, distribution, device)
3809    }
3810
3811    fn sign(tensor: Self::Primitive) -> Self::Primitive {
3812        B::int_sign(tensor)
3813    }
3814
3815    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
3816        B::int_sort(tensor, dim, descending)
3817    }
3818
3819    fn sort_with_indices(
3820        tensor: Self::Primitive,
3821        dim: usize,
3822        descending: bool,
3823    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
3824        B::int_sort_with_indices(tensor, dim, descending)
3825    }
3826
3827    fn argsort(
3828        tensor: Self::Primitive,
3829        dim: usize,
3830        descending: bool,
3831    ) -> <Int as TensorKind<B>>::Primitive {
3832        B::int_argsort(tensor, dim, descending)
3833    }
3834}
3835
3836impl<B: Backend> Numeric<B> for Float {
3837    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3838        match (lhs, rhs) {
3839            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3840                TensorPrimitive::Float(B::float_add(lhs, rhs))
3841            }
3842            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3843                TensorPrimitive::QFloat(B::q_add(lhs, rhs))
3844            }
3845            _ => panic!("Primitive type mismatch for lhs and rhs"),
3846        }
3847    }
3848    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3849        match lhs {
3850            TensorPrimitive::Float(lhs) => {
3851                TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
3852            }
3853            TensorPrimitive::QFloat(lhs) => {
3854                TensorPrimitive::QFloat(B::q_add_scalar(lhs, rhs.elem()))
3855            }
3856        }
3857    }
3858    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3859        match (lhs, rhs) {
3860            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3861                TensorPrimitive::Float(B::float_sub(lhs, rhs))
3862            }
3863            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3864                TensorPrimitive::QFloat(B::q_sub(lhs, rhs))
3865            }
3866            _ => panic!("Primitive type mismatch for lhs and rhs"),
3867        }
3868    }
3869    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3870        match lhs {
3871            TensorPrimitive::Float(lhs) => {
3872                TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
3873            }
3874            TensorPrimitive::QFloat(lhs) => {
3875                TensorPrimitive::QFloat(B::q_sub_scalar(lhs, rhs.elem()))
3876            }
3877        }
3878    }
3879    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3880        match (lhs, rhs) {
3881            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3882                TensorPrimitive::Float(B::float_div(lhs, rhs))
3883            }
3884            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3885                TensorPrimitive::QFloat(B::q_div(lhs, rhs))
3886            }
3887            _ => panic!("Primitive type mismatch for lhs and rhs"),
3888        }
3889    }
3890    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3891        match lhs {
3892            TensorPrimitive::Float(lhs) => {
3893                TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
3894            }
3895            TensorPrimitive::QFloat(lhs) => {
3896                TensorPrimitive::QFloat(B::q_div_scalar(lhs, rhs.elem()))
3897            }
3898        }
3899    }
3900    fn remainder(
3901        lhs: Self::Primitive,
3902        rhs: Self::Primitive,
3903    ) -> <Float as TensorKind<B>>::Primitive {
3904        match (lhs, rhs) {
3905            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3906                TensorPrimitive::Float(B::float_remainder(lhs, rhs))
3907            }
3908            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3909                TensorPrimitive::QFloat(B::q_remainder(lhs, rhs))
3910            }
3911            _ => panic!("Primitive type mismatch for lhs and rhs"),
3912        }
3913    }
3914    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3915        match lhs {
3916            TensorPrimitive::Float(lhs) => {
3917                TensorPrimitive::Float(B::float_remainder_scalar(lhs, rhs.elem()))
3918            }
3919            TensorPrimitive::QFloat(lhs) => {
3920                TensorPrimitive::QFloat(B::q_remainder_scalar(lhs, rhs.elem()))
3921            }
3922        }
3923    }
3924    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3925        match (lhs, rhs) {
3926            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3927                TensorPrimitive::Float(B::float_mul(lhs, rhs))
3928            }
3929            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3930                TensorPrimitive::QFloat(B::q_mul(lhs, rhs))
3931            }
3932            _ => panic!("Primitive type mismatch for lhs and rhs"),
3933        }
3934    }
3935    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3936        match lhs {
3937            TensorPrimitive::Float(lhs) => {
3938                TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
3939            }
3940            TensorPrimitive::QFloat(lhs) => {
3941                TensorPrimitive::QFloat(B::q_mul_scalar(lhs, rhs.elem()))
3942            }
3943        }
3944    }
3945    fn neg(tensor: Self::Primitive) -> Self::Primitive {
3946        match tensor {
3947            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
3948            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_neg(tensor)),
3949        }
3950    }
3951    fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3952        TensorPrimitive::Float(B::float_zeros(shape, device))
3953    }
3954    fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3955        TensorPrimitive::Float(B::float_ones(shape, device))
3956    }
3957
3958    fn full<E: ElementConversion>(
3959        shape: Shape,
3960        fill_value: E,
3961        device: &B::Device,
3962    ) -> Self::Primitive {
3963        TensorPrimitive::Float(B::float_full(shape, fill_value.elem(), device))
3964    }
3965
3966    fn sum(tensor: Self::Primitive) -> Self::Primitive {
3967        match tensor {
3968            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
3969            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum(tensor)),
3970        }
3971    }
3972
3973    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3974        match tensor {
3975            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
3976            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum_dim(tensor, dim)),
3977        }
3978    }
3979
3980    fn prod(tensor: Self::Primitive) -> Self::Primitive {
3981        match tensor {
3982            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
3983            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod(tensor)),
3984        }
3985    }
3986
3987    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3988        match tensor {
3989            TensorPrimitive::Float(tensor) => {
3990                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
3991            }
3992            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod_dim(tensor, dim)),
3993        }
3994    }
3995
3996    fn mean(tensor: Self::Primitive) -> Self::Primitive {
3997        match tensor {
3998            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
3999            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean(tensor)),
4000        }
4001    }
4002
4003    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4004        match tensor {
4005            TensorPrimitive::Float(tensor) => {
4006                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
4007            }
4008            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean_dim(tensor, dim)),
4009        }
4010    }
4011
4012    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4013        B::float_equal_elem(lhs.tensor(), rhs)
4014    }
4015    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4016        B::float_not_equal_elem(lhs.tensor(), rhs)
4017    }
4018    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4019        B::float_greater(lhs.tensor(), rhs.tensor())
4020    }
4021
4022    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4023        B::float_greater_elem(lhs.tensor(), rhs)
4024    }
4025
4026    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4027        B::float_greater_equal(lhs.tensor(), rhs.tensor())
4028    }
4029
4030    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4031        B::float_greater_equal_elem(lhs.tensor(), rhs)
4032    }
4033
4034    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4035        B::float_lower(lhs.tensor(), rhs.tensor())
4036    }
4037
4038    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4039        B::float_lower_elem(lhs.tensor(), rhs)
4040    }
4041
4042    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4043        B::float_lower_equal(lhs.tensor(), rhs.tensor())
4044    }
4045
4046    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4047        B::float_lower_equal_elem(lhs.tensor(), rhs)
4048    }
4049
4050    fn mask_where(
4051        tensor: Self::Primitive,
4052        mask: B::BoolTensorPrimitive,
4053        source: Self::Primitive,
4054    ) -> Self::Primitive {
4055        match (tensor, source) {
4056            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(source)) => {
4057                TensorPrimitive::Float(B::float_mask_where(tensor, mask, source))
4058            }
4059            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(source)) => {
4060                TensorPrimitive::QFloat(B::q_mask_where(tensor, mask, source))
4061            }
4062            _ => panic!("Primitive type mismatch for tensor and source"),
4063        }
4064    }
4065
4066    fn mask_fill(
4067        tensor: Self::Primitive,
4068        mask: B::BoolTensorPrimitive,
4069        value: Self::Elem,
4070    ) -> Self::Primitive {
4071        match tensor {
4072            TensorPrimitive::Float(tensor) => {
4073                TensorPrimitive::Float(B::float_mask_fill(tensor, mask, value))
4074            }
4075            TensorPrimitive::QFloat(tensor) => {
4076                TensorPrimitive::QFloat(B::q_mask_fill(tensor, mask, value))
4077            }
4078        }
4079    }
4080
4081    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
4082        match tensor {
4083            TensorPrimitive::Float(tensor) => {
4084                TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive))
4085            }
4086            TensorPrimitive::QFloat(tensor) => {
4087                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive))
4088            }
4089        }
4090    }
4091
4092    fn select_assign(
4093        tensor: Self::Primitive,
4094        dim: usize,
4095        indices: Tensor<B, 1, Int>,
4096        values: Self::Primitive,
4097    ) -> Self::Primitive {
4098        match (tensor, values) {
4099            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
4100                TensorPrimitive::Float(B::float_select_assign(
4101                    tensor,
4102                    dim,
4103                    indices.primitive,
4104                    values,
4105                ))
4106            }
4107            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
4108                TensorPrimitive::QFloat(B::q_select_assign(tensor, dim, indices.primitive, values))
4109            }
4110            _ => panic!("Primitive type mismatch for tensor and values"),
4111        }
4112    }
4113
4114    fn gather(
4115        dim: usize,
4116        tensor: Self::Primitive,
4117        indices: B::IntTensorPrimitive,
4118    ) -> Self::Primitive {
4119        match tensor {
4120            TensorPrimitive::Float(tensor) => {
4121                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
4122            }
4123            TensorPrimitive::QFloat(tensor) => {
4124                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
4125            }
4126        }
4127    }
4128
4129    fn scatter(
4130        dim: usize,
4131        tensor: Self::Primitive,
4132        indices: B::IntTensorPrimitive,
4133        values: Self::Primitive,
4134    ) -> Self::Primitive {
4135        match (tensor, values) {
4136            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
4137                TensorPrimitive::Float(B::float_scatter(dim, tensor, indices, values))
4138            }
4139            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
4140                TensorPrimitive::QFloat(B::q_scatter(dim, tensor, indices, values))
4141            }
4142            _ => panic!("Primitive type mismatch for tensor and values"),
4143        }
4144    }
4145
4146    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4147        match tensor {
4148            TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
4149            TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
4150        }
4151    }
4152
4153    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4154        match tensor {
4155            TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
4156            TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
4157        }
4158    }
4159
4160    fn max(tensor: Self::Primitive) -> Self::Primitive {
4161        match tensor {
4162            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
4163            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
4164        }
4165    }
4166
4167    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4168        match tensor {
4169            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
4170            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
4171        }
4172    }
4173
4174    fn max_dim_with_indices(
4175        tensor: Self::Primitive,
4176        dim: usize,
4177    ) -> (Self::Primitive, IntTensor<B>) {
4178        match tensor {
4179            TensorPrimitive::Float(tensor) => {
4180                let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
4181                (TensorPrimitive::Float(values), indices)
4182            }
4183            TensorPrimitive::QFloat(tensor) => {
4184                let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
4185                (TensorPrimitive::QFloat(values), indices)
4186            }
4187        }
4188    }
4189
4190    fn min(tensor: Self::Primitive) -> Self::Primitive {
4191        match tensor {
4192            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
4193            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
4194        }
4195    }
4196
4197    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4198        match tensor {
4199            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
4200            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
4201        }
4202    }
4203
4204    fn min_dim_with_indices(
4205        tensor: Self::Primitive,
4206        dim: usize,
4207    ) -> (Self::Primitive, IntTensor<B>) {
4208        match tensor {
4209            TensorPrimitive::Float(tensor) => {
4210                let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
4211                (TensorPrimitive::Float(values), indices)
4212            }
4213            TensorPrimitive::QFloat(tensor) => {
4214                let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
4215                (TensorPrimitive::QFloat(values), indices)
4216            }
4217        }
4218    }
4219
4220    fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
4221        match tensor {
4222            TensorPrimitive::Float(tensor) => {
4223                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
4224            }
4225            TensorPrimitive::QFloat(tensor) => {
4226                TensorPrimitive::QFloat(B::q_clamp(tensor, min, max))
4227            }
4228        }
4229    }
4230
4231    fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
4232        match tensor {
4233            TensorPrimitive::Float(tensor) => {
4234                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
4235            }
4236            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_min(tensor, min)),
4237        }
4238    }
4239
4240    fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
4241        match tensor {
4242            TensorPrimitive::Float(tensor) => {
4243                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
4244            }
4245            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_max(tensor, max)),
4246        }
4247    }
4248
4249    fn abs(tensor: Self::Primitive) -> Self::Primitive {
4250        match tensor {
4251            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
4252            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
4253        }
4254    }
4255
4256    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4257        match (lhs, rhs) {
4258            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4259                TensorPrimitive::Float(B::float_powf(lhs, rhs))
4260            }
4261            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4262                TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4263            }
4264            _ => panic!("Primitive type mismatch for lhs and rhs"),
4265        }
4266    }
4267
4268    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4269        match lhs {
4270            TensorPrimitive::Float(lhs) => {
4271                TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4272            }
4273            TensorPrimitive::QFloat(lhs) => {
4274                TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4275            }
4276        }
4277    }
4278
4279    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4280        match (lhs, rhs) {
4281            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4282                TensorPrimitive::Float(B::float_powf(lhs, rhs))
4283            }
4284            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4285                TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4286            }
4287            _ => panic!("Primitive type mismatch for lhs and rhs"),
4288        }
4289    }
4290
4291    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4292        match lhs {
4293            TensorPrimitive::Float(lhs) => {
4294                TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4295            }
4296            TensorPrimitive::QFloat(lhs) => {
4297                TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4298            }
4299        }
4300    }
4301
4302    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
4303        TensorPrimitive::Float(B::float_random(shape, distribution, device))
4304    }
4305
4306    fn sign(tensor: Self::Primitive) -> Self::Primitive {
4307        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
4308    }
4309
4310    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4311        match tensor {
4312            TensorPrimitive::Float(tensor) => {
4313                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
4314            }
4315            TensorPrimitive::QFloat(tensor) => {
4316                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
4317            }
4318        }
4319    }
4320
4321    fn sort_with_indices(
4322        tensor: Self::Primitive,
4323        dim: usize,
4324        descending: bool,
4325    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4326        match tensor {
4327            TensorPrimitive::Float(tensor) => {
4328                let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
4329                (TensorPrimitive::Float(values), indices)
4330            }
4331            TensorPrimitive::QFloat(tensor) => {
4332                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
4333                (TensorPrimitive::QFloat(values), indices)
4334            }
4335        }
4336    }
4337
4338    fn argsort(
4339        tensor: Self::Primitive,
4340        dim: usize,
4341        descending: bool,
4342    ) -> <Int as TensorKind<B>>::Primitive {
4343        match tensor {
4344            TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
4345            TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
4346        }
4347    }
4348
4349    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
4350        match tensor {
4351            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
4352            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
4353        }
4354    }
4355
4356    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4357        match tensor {
4358            TensorPrimitive::Float(tensor) => {
4359                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
4360            }
4361            TensorPrimitive::QFloat(tensor) => {
4362                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
4363            }
4364        }
4365    }
4366}
4367
4368impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
4369where
4370    B: Backend,
4371    K: Numeric<B>,
4372    K::Elem: Element,
4373{
4374    type Output = Self;
4375
4376    fn add(self, rhs: Tensor<B, D, K>) -> Self {
4377        Self::add(self, rhs)
4378    }
4379}
4380
4381impl<E, const D: usize, B, K> core::ops::Add<E> for Tensor<B, D, K>
4382where
4383    E: ElementConversion,
4384    B: Backend,
4385    K: Numeric<B>,
4386    K::Elem: Element,
4387{
4388    type Output = Self;
4389
4390    fn add(self, other: E) -> Self {
4391        Tensor::add_scalar(self, other)
4392    }
4393}
4394
4395impl<B, const D: usize, K> core::ops::Sub<Tensor<B, D, K>> for Tensor<B, D, K>
4396where
4397    B: Backend,
4398    K: Numeric<B>,
4399    K::Elem: Element,
4400{
4401    type Output = Self;
4402
4403    fn sub(self, rhs: Tensor<B, D, K>) -> Self {
4404        Tensor::sub(self, rhs)
4405    }
4406}
4407
4408impl<E, const D: usize, B, K> core::ops::Sub<E> for Tensor<B, D, K>
4409where
4410    E: ElementConversion,
4411    B: Backend,
4412    K: Numeric<B>,
4413    K::Elem: Element,
4414{
4415    type Output = Self;
4416
4417    fn sub(self, other: E) -> Self {
4418        Tensor::sub_scalar(self, other)
4419    }
4420}
4421
4422impl<B, const D: usize, K> core::ops::Div<Tensor<B, D, K>> for Tensor<B, D, K>
4423where
4424    B: Backend,
4425    K: Numeric<B>,
4426    K::Elem: Element,
4427{
4428    type Output = Self;
4429
4430    fn div(self, rhs: Tensor<B, D, K>) -> Self {
4431        Tensor::div(self, rhs)
4432    }
4433}
4434
4435impl<E, const D: usize, B, K> core::ops::Div<E> for Tensor<B, D, K>
4436where
4437    E: ElementConversion,
4438    B: Backend,
4439    K: Numeric<B>,
4440    K::Elem: Element,
4441{
4442    type Output = Self;
4443
4444    fn div(self, other: E) -> Self {
4445        Tensor::div_scalar(self, other)
4446    }
4447}
4448
4449impl<const D: usize, B, K> core::ops::Rem<Tensor<B, D, K>> for Tensor<B, D, K>
4450where
4451    B: Backend,
4452    K: Numeric<B>,
4453    K::Elem: Element,
4454{
4455    type Output = Self;
4456    fn rem(self, rhs: Tensor<B, D, K>) -> Self::Output {
4457        Tensor::remainder(self, rhs)
4458    }
4459}
4460
4461impl<E, const D: usize, B, K> core::ops::Rem<E> for Tensor<B, D, K>
4462where
4463    E: ElementConversion,
4464    B: Backend,
4465    K: Numeric<B>,
4466    K::Elem: Element,
4467{
4468    type Output = Self;
4469
4470    fn rem(self, other: E) -> Self {
4471        Tensor::remainder_scalar(self, other)
4472    }
4473}
4474
4475impl<B, const D: usize, K> core::ops::Mul<Tensor<B, D, K>> for Tensor<B, D, K>
4476where
4477    B: Backend,
4478    K: Numeric<B>,
4479    K::Elem: Element,
4480{
4481    type Output = Self;
4482
4483    fn mul(self, rhs: Tensor<B, D, K>) -> Self {
4484        Tensor::mul(self, rhs)
4485    }
4486}
4487
4488impl<E, const D: usize, B, K> core::ops::Mul<E> for Tensor<B, D, K>
4489where
4490    E: ElementConversion,
4491    B: Backend,
4492    K: Numeric<B>,
4493    K::Elem: Element,
4494{
4495    type Output = Self;
4496
4497    fn mul(self, other: E) -> Self {
4498        Tensor::mul_scalar(self, other)
4499    }
4500}
4501
4502impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
4503where
4504    B: Backend,
4505    K: Numeric<B>,
4506    K::Elem: Element,
4507{
4508    type Output = Self;
4509
4510    fn neg(self) -> Self {
4511        Tensor::neg(self)
4512    }
4513}