Skip to main content

burn_tensor/tensor/api/
float.rs

1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Int, TensorPrimitive};
13use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
14use core::f32;
15
16/// Default RTOL value for `is_close` and `all_close`.
17pub const DEFAULT_RTOL: f64 = 1e-5;
18
19/// Default ATOL value for `is_close` and `all_close`.
20pub const DEFAULT_ATOL: f64 = 1e-8;
21
22impl<const D: usize, B> Tensor<B, D>
23where
24    B: Backend,
25{
26    /// Applies element wise exponential operation.
27    ///
28    #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
29    #[cfg_attr(not(doc), doc = "`y = e^x`")]
30    pub fn exp(self) -> Self {
31        Self::new(TensorPrimitive::Float(B::float_exp(
32            self.primitive.tensor(),
33        )))
34    }
35
36    /// Applies element wise natural log operation *ln*.
37    ///
38    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
39    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
40    pub fn log(self) -> Self {
41        Self::new(TensorPrimitive::Float(B::float_log(
42            self.primitive.tensor(),
43        )))
44    }
45
46    /// Applies the natural logarithm of one plus the input tensor, element-wise.
47    ///
48    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
49    #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
50    pub fn log1p(self) -> Self {
51        Self::new(TensorPrimitive::Float(B::float_log1p(
52            self.primitive.tensor(),
53        )))
54    }
55
56    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
57    ///
58    #[cfg_attr(
59        doc,
60        doc = r#"
61$y_i = \text{erf}\(x_i\)$
62
63The error function is defined as:
64
65$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
66"#
67    )]
68    #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
69    pub fn erf(self) -> Self {
70        Self::new(TensorPrimitive::Float(B::float_erf(
71            self.primitive.tensor(),
72        )))
73    }
74
75    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
76    /// (or multiplicative inverse) element wise.
77    ///
78    #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
79    #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
80    pub fn recip(self) -> Self {
81        Self::new(TensorPrimitive::Float(B::float_recip(
82            self.primitive.tensor(),
83        )))
84    }
85
86    /// Applies element wise square operation.
87    ///
88    #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
89    #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
90    pub fn square(self) -> Self {
91        self.powi_scalar(2)
92    }
93
94    /// Applies element wise root square operation.
95    ///
96    #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
97    #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
98    pub fn sqrt(self) -> Self {
99        Self::new(TensorPrimitive::Float(B::float_sqrt(
100            self.primitive.tensor(),
101        )))
102    }
103
104    /// Applies element wise cosine operation.
105    ///
106    #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
107    #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
108    pub fn cos(self) -> Self {
109        Self::new(TensorPrimitive::Float(B::float_cos(
110            self.primitive.tensor(),
111        )))
112    }
113
114    /// Applies element wise sine operation.
115    ///
116    #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
117    #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
118    pub fn sin(self) -> Self {
119        Self::new(TensorPrimitive::Float(B::float_sin(
120            self.primitive.tensor(),
121        )))
122    }
123
124    /// Applies element wise tangent operation.
125    ///
126    #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
127    #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
128    pub fn tan(self) -> Self {
129        Self::new(TensorPrimitive::Float(B::float_tan(
130            self.primitive.tensor(),
131        )))
132    }
133
134    /// Applies element wise hyperbolic cosine operation.
135    ///
136    #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
137    #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
138    ///
139    /// # Example
140    ///
141    /// ```rust
142    /// use burn_tensor::backend::Backend;
143    /// use burn_tensor::Tensor;
144    ///
145    /// fn example<B: Backend>() {
146    ///     let device = Default::default();
147    ///
148    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
149    ///     println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
150    /// }
151    /// ```
152    pub fn cosh(self) -> Self {
153        Self::new(TensorPrimitive::Float(B::float_cosh(
154            self.primitive.tensor(),
155        )))
156    }
157
158    /// Applies element wise hyperbolic sine operation.
159    ///
160    #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
161    #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
162    ///
163    /// # Example
164    ///
165    /// ```rust
166    /// use burn_tensor::backend::Backend;
167    /// use burn_tensor::Tensor;
168    ///
169    /// fn example<B: Backend>() {
170    ///     let device = Default::default();
171    ///
172    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
173    ///     println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
174    /// }
175    /// ```
176    pub fn sinh(self) -> Self {
177        Self::new(TensorPrimitive::Float(B::float_sinh(
178            self.primitive.tensor(),
179        )))
180    }
181
182    /// Applies element wise hyperbolic tangent operation.
183    ///
184    #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
185    #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
186    ///
187    /// # Example
188    ///
189    /// ```rust
190    /// use burn_tensor::backend::Backend;
191    /// use burn_tensor::Tensor;
192    ///
193    /// fn example<B: Backend>() {
194    ///     let device = Default::default();
195    ///
196    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
197    ///     println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
198    /// }
199    /// ```
200    pub fn tanh(self) -> Self {
201        Self::new(TensorPrimitive::Float(B::float_tanh(
202            self.primitive.tensor(),
203        )))
204    }
205
206    /// Applies element wise inverse sine operation.
207    ///
208    #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
209    #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
210    ///
211    /// # Example
212    ///
213    /// ```rust
214    /// use burn_tensor::backend::Backend;
215    /// use burn_tensor::Tensor;
216    ///
217    /// fn example<B: Backend>() {
218    ///     let device = Default::default();
219    ///
220    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
221    ///     println!("{}", tensor.asin()); // [ 0.0000, -1.5708,  1.5708]
222    /// }
223    /// ```
224    pub fn asin(self) -> Self {
225        Self::new(TensorPrimitive::Float(B::float_asin(
226            self.primitive.tensor(),
227        )))
228    }
229
230    /// Applies element wise inverse hyperbolic sine operation.
231    ///
232    #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
233    #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
234    ///
235    /// # Example
236    ///
237    /// ```rust
238    /// use burn_tensor::backend::Backend;
239    /// use burn_tensor::Tensor;
240    ///
241    /// fn example<B: Backend>() {
242    ///     let device = Default::default();
243    ///
244    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
245    ///     println!("{}", tensor.asinh()); // [ 0.0000, -0.8814,  0.8814]
246    /// }
247    /// ```
248    pub fn asinh(self) -> Self {
249        Self::new(TensorPrimitive::Float(B::float_asinh(
250            self.primitive.tensor(),
251        )))
252    }
253
254    /// Applies element wise inverse cosine operation.
255    ///
256    #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
257    #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
258    ///
259    /// # Example
260    ///
261    /// ```rust
262    /// use burn_tensor::backend::Backend;
263    /// use burn_tensor::Tensor;
264    ///
265    /// fn example<B: Backend>() {
266    ///     let device = Default::default();
267    ///
268    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
269    ///     println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
270    /// }
271    /// ```
272    pub fn acos(self) -> Self {
273        Self::new(TensorPrimitive::Float(B::float_acos(
274            self.primitive.tensor(),
275        )))
276    }
277
278    /// Applies element wise inverse hyperbolic cosine operation.
279    ///
280    #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
281    #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
282    ///
283    /// # Example
284    ///
285    /// ```rust
286    /// use burn_tensor::backend::Backend;
287    /// use burn_tensor::Tensor;
288    ///
289    /// fn example<B: Backend>() {
290    ///     let device = Default::default();
291    ///
292    ///     let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
293    ///     println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627]
294    /// }
295    /// ```
296    pub fn acosh(self) -> Self {
297        Self::new(TensorPrimitive::Float(B::float_acosh(
298            self.primitive.tensor(),
299        )))
300    }
301
302    /// Applies element wise inverse tangent operation.
303    ///
304    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
305    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
306    ///
307    /// # Example
308    ///
309    /// ```rust
310    /// use burn_tensor::backend::Backend;
311    /// use burn_tensor::Tensor;
312    ///
313    /// fn example<B: Backend>() {
314    ///     let device = Default::default();
315    ///
316    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
317    ///     println!("{}", tensor.sinh()); // [ 0.0, -0.7854,  1.1071]
318    /// }
319    /// ```
320    pub fn atan(self) -> Self {
321        Self::new(TensorPrimitive::Float(B::float_atan(
322            self.primitive.tensor(),
323        )))
324    }
325
326    /// Applies element wise inverse hyperbolic tangent operation.
327    ///
328    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
329    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
330    ///
331    /// # Example
332    ///
333    /// ```rust
334    /// use burn_tensor::backend::Backend;
335    /// use burn_tensor::Tensor;
336    ///
337    /// fn example<B: Backend>() {
338    ///     let device = Default::default();
339    ///
340    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
341    ///     println!("{}", tensor.sinh()); // [ 0.0, -0.5493,  0.5493]
342    /// }
343    /// ```
344    pub fn atanh(self) -> Self {
345        Self::new(TensorPrimitive::Float(B::float_atanh(
346            self.primitive.tensor(),
347        )))
348    }
349
350    /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
351    ///
352    #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
353    #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
354    ///
355    /// # Example
356    ///
357    /// ```rust
358    /// use burn_tensor::backend::Backend;
359    /// use burn_tensor::Tensor;
360    ///
361    /// fn example<B: Backend>() {
362    ///     let device = Default::default();
363    ///
364    ///     let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
365    ///     let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
366    ///     println!("{}", lhs.atan2(rhs)); // [-1.1071,  2.0344, -2.0344]
367    /// }
368    /// ```
369    pub fn atan2(self, other: Self) -> Self {
370        Self::new(TensorPrimitive::Float(B::float_atan2(
371            self.primitive.tensor(),
372            other.primitive.tensor(),
373        )))
374    }
375
376    /// Converts each of the elements of the input tensor from angles in degrees to radians.
377    ///
378    /// # Example
379    /// ```ignore
380    /// let tensor_in_radians = tensor.deg2rad();
381    /// ```
382    pub fn deg2rad(self) -> Self {
383        self.mul_scalar(f32::consts::PI / 180.0)
384    }
385
386    /// Converts each of the elements of the input tensor from angles in radians to degrees.
387    ///
388    /// # Example
389    /// ```ignore
390    /// let tensor_in_degrees = tensor.rad2deg();
391    /// ```
392    pub fn rad2deg(self) -> Self {
393        self.mul_scalar(180.0 / f32::consts::PI)
394    }
395
396    /// Applies element wise round operation.
397    ///
398    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
399    /// strategy, with halfway cases rounded to the nearest even integer value.
400    pub fn round(self) -> Self {
401        Self::new(TensorPrimitive::Float(B::float_round(
402            self.primitive.tensor(),
403        )))
404    }
405
406    /// Applies element wise floor operation.
407    pub fn floor(self) -> Self {
408        Self::new(TensorPrimitive::Float(B::float_floor(
409            self.primitive.tensor(),
410        )))
411    }
412
413    /// Applies element wise ceil operation.
414    pub fn ceil(self) -> Self {
415        Self::new(TensorPrimitive::Float(B::float_ceil(
416            self.primitive.tensor(),
417        )))
418    }
419
420    /// Create a tensor from floats (f32) on a given device.
421    ///
422    /// # Example
423    ///
424    /// ```rust
425    /// use burn_tensor::backend::Backend;
426    /// use burn_tensor::Tensor;
427    ///
428    /// fn example<B: Backend>() {
429    ///     let device = B::Device::default();
430    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
431    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
432    /// }
433    /// ```
434    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
435        Self::from_data(floats.into().convert::<f32>(), device)
436    }
437
438    /// Returns a new tensor with the same shape and device as the current tensor and the data
439    /// cast to Integer.
440    ///
441    /// # Example
442    ///
443    /// ```rust
444    /// use burn_tensor::backend::Backend;
445    /// use burn_tensor::Tensor;
446    ///
447    /// fn example<B: Backend>() {
448    ///     let device = Default::default();
449    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
450    ///     let int_tensor = float_tensor.int();
451    /// }
452    /// ```
453    pub fn int(self) -> Tensor<B, D, Int> {
454        Tensor::new(B::float_into_int(self.primitive.tensor()))
455    }
456
457    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
458    /// values sampled from the given distribution.
459    pub fn random_like(&self, distribution: Distribution) -> Self {
460        Self::new(TensorPrimitive::Float(B::float_random(
461            self.shape(),
462            distribution,
463            &self.device(),
464        )))
465        .cast(self.dtype())
466    }
467
468    /// Calculate the variance along the given dimension.
469    pub fn var(self, dim: usize) -> Self {
470        stats::var(self, dim)
471    }
472
473    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
474    pub fn var_bias(self, dim: usize) -> Self {
475        stats::var_bias(self, dim)
476    }
477
478    /// Calculate the variance along the given dimension and also returns the mean.
479    pub fn var_mean(self, dim: usize) -> (Self, Self) {
480        let mean = self.clone().mean_dim(dim);
481        let var = stats::var_with_mean(self, mean.clone(), dim);
482        (var, mean)
483    }
484
485    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
486    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
487        let mean = self.clone().mean_dim(dim);
488        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
489        (var, mean)
490    }
491
492    /// Returns the median value along the specified dimension.
493    ///
494    /// The median is not unique for input tensors with an even number of elements
495    /// in the reduced dimension. In this case, the lower of the two medians is returned,
496    /// following PyTorch's behavior.
497    ///
498    /// # Note
499    ///
500    /// The current implementation performs a full sort along the specified dimension,
501    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
502    /// to CPU for the sort operation, which may result in slower performance compared
503    /// to native GPU operations.
504    ///
505    /// # Arguments
506    ///
507    /// - `dim` - The dimension along which to compute the median.
508    ///
509    /// # Returns
510    ///
511    /// - A tensor containing the median values along the specified dimension.
512    ///
513    /// # Example 1
514    ///
515    /// ```ignore
516    /// // Assuming backend B
517    /// let device = B::Device::default();
518    /// let tensor = Tensor::<B, 2>::from_data(
519    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
520    ///     &device,
521    /// );
522    ///
523    /// // Median along dimension 0:
524    /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]
525    /// let median = tensor.median(0);
526    /// // Result: [[1.0, 4.0, 3.0, 2.0]]
527    ///
528    /// // Median along dimension 1:
529    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
530    /// let median = tensor.median(1);
531    /// // Result: [[2.0], [6.0]]
532    /// ```
533    ///
534    /// # Example 2
535    ///
536    /// The median across all elements can be calculated as follows:
537    ///
538    /// ```ignore
539    /// // D is the number of dimensions of the tensor
540    /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);
541    ///
542    /// // Calculate median for dim 0 since the tensor has become 1 dimensional
543    /// let median = flattened_tensor.median(0);
544    /// // Result: [4.0]
545    /// ```
546    pub fn median(self, dim: usize) -> Self {
547        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
548        // instead of leveraging a full sort to get the median.
549        stats::median(self, dim)
550    }
551
552    /// Returns the median value along the specified dimension and its index.
553    ///
554    /// The median is not unique for input tensors with an even number of elements
555    /// in the reduced dimension. In this case, the lower of the two medians is returned,
556    /// following PyTorch's behavior.
557    ///
558    /// # Note
559    ///
560    /// The current implementation performs a full sort along the specified dimension,
561    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
562    /// to CPU for the sort operation, which may result in slower performance compared
563    /// to native GPU operations.
564    ///
565    /// # Arguments
566    ///
567    /// - `dim` - The dimension along which to compute the median.
568    ///
569    /// # Returns
570    ///
571    /// A tuple containing:
572    /// - A tensor with the median values.
573    /// - A tensor with the indices of the median values in the original tensor.
574    ///
575    /// # Example
576    ///
577    /// ```ignore
578    /// // Assuming backend B
579    /// let device = B::Device::default();
580    /// let tensor = Tensor::<B, 2>::from_data(
581    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
582    ///     &device,
583    /// );
584    ///
585    /// // Median along dimension 1:
586    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
587    /// let (values, indices) = tensor.median_with_indices(1);
588    /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)
589    /// ```
590    pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
591        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
592        // instead of leveraging a full sort to get the median.
593        stats::median_with_indices(self, dim)
594    }
595
596    /// Converts a tensor to the specified floating point data type.
597    ///
598    /// This is always a no-op when casting to the current dtype.
599    ///
600    /// # Warning
601    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
602    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
603    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
604        let dtype = dtype.into();
605        let self_type: FloatDType = self.dtype().into();
606        if dtype == self_type {
607            // no-op.
608            return self;
609        }
610
611        Tensor::new(TensorPrimitive::Float(B::float_cast(
612            self.primitive.tensor(),
613            dtype,
614        )))
615    }
616
617    /// Detach the current tensor from the autodiff graph.
618    ///
619    /// This function does nothing when autodiff is not enabled.
620    /// This can be used in batchers or elsewhere to ensure that previous operations are not
621    /// considered in the autodiff graph.
622    pub fn detach(self) -> Self {
623        Self::new(TensorPrimitive::Float(B::float_detach(
624            self.primitive.tensor(),
625        )))
626    }
627
628    /// Mark the tensor to keep gradients during the backward pass.
629    ///
630    /// This function does nothing when autodiff is not enabled.
631    pub fn require_grad(self) -> Self {
632        self.set_require_grad(true)
633    }
634
635    /// Returns true if the tensor requires gradients during the backward pass.
636    pub fn is_require_grad(&self) -> bool {
637        match &self.primitive {
638            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
639            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
640        }
641    }
642
643    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
644    /// When tracked, the gradients will be available after the backward pass.
645    ///
646    /// This function does nothing when autodiff is not enabled.
647    pub fn set_require_grad(self, require_grad: bool) -> Self {
648        let primitive = match self.primitive {
649            TensorPrimitive::Float(tensor) => {
650                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
651            }
652            TensorPrimitive::QFloat(tensor) => {
653                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
654            }
655        };
656        Self::new(primitive)
657    }
658
659    /// Applies the relu function to the tensor.
660    pub(crate) fn relu(self) -> Self {
661        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
662    }
663
664    /// Calculate covaraince matrix between different entries alongside a given dimension.
665    ///
666    /// # Arguments
667    ///
668    /// * `size` - The size of the square matrix.
669    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
670    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
671        let n = self.dims()[dim];
672        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
673        centered
674            .clone()
675            .transpose()
676            .matmul(centered)
677            .div_scalar(n as f32 - correction_factor as f32)
678    }
679
680    /// Convert the tensor to a lower precision data type based on the quantization scheme.
681    ///
682    /// # Arguments
683    ///
684    /// * `scheme` - The quantization scheme.
685    /// * `qparams` - The pre-computed quantization parameters.
686    ///
687    /// # Returns
688    ///
689    /// The quantized tensor.
690    pub fn quantize(
691        self,
692        scheme: &QuantScheme,
693        qparams: QuantizationParameters<B>,
694    ) -> Tensor<B, D> {
695        Tensor::new(TensorPrimitive::QFloat(B::quantize(
696            self.primitive.tensor(),
697            scheme,
698            QuantizationParametersPrimitive {
699                scales: qparams.scales.primitive.tensor(),
700            },
701        )))
702    }
703
704    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
705    ///
706    /// # Arguments
707    ///
708    /// * `scheme` - The quantization scheme.
709    ///
710    /// # Returns
711    ///
712    /// The quantized tensor.
713    ///
714    /// # Notes
715    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
716    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
717        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
718            self.primitive.tensor(),
719            scheme,
720        )))
721    }
722
723    /// Convert the tensor back to a higher precision data type.
724    ///
725    /// If the tensor is not quantized, its value is simply returned.
726    ///
727    /// # Returns
728    ///
729    /// The dequantized tensor.
730    pub fn dequantize(self) -> Tensor<B, D> {
731        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
732    }
733
734    /// Checks element wise if the tensor is close to another tensor.
735    ///
736    /// The tolerance is defined by the following equation:
737    ///
738    /// ```text
739    /// abs(a - b) <= (atol + rtol * abs(b))
740    ///
741    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
742    /// and `atol` is the absolute tolerance.
743    /// ```
744    ///
745    /// # Arguments
746    ///
747    /// * `other` - The tensor to compare with.
748    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
749    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
750    ///
751    /// # Returns
752    ///
753    /// A boolean tensor with the same shape as the input tensors.
754    ///
755    /// # Example
756    ///
757    /// ```rust
758    /// use burn_tensor::backend::Backend;
759    /// use burn_tensor::{Tensor, Shape};
760    ///
761    /// fn example<B: Backend>() {
762    ///    let device = B::Device::default();
763    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
764    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
765    ///    let tensor = tensor1.is_close(tensor2, None, None);
766    ///    println!("{tensor}");
767    ///    // [[true, true, true], [true, true, true]]
768    /// }
769    /// ```
770    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
771        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
772        let atol = atol.unwrap_or(DEFAULT_ATOL);
773
774        // check finite difference is close
775        let is_close_finite_val = self
776            .clone()
777            .sub(other.clone())
778            .abs()
779            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
780            .bool_and(self.clone().is_finite())
781            .bool_and(other.clone().is_finite());
782
783        // check if both are infinite and have same sign
784        let inf_same_sign = self
785            .clone()
786            .is_finite()
787            .bool_not()
788            .bool_and(other.clone().is_finite().bool_not())
789            .bool_and(self.equal(other));
790
791        is_close_finite_val.bool_or(inf_same_sign)
792    }
793
794    /// Checks if all elements are close to another tensor.
795    ///
796    /// The tolerance is defined by the following equation:
797    ///
798    /// ```text
799    ///
800    /// abs(a - b) <= (atol + rtol * abs(b))
801    ///
802    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
803    /// and `atol` is the absolute tolerance.
804    ///
805    /// ```
806    ///
807    /// # Arguments
808    ///
809    /// * `other` - The tensor to compare with.
810    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
811    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
812    ///
813    /// # Returns
814    ///
815    /// A boolean scalar.
816    ///
817    /// # Remarks
818    ///
819    /// # Example
820    ///
821    /// ```rust
822    /// use burn_tensor::backend::Backend;
823    /// use burn_tensor::{Tensor, Shape};
824    ///
825    /// fn example<B: Backend>() {
826    ///    let device = B::Device::default();
827    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
828    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
829    ///    let result = tensor1.all_close(tensor2, None, None);
830    ///    println!("{}", result);
831    ///    // true
832    /// }
833    /// ```
834    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
835        self.is_close(other, rtol, atol)
836            .all()
837            .into_scalar()
838            .to_bool()
839    }
840
841    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
842    ///
843    /// # Returns
844    ///
845    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
846    ///
847    /// # Example
848    ///
849    /// ```rust
850    /// use burn_tensor::backend::Backend;
851    /// use burn_tensor::{Tensor, Bool, Shape};
852    ///
853    /// fn example<B: Backend>() {
854    ///    let device = B::Device::default();
855    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
856    ///    let tensor = tensor.is_nan();
857    ///    println!("{tensor}");
858    ///    // [[false, true, false], [false, false, false]]
859    /// }
860    /// ```
861    pub fn is_nan(self) -> Tensor<B, D, Bool> {
862        Tensor::new(B::float_is_nan(self.primitive.tensor()))
863    }
864
865    /// Checks if the tensor contains any NaN values.
866    ///
867    /// # Returns
868    ///
869    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
870    ///
871    /// # Example
872    ///
873    /// ```rust
874    /// use burn_tensor::backend::Backend;
875    /// use burn_tensor::{Tensor, Bool, Shape};
876    ///
877    /// fn example<B: Backend>() {
878    ///   let device = B::Device::default();
879    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
880    ///   let tensor = tensor.contains_nan();
881    ///   println!("{tensor}");
882    ///   // [true]
883    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
884    ///   let tensor = tensor.contains_nan();
885    ///   println!("{tensor}");
886    ///   // [false]
887    /// }
888    /// ```
889    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
890        // Summing the tensor will result in NaN if the tensor contains any NaN values
891        // This is faster than checking each element individually
892        // because it rolls up the NaN values into a single value
893        let sum = self.sum();
894
895        sum.is_nan()
896    }
897
898    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
899    ///
900    /// # Returns
901    ///
902    /// A boolean tensor where `true` indicates that the value is infinite
903    ///
904    /// # Example
905    ///
906    /// ```rust
907    /// use burn_tensor::backend::Backend;
908    /// use burn_tensor::{Tensor, Bool, Shape};
909    ///
910    /// fn example<B: Backend>() {
911    ///    let device = B::Device::default();
912    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
913    ///    let tensor = tensor.is_finite();
914    ///    println!("{tensor}");
915    ///    // [[false, true, false], [false, false, false]]
916    /// }
917    /// ```
918    pub fn is_inf(self) -> Tensor<B, D, Bool> {
919        Tensor::new(B::float_is_inf(self.primitive.tensor()))
920    }
921
922    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
923    ///
924    /// # Returns
925    ///
926    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
927    /// either INF, -INF or NAN
928    ///
929    /// # Example
930    ///
931    /// ```rust
932    /// use burn_tensor::backend::Backend;
933    /// use burn_tensor::{Tensor, Bool, Shape};
934    ///
935    /// fn example<B: Backend>() {
936    ///    let device = B::Device::default();
937    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
938    ///    let tensor = tensor.is_finite();
939    ///    println!("{tensor}");
940    ///    // [[true, false, true], [false, true, true]]
941    /// }
942    /// ```
943    pub fn is_finite(self) -> Tensor<B, D, Bool> {
944        self.clone()
945            .is_nan()
946            .bool_not()
947            .bool_and(self.is_inf().bool_not())
948    }
949
950    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
951    /// using the given locations in [-1, 1].
952    ///
953    /// # Arguments
954    ///
955    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
956    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
957    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
958    ///
959    /// # Returns
960    ///
961    /// A tensor with shape (N, C, H_out, W_out)
962    ///
963    /// # Example
964    ///
965    /// ```ignore
966    /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
967    ///
968    /// // Default options (bilinear, zeros padding, align_corners=false)
969    /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
970    ///
971    /// // Custom options
972    /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
973    ///     .with_padding_mode(GridSamplePaddingMode::Border)
974    ///     .with_align_corners(true);
975    /// let output = tensor.grid_sample_2d(grid, options);
976    /// ```
977    pub fn grid_sample_2d(
978        self,
979        grid: Tensor<B, D>,
980        options: impl Into<GridSampleOptions>,
981    ) -> Tensor<B, D> {
982        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
983            self.primitive.tensor(),
984            grid.primitive.tensor(),
985            options.into(),
986        )))
987    }
988
989    /// Computes the cross product of `self` and another tensor along a given dimension.
990    ///
991    /// Both `self` and `other` **must have size 3** along the specified `dim`,
992    /// because the cross product is only defined in three-dimensional space.
993    ///
994    /// # Arguments
995    ///
996    /// * `other` - The other tensor to take the cross product with.
997    /// * `dim`   - The dimension along which to compute the cross product.
998    ///
999    /// # Returns
1000    ///
1001    /// A tensor containing the cross product of `self` and `other` along `dim`.
1002    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
1003        let dim = dim.expect_dim_index(D);
1004        check!(TensorCheck::cross(&self, &other, dim));
1005        Tensor::new(TensorPrimitive::Float(B::float_cross(
1006            self.primitive.tensor(),
1007            other.primitive.tensor(),
1008            dim,
1009        )))
1010    }
1011}