burn_tensor/tensor/api/
float.rs

1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Int, TensorPrimitive};
13use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
14
15/// Default RTOL value for `is_close` and `all_close`.
16pub const DEFAULT_RTOL: f64 = 1e-5;
17
18/// Default ATOL value for `is_close` and `all_close`.
19pub const DEFAULT_ATOL: f64 = 1e-8;
20
21impl<const D: usize, B> Tensor<B, D>
22where
23    B: Backend,
24{
25    /// Applies element wise exponential operation.
26    ///
27    #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
28    #[cfg_attr(not(doc), doc = "`y = e^x`")]
29    pub fn exp(self) -> Self {
30        Self::new(TensorPrimitive::Float(B::float_exp(
31            self.primitive.tensor(),
32        )))
33    }
34
35    /// Applies element wise natural log operation *ln*.
36    ///
37    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
38    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
39    pub fn log(self) -> Self {
40        Self::new(TensorPrimitive::Float(B::float_log(
41            self.primitive.tensor(),
42        )))
43    }
44
45    /// Applies the natural logarithm of one plus the input tensor, element-wise.
46    ///
47    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
48    #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
49    pub fn log1p(self) -> Self {
50        Self::new(TensorPrimitive::Float(B::float_log1p(
51            self.primitive.tensor(),
52        )))
53    }
54
55    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
56    ///
57    #[cfg_attr(
58        doc,
59        doc = r#"
60$y_i = \text{erf}\(x_i\)$
61
62The error function is defined as:
63
64$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
65"#
66    )]
67    #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
68    pub fn erf(self) -> Self {
69        Self::new(TensorPrimitive::Float(B::float_erf(
70            self.primitive.tensor(),
71        )))
72    }
73
74    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
75    /// (or multiplicative inverse) element wise.
76    ///
77    #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
78    #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
79    pub fn recip(self) -> Self {
80        Self::new(TensorPrimitive::Float(B::float_recip(
81            self.primitive.tensor(),
82        )))
83    }
84
85    /// Applies element wise square operation.
86    ///
87    #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
88    #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
89    pub fn square(self) -> Self {
90        self.powi_scalar(2)
91    }
92
93    /// Applies element wise root square operation.
94    ///
95    #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
96    #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
97    pub fn sqrt(self) -> Self {
98        Self::new(TensorPrimitive::Float(B::float_sqrt(
99            self.primitive.tensor(),
100        )))
101    }
102
103    /// Applies element wise cosine operation.
104    ///
105    #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
106    #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
107    pub fn cos(self) -> Self {
108        Self::new(TensorPrimitive::Float(B::float_cos(
109            self.primitive.tensor(),
110        )))
111    }
112
113    /// Applies element wise sine operation.
114    ///
115    #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
116    #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
117    pub fn sin(self) -> Self {
118        Self::new(TensorPrimitive::Float(B::float_sin(
119            self.primitive.tensor(),
120        )))
121    }
122
123    /// Applies element wise tangent operation.
124    ///
125    #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
126    #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
127    pub fn tan(self) -> Self {
128        Self::new(TensorPrimitive::Float(B::float_tan(
129            self.primitive.tensor(),
130        )))
131    }
132
133    /// Applies element wise hyperbolic cosine operation.
134    ///
135    #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
136    #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
137    ///
138    /// # Example
139    ///
140    /// ```rust
141    /// use burn_tensor::backend::Backend;
142    /// use burn_tensor::Tensor;
143    ///
144    /// fn example<B: Backend>() {
145    ///     let device = Default::default();
146    ///
147    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
148    ///     println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
149    /// }
150    /// ```
151    pub fn cosh(self) -> Self {
152        Self::new(TensorPrimitive::Float(B::float_cosh(
153            self.primitive.tensor(),
154        )))
155    }
156
157    /// Applies element wise hyperbolic sine operation.
158    ///
159    #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
160    #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
161    ///
162    /// # Example
163    ///
164    /// ```rust
165    /// use burn_tensor::backend::Backend;
166    /// use burn_tensor::Tensor;
167    ///
168    /// fn example<B: Backend>() {
169    ///     let device = Default::default();
170    ///
171    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
172    ///     println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
173    /// }
174    /// ```
175    pub fn sinh(self) -> Self {
176        Self::new(TensorPrimitive::Float(B::float_sinh(
177            self.primitive.tensor(),
178        )))
179    }
180
181    /// Applies element wise hyperbolic tangent operation.
182    ///
183    #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
184    #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
185    ///
186    /// # Example
187    ///
188    /// ```rust
189    /// use burn_tensor::backend::Backend;
190    /// use burn_tensor::Tensor;
191    ///
192    /// fn example<B: Backend>() {
193    ///     let device = Default::default();
194    ///
195    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
196    ///     println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
197    /// }
198    /// ```
199    pub fn tanh(self) -> Self {
200        Self::new(TensorPrimitive::Float(B::float_tanh(
201            self.primitive.tensor(),
202        )))
203    }
204
205    /// Applies element wise round operation.
206    ///
207    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
208    /// strategy, with halfway cases rounded to the nearest even integer value.
209    pub fn round(self) -> Self {
210        Self::new(TensorPrimitive::Float(B::float_round(
211            self.primitive.tensor(),
212        )))
213    }
214
215    /// Applies element wise floor operation.
216    pub fn floor(self) -> Self {
217        Self::new(TensorPrimitive::Float(B::float_floor(
218            self.primitive.tensor(),
219        )))
220    }
221
222    /// Applies element wise ceil operation.
223    pub fn ceil(self) -> Self {
224        Self::new(TensorPrimitive::Float(B::float_ceil(
225            self.primitive.tensor(),
226        )))
227    }
228
229    /// Create a tensor from floats (f32) on a given device.
230    ///
231    /// # Example
232    ///
233    /// ```rust
234    /// use burn_tensor::backend::Backend;
235    /// use burn_tensor::Tensor;
236    ///
237    /// fn example<B: Backend>() {
238    ///     let device = B::Device::default();
239    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
240    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
241    /// }
242    /// ```
243    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
244        Self::from_data(floats.into().convert::<f32>(), device)
245    }
246
247    /// Returns a new tensor with the same shape and device as the current tensor and the data
248    /// cast to Integer.
249    ///
250    /// # Example
251    ///
252    /// ```rust
253    /// use burn_tensor::backend::Backend;
254    /// use burn_tensor::Tensor;
255    ///
256    /// fn example<B: Backend>() {
257    ///     let device = Default::default();
258    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
259    ///     let int_tensor = float_tensor.int();
260    /// }
261    /// ```
262    pub fn int(self) -> Tensor<B, D, Int> {
263        Tensor::new(B::float_into_int(self.primitive.tensor()))
264    }
265
266    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
267    /// values sampled from the given distribution.
268    pub fn random_like(&self, distribution: Distribution) -> Self {
269        Self::new(TensorPrimitive::Float(B::float_random(
270            self.shape(),
271            distribution,
272            &self.device(),
273        )))
274        .cast(self.dtype())
275    }
276
277    /// Calculate the variance along the given dimension.
278    pub fn var(self, dim: usize) -> Self {
279        stats::var(self, dim)
280    }
281
282    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
283    pub fn var_bias(self, dim: usize) -> Self {
284        stats::var_bias(self, dim)
285    }
286
287    /// Calculate the variance along the given dimension and also returns the mean.
288    pub fn var_mean(self, dim: usize) -> (Self, Self) {
289        let mean = self.clone().mean_dim(dim);
290        let var = stats::var_with_mean(self, mean.clone(), dim);
291        (var, mean)
292    }
293
294    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
295    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
296        let mean = self.clone().mean_dim(dim);
297        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
298        (var, mean)
299    }
300
301    /// Converts a tensor to the specified floating point data type.
302    ///
303    /// This is always a no-op when casting to the current dtype.
304    ///
305    /// # Warning
306    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
307    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
308    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
309        let dtype = dtype.into();
310        let self_type: FloatDType = self.dtype().into();
311        if dtype == self_type {
312            // no-op.
313            return self;
314        }
315
316        Tensor::new(TensorPrimitive::Float(B::float_cast(
317            self.primitive.tensor(),
318            dtype,
319        )))
320    }
321
322    /// Detach the current tensor from the autodiff graph.
323    ///
324    /// This function does nothing when autodiff is not enabled.
325    /// This can be used in batchers or elsewhere to ensure that previous operations are not
326    /// considered in the autodiff graph.
327    pub fn detach(self) -> Self {
328        Self::new(TensorPrimitive::Float(B::float_detach(
329            self.primitive.tensor(),
330        )))
331    }
332
333    /// Mark the tensor to keep gradients during the backward pass.
334    ///
335    /// This function does nothing when autodiff is not enabled.
336    pub fn require_grad(self) -> Self {
337        self.set_require_grad(true)
338    }
339
340    /// Returns true if the tensor requires gradients during the backward pass.
341    pub fn is_require_grad(&self) -> bool {
342        match &self.primitive {
343            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
344            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
345        }
346    }
347
348    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
349    /// When tracked, the gradients will be available after the backward pass.
350    ///
351    /// This function does nothing when autodiff is not enabled.
352    pub fn set_require_grad(self, require_grad: bool) -> Self {
353        let primitive = match self.primitive {
354            TensorPrimitive::Float(tensor) => {
355                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
356            }
357            TensorPrimitive::QFloat(tensor) => {
358                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
359            }
360        };
361        Self::new(primitive)
362    }
363
364    /// Applies the relu function to the tensor.
365    pub(crate) fn relu(self) -> Self {
366        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
367    }
368
369    /// Calculate covaraince matrix between different entries alongside a given dimension.
370    ///
371    /// # Arguments
372    ///
373    /// * `size` - The size of the square matrix.
374    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
375    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
376        let n = self.dims()[dim];
377        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
378        centered
379            .clone()
380            .transpose()
381            .matmul(centered)
382            .div_scalar(n as f32 - correction_factor as f32)
383    }
384
385    /// Convert the tensor to a lower precision data type based on the quantization scheme.
386    ///
387    /// # Arguments
388    ///
389    /// * `scheme` - The quantization scheme.
390    /// * `qparams` - The pre-computed quantization parameters.
391    ///
392    /// # Returns
393    ///
394    /// The quantized tensor.
395    pub fn quantize(
396        self,
397        scheme: &QuantScheme,
398        qparams: QuantizationParameters<B>,
399    ) -> Tensor<B, D> {
400        Tensor::new(TensorPrimitive::QFloat(B::quantize(
401            self.primitive.tensor(),
402            scheme,
403            QuantizationParametersPrimitive {
404                scales: qparams.scales.primitive.tensor(),
405            },
406        )))
407    }
408
409    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
410    ///
411    /// # Arguments
412    ///
413    /// * `scheme` - The quantization scheme.
414    ///
415    /// # Returns
416    ///
417    /// The quantized tensor.
418    ///
419    /// # Notes
420    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
421    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
422        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
423            self.primitive.tensor(),
424            scheme,
425        )))
426    }
427
428    /// Convert the tensor back to a higher precision data type.
429    ///
430    /// If the tensor is not quantized, its value is simply returned.
431    ///
432    /// # Returns
433    ///
434    /// The dequantized tensor.
435    pub fn dequantize(self) -> Tensor<B, D> {
436        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
437    }
438
439    /// Checks element wise if the tensor is close to another tensor.
440    ///
441    /// The tolerance is defined by the following equation:
442    ///
443    /// ```text
444    /// abs(a - b) <= (atol + rtol * abs(b))
445    ///
446    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
447    /// and `atol` is the absolute tolerance.
448    /// ```
449    ///
450    /// # Arguments
451    ///
452    /// * `other` - The tensor to compare with.
453    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
454    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
455    ///
456    /// # Returns
457    ///
458    /// A boolean tensor with the same shape as the input tensors.
459    ///
460    /// # Example
461    ///
462    /// ```rust
463    /// use burn_tensor::backend::Backend;
464    /// use burn_tensor::{Tensor, Shape};
465    ///
466    /// fn example<B: Backend>() {
467    ///    let device = B::Device::default();
468    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
469    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
470    ///    let tensor = tensor1.is_close(tensor2, None, None);
471    ///    println!("{tensor}");
472    ///    // [[true, true, true], [true, true, true]]
473    /// }
474    /// ```
475    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
476        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
477        let atol = atol.unwrap_or(DEFAULT_ATOL);
478
479        // check finite difference is close
480        let is_close_finite_val = self
481            .clone()
482            .sub(other.clone())
483            .abs()
484            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
485            .bool_and(self.clone().is_finite())
486            .bool_and(other.clone().is_finite());
487
488        // check if both are infinite and have same sign
489        let inf_same_sign = self
490            .clone()
491            .is_finite()
492            .bool_not()
493            .bool_and(other.clone().is_finite().bool_not())
494            .bool_and(self.equal(other));
495
496        is_close_finite_val.bool_or(inf_same_sign)
497    }
498
499    /// Checks if all elements are close to another tensor.
500    ///
501    /// The tolerance is defined by the following equation:
502    ///
503    /// ```text
504    ///
505    /// abs(a - b) <= (atol + rtol * abs(b))
506    ///
507    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
508    /// and `atol` is the absolute tolerance.
509    ///
510    /// ```
511    ///
512    /// # Arguments
513    ///
514    /// * `other` - The tensor to compare with.
515    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
516    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
517    ///
518    /// # Returns
519    ///
520    /// A boolean scalar.
521    ///
522    /// # Remarks
523    ///
524    /// # Example
525    ///
526    /// ```rust
527    /// use burn_tensor::backend::Backend;
528    /// use burn_tensor::{Tensor, Shape};
529    ///
530    /// fn example<B: Backend>() {
531    ///    let device = B::Device::default();
532    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
533    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
534    ///    let result = tensor1.all_close(tensor2, None, None);
535    ///    println!("{}", result);
536    ///    // true
537    /// }
538    /// ```
539    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
540        self.is_close(other, rtol, atol)
541            .all()
542            .into_scalar()
543            .to_bool()
544    }
545
546    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
547    ///
548    /// # Returns
549    ///
550    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
551    ///
552    /// # Example
553    ///
554    /// ```rust
555    /// use burn_tensor::backend::Backend;
556    /// use burn_tensor::{Tensor, Bool, Shape};
557    ///
558    /// fn example<B: Backend>() {
559    ///    let device = B::Device::default();
560    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
561    ///    let tensor = tensor.is_nan();
562    ///    println!("{tensor}");
563    ///    // [[false, true, false], [false, false, false]]
564    /// }
565    /// ```
566    pub fn is_nan(self) -> Tensor<B, D, Bool> {
567        Tensor::new(B::float_is_nan(self.primitive.tensor()))
568    }
569
570    /// Checks if the tensor contains any NaN values.
571    ///
572    /// # Returns
573    ///
574    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
575    ///
576    /// # Example
577    ///
578    /// ```rust
579    /// use burn_tensor::backend::Backend;
580    /// use burn_tensor::{Tensor, Bool, Shape};
581    ///
582    /// fn example<B: Backend>() {
583    ///   let device = B::Device::default();
584    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
585    ///   let tensor = tensor.contains_nan();
586    ///   println!("{tensor}");
587    ///   // [true]
588    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
589    ///   let tensor = tensor.contains_nan();
590    ///   println!("{tensor}");
591    ///   // [false]
592    /// }
593    /// ```
594    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
595        // Summing the tensor will result in NaN if the tensor contains any NaN values
596        // This is faster than checking each element individually
597        // because it rolls up the NaN values into a single value
598        let sum = self.sum();
599
600        sum.is_nan()
601    }
602
603    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
604    ///
605    /// # Returns
606    ///
607    /// A boolean tensor where `true` indicates that the value is infinite
608    ///
609    /// # Example
610    ///
611    /// ```rust
612    /// use burn_tensor::backend::Backend;
613    /// use burn_tensor::{Tensor, Bool, Shape};
614    ///
615    /// fn example<B: Backend>() {
616    ///    let device = B::Device::default();
617    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
618    ///    let tensor = tensor.is_finite();
619    ///    println!("{tensor}");
620    ///    // [[false, true, false], [false, false, false]]
621    /// }
622    /// ```
623    pub fn is_inf(self) -> Tensor<B, D, Bool> {
624        Tensor::new(B::float_is_inf(self.primitive.tensor()))
625    }
626
627    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
628    ///
629    /// # Returns
630    ///
631    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
632    /// either INF, -INF or NAN
633    ///
634    /// # Example
635    ///
636    /// ```rust
637    /// use burn_tensor::backend::Backend;
638    /// use burn_tensor::{Tensor, Bool, Shape};
639    ///
640    /// fn example<B: Backend>() {
641    ///    let device = B::Device::default();
642    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
643    ///    let tensor = tensor.is_finite();
644    ///    println!("{tensor}");
645    ///    // [[true, false, true], [false, true, true]]
646    /// }
647    /// ```
648    pub fn is_finite(self) -> Tensor<B, D, Bool> {
649        self.clone()
650            .is_nan()
651            .bool_not()
652            .bool_and(self.is_inf().bool_not())
653    }
654
655    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
656    /// using the given locations in [-1, 1].
657    ///
658    /// # Arguments
659    ///
660    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
661    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
662    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
663    ///
664    /// # Returns
665    ///
666    /// A tensor with shape (N, C, H_out, W_out)
667    ///
668    /// # Example
669    ///
670    /// ```ignore
671    /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
672    ///
673    /// // Default options (bilinear, zeros padding, align_corners=false)
674    /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
675    ///
676    /// // Custom options
677    /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
678    ///     .with_padding_mode(GridSamplePaddingMode::Border)
679    ///     .with_align_corners(true);
680    /// let output = tensor.grid_sample_2d(grid, options);
681    /// ```
682    pub fn grid_sample_2d(self, grid: Tensor<B, D>, options: GridSampleOptions) -> Tensor<B, D> {
683        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
684            self.primitive.tensor(),
685            grid.primitive.tensor(),
686            options,
687        )))
688    }
689
690    /// Computes the cross product of `self` and another tensor along a given dimension.
691    ///
692    /// Both `self` and `other` **must have size 3** along the specified `dim`,
693    /// because the cross product is only defined in three-dimensional space.
694    ///
695    /// # Arguments
696    ///
697    /// * `other` - The other tensor to take the cross product with.
698    /// * `dim`   - The dimension along which to compute the cross product.
699    ///
700    /// # Returns
701    ///
702    /// A tensor containing the cross product of `self` and `other` along `dim`.
703    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
704        let dim = dim.expect_dim_index(D);
705        check!(TensorCheck::cross(&self, &other, dim));
706        Tensor::new(TensorPrimitive::Float(B::float_cross(
707            self.primitive.tensor(),
708            other.primitive.tensor(),
709            dim,
710        )))
711    }
712}