burn_tensor/tensor/api/
float.rs

1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Int, TensorPrimitive};
13use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
14
15/// Default RTOL value for `is_close` and `all_close`.
16pub const DEFAULT_RTOL: f64 = 1e-5;
17
18/// Default ATOL value for `is_close` and `all_close`.
19pub const DEFAULT_ATOL: f64 = 1e-8;
20
21impl<const D: usize, B> Tensor<B, D>
22where
23    B: Backend,
24{
25    /// Applies element wise exponential operation.
26    ///
27    #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
28    #[cfg_attr(not(doc), doc = "`y = e^x`")]
29    pub fn exp(self) -> Self {
30        Self::new(TensorPrimitive::Float(B::float_exp(
31            self.primitive.tensor(),
32        )))
33    }
34
35    /// Applies element wise natural log operation *ln*.
36    ///
37    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
38    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
39    pub fn log(self) -> Self {
40        Self::new(TensorPrimitive::Float(B::float_log(
41            self.primitive.tensor(),
42        )))
43    }
44
45    /// Applies the natural logarithm of one plus the input tensor, element-wise.
46    ///
47    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
48    #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
49    pub fn log1p(self) -> Self {
50        Self::new(TensorPrimitive::Float(B::float_log1p(
51            self.primitive.tensor(),
52        )))
53    }
54
55    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
56    ///
57    #[cfg_attr(
58        doc,
59        doc = r#"
60$y_i = \text{erf}\(x_i\)$
61
62The error function is defined as:
63
64$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
65"#
66    )]
67    #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
68    pub fn erf(self) -> Self {
69        Self::new(TensorPrimitive::Float(B::float_erf(
70            self.primitive.tensor(),
71        )))
72    }
73
74    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
75    /// (or multiplicative inverse) element wise.
76    ///
77    #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
78    #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
79    pub fn recip(self) -> Self {
80        Self::new(TensorPrimitive::Float(B::float_recip(
81            self.primitive.tensor(),
82        )))
83    }
84
85    /// Applies element wise square operation.
86    ///
87    #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
88    #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
89    pub fn square(self) -> Self {
90        self.powi_scalar(2)
91    }
92
93    /// Applies element wise root square operation.
94    ///
95    #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
96    #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
97    pub fn sqrt(self) -> Self {
98        Self::new(TensorPrimitive::Float(B::float_sqrt(
99            self.primitive.tensor(),
100        )))
101    }
102
103    /// Applies element wise cosine operation.
104    ///
105    #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
106    #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
107    pub fn cos(self) -> Self {
108        Self::new(TensorPrimitive::Float(B::float_cos(
109            self.primitive.tensor(),
110        )))
111    }
112
113    /// Applies element wise sine operation.
114    ///
115    #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
116    #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
117    pub fn sin(self) -> Self {
118        Self::new(TensorPrimitive::Float(B::float_sin(
119            self.primitive.tensor(),
120        )))
121    }
122
123    /// Applies element wise tangent operation.
124    ///
125    #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
126    #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
127    pub fn tan(self) -> Self {
128        Self::new(TensorPrimitive::Float(B::float_tan(
129            self.primitive.tensor(),
130        )))
131    }
132
133    /// Applies element wise hyperbolic cosine operation.
134    ///
135    #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
136    #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
137    ///
138    /// # Example
139    ///
140    /// ```rust
141    /// use burn_tensor::backend::Backend;
142    /// use burn_tensor::Tensor;
143    ///
144    /// fn example<B: Backend>() {
145    ///     let device = Default::default();
146    ///
147    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
148    ///     println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
149    /// }
150    /// ```
151    pub fn cosh(self) -> Self {
152        Self::new(TensorPrimitive::Float(B::float_cosh(
153            self.primitive.tensor(),
154        )))
155    }
156
157    /// Applies element wise hyperbolic sine operation.
158    ///
159    #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
160    #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
161    ///
162    /// # Example
163    ///
164    /// ```rust
165    /// use burn_tensor::backend::Backend;
166    /// use burn_tensor::Tensor;
167    ///
168    /// fn example<B: Backend>() {
169    ///     let device = Default::default();
170    ///
171    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
172    ///     println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
173    /// }
174    /// ```
175    pub fn sinh(self) -> Self {
176        Self::new(TensorPrimitive::Float(B::float_sinh(
177            self.primitive.tensor(),
178        )))
179    }
180
181    /// Applies element wise hyperbolic tangent operation.
182    ///
183    #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
184    #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
185    ///
186    /// # Example
187    ///
188    /// ```rust
189    /// use burn_tensor::backend::Backend;
190    /// use burn_tensor::Tensor;
191    ///
192    /// fn example<B: Backend>() {
193    ///     let device = Default::default();
194    ///
195    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
196    ///     println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
197    /// }
198    /// ```
199    pub fn tanh(self) -> Self {
200        Self::new(TensorPrimitive::Float(B::float_tanh(
201            self.primitive.tensor(),
202        )))
203    }
204
205    /// Applies element wise inverse sine operation.
206    ///
207    #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
208    #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
209    ///
210    /// # Example
211    ///
212    /// ```rust
213    /// use burn_tensor::backend::Backend;
214    /// use burn_tensor::Tensor;
215    ///
216    /// fn example<B: Backend>() {
217    ///     let device = Default::default();
218    ///
219    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
220    ///     println!("{}", tensor.asin()); // [ 0.0000, -1.5708,  1.5708]
221    /// }
222    /// ```
223    pub fn asin(self) -> Self {
224        Self::new(TensorPrimitive::Float(B::float_asin(
225            self.primitive.tensor(),
226        )))
227    }
228
229    /// Applies element wise inverse hyperbolic sine operation.
230    ///
231    #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
232    #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
233    ///
234    /// # Example
235    ///
236    /// ```rust
237    /// use burn_tensor::backend::Backend;
238    /// use burn_tensor::Tensor;
239    ///
240    /// fn example<B: Backend>() {
241    ///     let device = Default::default();
242    ///
243    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
244    ///     println!("{}", tensor.asinh()); // [ 0.0000, -0.8814,  0.8814]
245    /// }
246    /// ```
247    pub fn asinh(self) -> Self {
248        Self::new(TensorPrimitive::Float(B::float_asinh(
249            self.primitive.tensor(),
250        )))
251    }
252
253    /// Applies element wise inverse cosine operation.
254    ///
255    #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
256    #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
257    ///
258    /// # Example
259    ///
260    /// ```rust
261    /// use burn_tensor::backend::Backend;
262    /// use burn_tensor::Tensor;
263    ///
264    /// fn example<B: Backend>() {
265    ///     let device = Default::default();
266    ///
267    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
268    ///     println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
269    /// }
270    /// ```
271    pub fn acos(self) -> Self {
272        Self::new(TensorPrimitive::Float(B::float_acos(
273            self.primitive.tensor(),
274        )))
275    }
276
277    /// Applies element wise inverse hyperbolic cosine operation.
278    ///
279    #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
280    #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
281    ///
282    /// # Example
283    ///
284    /// ```rust
285    /// use burn_tensor::backend::Backend;
286    /// use burn_tensor::Tensor;
287    ///
288    /// fn example<B: Backend>() {
289    ///     let device = Default::default();
290    ///
291    ///     let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
292    ///     println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627]
293    /// }
294    /// ```
295    pub fn acosh(self) -> Self {
296        Self::new(TensorPrimitive::Float(B::float_acosh(
297            self.primitive.tensor(),
298        )))
299    }
300
301    /// Applies element wise inverse tangent operation.
302    ///
303    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
304    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
305    ///
306    /// # Example
307    ///
308    /// ```rust
309    /// use burn_tensor::backend::Backend;
310    /// use burn_tensor::Tensor;
311    ///
312    /// fn example<B: Backend>() {
313    ///     let device = Default::default();
314    ///
315    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
316    ///     println!("{}", tensor.sinh()); // [ 0.0, -0.7854,  1.1071]
317    /// }
318    /// ```
319    pub fn atan(self) -> Self {
320        Self::new(TensorPrimitive::Float(B::float_atan(
321            self.primitive.tensor(),
322        )))
323    }
324
325    /// Applies element wise inverse hyperbolic tangent operation.
326    ///
327    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
328    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
329    ///
330    /// # Example
331    ///
332    /// ```rust
333    /// use burn_tensor::backend::Backend;
334    /// use burn_tensor::Tensor;
335    ///
336    /// fn example<B: Backend>() {
337    ///     let device = Default::default();
338    ///
339    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
340    ///     println!("{}", tensor.sinh()); // [ 0.0, -0.5493,  0.5493]
341    /// }
342    /// ```
343    pub fn atanh(self) -> Self {
344        Self::new(TensorPrimitive::Float(B::float_atanh(
345            self.primitive.tensor(),
346        )))
347    }
348
349    /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
350    ///
351    #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
352    #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
353    ///
354    /// # Example
355    ///
356    /// ```rust
357    /// use burn_tensor::backend::Backend;
358    /// use burn_tensor::Tensor;
359    ///
360    /// fn example<B: Backend>() {
361    ///     let device = Default::default();
362    ///
363    ///     let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
364    ///     let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
365    ///     println!("{}", lhs.atan2(rhs)); // [-1.1071,  2.0344, -2.0344]
366    /// }
367    /// ```
368    pub fn atan2(self, other: Self) -> Self {
369        Self::new(TensorPrimitive::Float(B::float_atan2(
370            self.primitive.tensor(),
371            other.primitive.tensor(),
372        )))
373    }
374
375    /// Applies element wise round operation.
376    ///
377    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
378    /// strategy, with halfway cases rounded to the nearest even integer value.
379    pub fn round(self) -> Self {
380        Self::new(TensorPrimitive::Float(B::float_round(
381            self.primitive.tensor(),
382        )))
383    }
384
385    /// Applies element wise floor operation.
386    pub fn floor(self) -> Self {
387        Self::new(TensorPrimitive::Float(B::float_floor(
388            self.primitive.tensor(),
389        )))
390    }
391
392    /// Applies element wise ceil operation.
393    pub fn ceil(self) -> Self {
394        Self::new(TensorPrimitive::Float(B::float_ceil(
395            self.primitive.tensor(),
396        )))
397    }
398
399    /// Create a tensor from floats (f32) on a given device.
400    ///
401    /// # Example
402    ///
403    /// ```rust
404    /// use burn_tensor::backend::Backend;
405    /// use burn_tensor::Tensor;
406    ///
407    /// fn example<B: Backend>() {
408    ///     let device = B::Device::default();
409    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
410    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
411    /// }
412    /// ```
413    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
414        Self::from_data(floats.into().convert::<f32>(), device)
415    }
416
417    /// Returns a new tensor with the same shape and device as the current tensor and the data
418    /// cast to Integer.
419    ///
420    /// # Example
421    ///
422    /// ```rust
423    /// use burn_tensor::backend::Backend;
424    /// use burn_tensor::Tensor;
425    ///
426    /// fn example<B: Backend>() {
427    ///     let device = Default::default();
428    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
429    ///     let int_tensor = float_tensor.int();
430    /// }
431    /// ```
432    pub fn int(self) -> Tensor<B, D, Int> {
433        Tensor::new(B::float_into_int(self.primitive.tensor()))
434    }
435
436    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
437    /// values sampled from the given distribution.
438    pub fn random_like(&self, distribution: Distribution) -> Self {
439        Self::new(TensorPrimitive::Float(B::float_random(
440            self.shape(),
441            distribution,
442            &self.device(),
443        )))
444        .cast(self.dtype())
445    }
446
447    /// Calculate the variance along the given dimension.
448    pub fn var(self, dim: usize) -> Self {
449        stats::var(self, dim)
450    }
451
452    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
453    pub fn var_bias(self, dim: usize) -> Self {
454        stats::var_bias(self, dim)
455    }
456
457    /// Calculate the variance along the given dimension and also returns the mean.
458    pub fn var_mean(self, dim: usize) -> (Self, Self) {
459        let mean = self.clone().mean_dim(dim);
460        let var = stats::var_with_mean(self, mean.clone(), dim);
461        (var, mean)
462    }
463
464    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
465    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
466        let mean = self.clone().mean_dim(dim);
467        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
468        (var, mean)
469    }
470
471    /// Converts a tensor to the specified floating point data type.
472    ///
473    /// This is always a no-op when casting to the current dtype.
474    ///
475    /// # Warning
476    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
477    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
478    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
479        let dtype = dtype.into();
480        let self_type: FloatDType = self.dtype().into();
481        if dtype == self_type {
482            // no-op.
483            return self;
484        }
485
486        Tensor::new(TensorPrimitive::Float(B::float_cast(
487            self.primitive.tensor(),
488            dtype,
489        )))
490    }
491
492    /// Detach the current tensor from the autodiff graph.
493    ///
494    /// This function does nothing when autodiff is not enabled.
495    /// This can be used in batchers or elsewhere to ensure that previous operations are not
496    /// considered in the autodiff graph.
497    pub fn detach(self) -> Self {
498        Self::new(TensorPrimitive::Float(B::float_detach(
499            self.primitive.tensor(),
500        )))
501    }
502
503    /// Mark the tensor to keep gradients during the backward pass.
504    ///
505    /// This function does nothing when autodiff is not enabled.
506    pub fn require_grad(self) -> Self {
507        self.set_require_grad(true)
508    }
509
510    /// Returns true if the tensor requires gradients during the backward pass.
511    pub fn is_require_grad(&self) -> bool {
512        match &self.primitive {
513            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
514            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
515        }
516    }
517
518    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
519    /// When tracked, the gradients will be available after the backward pass.
520    ///
521    /// This function does nothing when autodiff is not enabled.
522    pub fn set_require_grad(self, require_grad: bool) -> Self {
523        let primitive = match self.primitive {
524            TensorPrimitive::Float(tensor) => {
525                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
526            }
527            TensorPrimitive::QFloat(tensor) => {
528                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
529            }
530        };
531        Self::new(primitive)
532    }
533
534    /// Applies the relu function to the tensor.
535    pub(crate) fn relu(self) -> Self {
536        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
537    }
538
539    /// Calculate covaraince matrix between different entries alongside a given dimension.
540    ///
541    /// # Arguments
542    ///
543    /// * `size` - The size of the square matrix.
544    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
545    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
546        let n = self.dims()[dim];
547        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
548        centered
549            .clone()
550            .transpose()
551            .matmul(centered)
552            .div_scalar(n as f32 - correction_factor as f32)
553    }
554
555    /// Convert the tensor to a lower precision data type based on the quantization scheme.
556    ///
557    /// # Arguments
558    ///
559    /// * `scheme` - The quantization scheme.
560    /// * `qparams` - The pre-computed quantization parameters.
561    ///
562    /// # Returns
563    ///
564    /// The quantized tensor.
565    pub fn quantize(
566        self,
567        scheme: &QuantScheme,
568        qparams: QuantizationParameters<B>,
569    ) -> Tensor<B, D> {
570        Tensor::new(TensorPrimitive::QFloat(B::quantize(
571            self.primitive.tensor(),
572            scheme,
573            QuantizationParametersPrimitive {
574                scales: qparams.scales.primitive.tensor(),
575            },
576        )))
577    }
578
579    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
580    ///
581    /// # Arguments
582    ///
583    /// * `scheme` - The quantization scheme.
584    ///
585    /// # Returns
586    ///
587    /// The quantized tensor.
588    ///
589    /// # Notes
590    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
591    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
592        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
593            self.primitive.tensor(),
594            scheme,
595        )))
596    }
597
598    /// Convert the tensor back to a higher precision data type.
599    ///
600    /// If the tensor is not quantized, its value is simply returned.
601    ///
602    /// # Returns
603    ///
604    /// The dequantized tensor.
605    pub fn dequantize(self) -> Tensor<B, D> {
606        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
607    }
608
609    /// Checks element wise if the tensor is close to another tensor.
610    ///
611    /// The tolerance is defined by the following equation:
612    ///
613    /// ```text
614    /// abs(a - b) <= (atol + rtol * abs(b))
615    ///
616    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
617    /// and `atol` is the absolute tolerance.
618    /// ```
619    ///
620    /// # Arguments
621    ///
622    /// * `other` - The tensor to compare with.
623    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
624    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
625    ///
626    /// # Returns
627    ///
628    /// A boolean tensor with the same shape as the input tensors.
629    ///
630    /// # Example
631    ///
632    /// ```rust
633    /// use burn_tensor::backend::Backend;
634    /// use burn_tensor::{Tensor, Shape};
635    ///
636    /// fn example<B: Backend>() {
637    ///    let device = B::Device::default();
638    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
639    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
640    ///    let tensor = tensor1.is_close(tensor2, None, None);
641    ///    println!("{tensor}");
642    ///    // [[true, true, true], [true, true, true]]
643    /// }
644    /// ```
645    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
646        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
647        let atol = atol.unwrap_or(DEFAULT_ATOL);
648
649        // check finite difference is close
650        let is_close_finite_val = self
651            .clone()
652            .sub(other.clone())
653            .abs()
654            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
655            .bool_and(self.clone().is_finite())
656            .bool_and(other.clone().is_finite());
657
658        // check if both are infinite and have same sign
659        let inf_same_sign = self
660            .clone()
661            .is_finite()
662            .bool_not()
663            .bool_and(other.clone().is_finite().bool_not())
664            .bool_and(self.equal(other));
665
666        is_close_finite_val.bool_or(inf_same_sign)
667    }
668
669    /// Checks if all elements are close to another tensor.
670    ///
671    /// The tolerance is defined by the following equation:
672    ///
673    /// ```text
674    ///
675    /// abs(a - b) <= (atol + rtol * abs(b))
676    ///
677    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
678    /// and `atol` is the absolute tolerance.
679    ///
680    /// ```
681    ///
682    /// # Arguments
683    ///
684    /// * `other` - The tensor to compare with.
685    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
686    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
687    ///
688    /// # Returns
689    ///
690    /// A boolean scalar.
691    ///
692    /// # Remarks
693    ///
694    /// # Example
695    ///
696    /// ```rust
697    /// use burn_tensor::backend::Backend;
698    /// use burn_tensor::{Tensor, Shape};
699    ///
700    /// fn example<B: Backend>() {
701    ///    let device = B::Device::default();
702    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
703    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
704    ///    let result = tensor1.all_close(tensor2, None, None);
705    ///    println!("{}", result);
706    ///    // true
707    /// }
708    /// ```
709    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
710        self.is_close(other, rtol, atol)
711            .all()
712            .into_scalar()
713            .to_bool()
714    }
715
716    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
717    ///
718    /// # Returns
719    ///
720    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
721    ///
722    /// # Example
723    ///
724    /// ```rust
725    /// use burn_tensor::backend::Backend;
726    /// use burn_tensor::{Tensor, Bool, Shape};
727    ///
728    /// fn example<B: Backend>() {
729    ///    let device = B::Device::default();
730    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
731    ///    let tensor = tensor.is_nan();
732    ///    println!("{tensor}");
733    ///    // [[false, true, false], [false, false, false]]
734    /// }
735    /// ```
736    pub fn is_nan(self) -> Tensor<B, D, Bool> {
737        Tensor::new(B::float_is_nan(self.primitive.tensor()))
738    }
739
740    /// Checks if the tensor contains any NaN values.
741    ///
742    /// # Returns
743    ///
744    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
745    ///
746    /// # Example
747    ///
748    /// ```rust
749    /// use burn_tensor::backend::Backend;
750    /// use burn_tensor::{Tensor, Bool, Shape};
751    ///
752    /// fn example<B: Backend>() {
753    ///   let device = B::Device::default();
754    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
755    ///   let tensor = tensor.contains_nan();
756    ///   println!("{tensor}");
757    ///   // [true]
758    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
759    ///   let tensor = tensor.contains_nan();
760    ///   println!("{tensor}");
761    ///   // [false]
762    /// }
763    /// ```
764    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
765        // Summing the tensor will result in NaN if the tensor contains any NaN values
766        // This is faster than checking each element individually
767        // because it rolls up the NaN values into a single value
768        let sum = self.sum();
769
770        sum.is_nan()
771    }
772
773    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
774    ///
775    /// # Returns
776    ///
777    /// A boolean tensor where `true` indicates that the value is infinite
778    ///
779    /// # Example
780    ///
781    /// ```rust
782    /// use burn_tensor::backend::Backend;
783    /// use burn_tensor::{Tensor, Bool, Shape};
784    ///
785    /// fn example<B: Backend>() {
786    ///    let device = B::Device::default();
787    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
788    ///    let tensor = tensor.is_finite();
789    ///    println!("{tensor}");
790    ///    // [[false, true, false], [false, false, false]]
791    /// }
792    /// ```
793    pub fn is_inf(self) -> Tensor<B, D, Bool> {
794        Tensor::new(B::float_is_inf(self.primitive.tensor()))
795    }
796
797    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
798    ///
799    /// # Returns
800    ///
801    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
802    /// either INF, -INF or NAN
803    ///
804    /// # Example
805    ///
806    /// ```rust
807    /// use burn_tensor::backend::Backend;
808    /// use burn_tensor::{Tensor, Bool, Shape};
809    ///
810    /// fn example<B: Backend>() {
811    ///    let device = B::Device::default();
812    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
813    ///    let tensor = tensor.is_finite();
814    ///    println!("{tensor}");
815    ///    // [[true, false, true], [false, true, true]]
816    /// }
817    /// ```
818    pub fn is_finite(self) -> Tensor<B, D, Bool> {
819        self.clone()
820            .is_nan()
821            .bool_not()
822            .bool_and(self.is_inf().bool_not())
823    }
824
825    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
826    /// using the given locations in [-1, 1].
827    ///
828    /// # Arguments
829    ///
830    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
831    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
832    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
833    ///
834    /// # Returns
835    ///
836    /// A tensor with shape (N, C, H_out, W_out)
837    ///
838    /// # Example
839    ///
840    /// ```ignore
841    /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
842    ///
843    /// // Default options (bilinear, zeros padding, align_corners=false)
844    /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
845    ///
846    /// // Custom options
847    /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
848    ///     .with_padding_mode(GridSamplePaddingMode::Border)
849    ///     .with_align_corners(true);
850    /// let output = tensor.grid_sample_2d(grid, options);
851    /// ```
852    pub fn grid_sample_2d(
853        self,
854        grid: Tensor<B, D>,
855        options: impl Into<GridSampleOptions>,
856    ) -> Tensor<B, D> {
857        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
858            self.primitive.tensor(),
859            grid.primitive.tensor(),
860            options.into(),
861        )))
862    }
863
864    /// Computes the cross product of `self` and another tensor along a given dimension.
865    ///
866    /// Both `self` and `other` **must have size 3** along the specified `dim`,
867    /// because the cross product is only defined in three-dimensional space.
868    ///
869    /// # Arguments
870    ///
871    /// * `other` - The other tensor to take the cross product with.
872    /// * `dim`   - The dimension along which to compute the cross product.
873    ///
874    /// # Returns
875    ///
876    /// A tensor containing the cross product of `self` and `other` along `dim`.
877    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
878        let dim = dim.expect_dim_index(D);
879        check!(TensorCheck::cross(&self, &other, dim));
880        Tensor::new(TensorPrimitive::Float(B::float_cross(
881            self.primitive.tensor(),
882            other.primitive.tensor(),
883            dim,
884        )))
885    }
886}