burn_tensor/tensor/api/
float.rs

1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::indexing::canonicalize_dim;
8use crate::ops::InterpolateMode;
9use crate::quantization::{QuantScheme, QuantizationParameters};
10use crate::tensor::backend::Backend;
11use crate::tensor::stats;
12use crate::tensor::{Distribution, TensorData};
13use crate::{Int, TensorPrimitive};
14
15use super::Bool;
16
17/// Default RTOL value for `is_close` and `all_close`.
18pub const DEFAULT_RTOL: f64 = 1e-5;
19
20/// Default ATOL value for `is_close` and `all_close`.
21pub const DEFAULT_ATOL: f64 = 1e-8;
22
23impl<const D: usize, B> Tensor<B, D>
24where
25    B: Backend,
26{
27    /// Applies element wise exponential operation.
28    ///
29    #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
30    #[cfg_attr(not(doc), doc = "`y = e^x`")]
31    pub fn exp(self) -> Self {
32        Self::new(TensorPrimitive::Float(B::float_exp(
33            self.primitive.tensor(),
34        )))
35    }
36
37    /// Applies element wise natural log operation *ln*.
38    ///
39    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
40    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
41    pub fn log(self) -> Self {
42        Self::new(TensorPrimitive::Float(B::float_log(
43            self.primitive.tensor(),
44        )))
45    }
46
47    /// Applies the natural logarithm of one plus the input tensor, element-wise.
48    ///
49    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
50    #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
51    pub fn log1p(self) -> Self {
52        Self::new(TensorPrimitive::Float(B::float_log1p(
53            self.primitive.tensor(),
54        )))
55    }
56
57    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
58    ///
59    #[cfg_attr(
60        doc,
61        doc = r#"
62$y_i = \text{erf}\(x_i\)$
63
64The error function is defined as:
65
66$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
67"#
68    )]
69    #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
70    pub fn erf(self) -> Self {
71        Self::new(TensorPrimitive::Float(B::float_erf(
72            self.primitive.tensor(),
73        )))
74    }
75
76    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
77    /// (or multiplicative inverse) element wise.
78    ///
79    #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
80    #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
81    pub fn recip(self) -> Self {
82        Self::new(TensorPrimitive::Float(B::float_recip(
83            self.primitive.tensor(),
84        )))
85    }
86
87    /// Applies element wise square operation.
88    ///
89    #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
90    #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
91    pub fn square(self) -> Self {
92        self.powi_scalar(2)
93    }
94
95    /// Applies element wise root square operation.
96    ///
97    #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
98    #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
99    pub fn sqrt(self) -> Self {
100        Self::new(TensorPrimitive::Float(B::float_sqrt(
101            self.primitive.tensor(),
102        )))
103    }
104
105    /// Applies element wise cosine operation.
106    ///
107    #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
108    #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
109    pub fn cos(self) -> Self {
110        Self::new(TensorPrimitive::Float(B::float_cos(
111            self.primitive.tensor(),
112        )))
113    }
114
115    /// Applies element wise sine operation.
116    ///
117    #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
118    #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
119    pub fn sin(self) -> Self {
120        Self::new(TensorPrimitive::Float(B::float_sin(
121            self.primitive.tensor(),
122        )))
123    }
124
125    /// Applies element wise tangent operation.
126    ///
127    #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
128    #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
129    pub fn tan(self) -> Self {
130        Self::new(TensorPrimitive::Float(B::float_tan(
131            self.primitive.tensor(),
132        )))
133    }
134
135    /// Applies element wise hyperbolic cosine operation.
136    ///
137    #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
138    #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
139    ///
140    /// # Example
141    ///
142    /// ```rust
143    /// use burn_tensor::backend::Backend;
144    /// use burn_tensor::Tensor;
145    ///
146    /// fn example<B: Backend>() {
147    ///     let device = Default::default();
148    ///
149    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
150    ///     println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
151    /// }
152    /// ```
153    pub fn cosh(self) -> Self {
154        Self::new(TensorPrimitive::Float(B::float_cosh(
155            self.primitive.tensor(),
156        )))
157    }
158
159    /// Applies element wise hyperbolic sine operation.
160    ///
161    #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
162    #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
163    ///
164    /// # Example
165    ///
166    /// ```rust
167    /// use burn_tensor::backend::Backend;
168    /// use burn_tensor::Tensor;
169    ///
170    /// fn example<B: Backend>() {
171    ///     let device = Default::default();
172    ///
173    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
174    ///     println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
175    /// }
176    /// ```
177    pub fn sinh(self) -> Self {
178        Self::new(TensorPrimitive::Float(B::float_sinh(
179            self.primitive.tensor(),
180        )))
181    }
182
183    /// Applies element wise hyperbolic tangent operation.
184    ///
185    #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
186    #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
187    ///
188    /// # Example
189    ///
190    /// ```rust
191    /// use burn_tensor::backend::Backend;
192    /// use burn_tensor::Tensor;
193    ///
194    /// fn example<B: Backend>() {
195    ///     let device = Default::default();
196    ///
197    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
198    ///     println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
199    /// }
200    /// ```
201    pub fn tanh(self) -> Self {
202        Self::new(TensorPrimitive::Float(B::float_tanh(
203            self.primitive.tensor(),
204        )))
205    }
206
207    /// Applies element wise round operation.
208    ///
209    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
210    /// strategy, with halfway cases rounded to the nearest even integer value.
211    pub fn round(self) -> Self {
212        Self::new(TensorPrimitive::Float(B::float_round(
213            self.primitive.tensor(),
214        )))
215    }
216
217    /// Applies element wise floor operation.
218    pub fn floor(self) -> Self {
219        Self::new(TensorPrimitive::Float(B::float_floor(
220            self.primitive.tensor(),
221        )))
222    }
223
224    /// Applies element wise ceil operation.
225    pub fn ceil(self) -> Self {
226        Self::new(TensorPrimitive::Float(B::float_ceil(
227            self.primitive.tensor(),
228        )))
229    }
230
231    /// Create a tensor from floats (f32) on a given device.
232    ///
233    /// # Example
234    ///
235    /// ```rust
236    /// use burn_tensor::backend::Backend;
237    /// use burn_tensor::Tensor;
238    ///
239    /// fn example<B: Backend>() {
240    ///     let device = B::Device::default();
241    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
242    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
243    /// }
244    /// ```
245    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
246        Self::from_data(floats.into().convert::<f32>(), device)
247    }
248
249    /// Returns a new tensor with the same shape and device as the current tensor and the data
250    /// cast to Integer.
251    ///
252    /// # Example
253    ///
254    /// ```rust
255    /// use burn_tensor::backend::Backend;
256    /// use burn_tensor::Tensor;
257    ///
258    /// fn example<B: Backend>() {
259    ///     let device = Default::default();
260    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
261    ///     let int_tensor = float_tensor.int();
262    /// }
263    /// ```
264    pub fn int(self) -> Tensor<B, D, Int> {
265        Tensor::new(B::float_into_int(self.primitive.tensor()))
266    }
267
268    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
269    /// values sampled from the given distribution.
270    pub fn random_like(&self, distribution: Distribution) -> Self {
271        Self::new(TensorPrimitive::Float(B::float_random(
272            self.shape(),
273            distribution,
274            &self.device(),
275        )))
276        .cast(self.dtype())
277    }
278
279    /// Calculate the variance along the given dimension.
280    pub fn var(self, dim: usize) -> Self {
281        stats::var(self, dim)
282    }
283
284    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
285    pub fn var_bias(self, dim: usize) -> Self {
286        stats::var_bias(self, dim)
287    }
288
289    /// Calculate the variance along the given dimension and also returns the mean.
290    pub fn var_mean(self, dim: usize) -> (Self, Self) {
291        let mean = self.clone().mean_dim(dim);
292        let var = stats::var_with_mean(self, mean.clone(), dim);
293        (var, mean)
294    }
295
296    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
297    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
298        let mean = self.clone().mean_dim(dim);
299        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
300        (var, mean)
301    }
302
303    /// Converts a tensor to the specified floating point data type.
304    ///
305    /// This is always a no-op when casting to the current dtype.
306    ///
307    /// # Warning
308    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
309    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
310    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
311        let dtype = dtype.into();
312        let self_type: FloatDType = self.dtype().into();
313        if dtype == self_type {
314            // no-op.
315            return self;
316        }
317
318        Tensor::new(TensorPrimitive::Float(B::float_cast(
319            self.primitive.tensor(),
320            dtype,
321        )))
322    }
323
324    /// Detach the current tensor from the autodiff graph.
325    ///
326    /// This function does nothing when autodiff is not enabled.
327    /// This can be used in batchers or elsewhere to ensure that previous operations are not
328    /// considered in the autodiff graph.
329    pub fn detach(self) -> Self {
330        Self::new(TensorPrimitive::Float(B::float_detach(
331            self.primitive.tensor(),
332        )))
333    }
334
335    /// Mark the tensor to keep gradients during the backward pass.
336    ///
337    /// This function does nothing when autodiff is not enabled.
338    pub fn require_grad(self) -> Self {
339        self.set_require_grad(true)
340    }
341
342    /// Returns true if the tensor requires gradients during the backward pass.
343    pub fn is_require_grad(&self) -> bool {
344        match &self.primitive {
345            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
346            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
347        }
348    }
349
350    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
351    /// When tracked, the gradients will be available after the backward pass.
352    ///
353    /// This function does nothing when autodiff is not enabled.
354    pub fn set_require_grad(self, require_grad: bool) -> Self {
355        let primitive = match self.primitive {
356            TensorPrimitive::Float(tensor) => {
357                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
358            }
359            TensorPrimitive::QFloat(tensor) => {
360                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
361            }
362        };
363        Self::new(primitive)
364    }
365
366    /// Applies the relu function to the tensor.
367    pub(crate) fn relu(self) -> Self {
368        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
369    }
370
371    /// Calculate covaraince matrix between different entries alongside a given dimension.
372    ///
373    /// # Arguments
374    ///
375    /// * `size` - The size of the square matrix.
376    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
377    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
378        let n = self.dims()[dim];
379        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
380        centered
381            .clone()
382            .transpose()
383            .matmul(centered)
384            .div_scalar(n as f32 - correction_factor as f32)
385    }
386
387    /// Convert the tensor to a lower precision data type based on the quantization scheme.
388    ///
389    /// # Arguments
390    ///
391    /// * `scheme` - The quantization scheme.
392    /// * `qparams` - The pre-computed quantization parameters.
393    ///
394    /// # Returns
395    ///
396    /// The quantized tensor.
397    pub fn quantize(
398        self,
399        scheme: &QuantScheme,
400        qparams: QuantizationParameters<B>,
401    ) -> Tensor<B, D> {
402        Tensor::new(TensorPrimitive::QFloat(B::quantize(
403            self.primitive.tensor(),
404            scheme,
405            qparams.into(),
406        )))
407    }
408
409    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
410    ///
411    /// # Arguments
412    ///
413    /// * `scheme` - The quantization scheme.
414    ///
415    /// # Returns
416    ///
417    /// The quantized tensor.
418    ///
419    /// # Notes
420    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
421    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
422        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
423            self.primitive.tensor(),
424            scheme,
425        )))
426    }
427
428    /// Convert the tensor back to a higher precision data type.
429    ///
430    /// If the tensor is not quantized, its value is simply returned.
431    ///
432    /// # Returns
433    ///
434    /// The dequantized tensor.
435    pub fn dequantize(self) -> Tensor<B, D> {
436        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
437    }
438
439    /// Checks element wise if the tensor is close to another tensor.
440    ///
441    /// The tolerance is defined by the following equation:
442    ///
443    /// ```text
444    /// abs(a - b) <= (atol + rtol * abs(b))
445    ///
446    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
447    /// and `atol` is the absolute tolerance.
448    /// ```
449    ///
450    /// # Arguments
451    ///
452    /// * `other` - The tensor to compare with.
453    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
454    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
455    ///
456    /// # Returns
457    ///
458    /// A boolean tensor with the same shape as the input tensors.
459    ///
460    /// # Example
461    ///
462    /// ```rust
463    /// use burn_tensor::backend::Backend;
464    /// use burn_tensor::{Tensor, Shape};
465    ///
466    /// fn example<B: Backend>() {
467    ///    let device = B::Device::default();
468    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
469    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
470    ///    let tensor = tensor1.is_close(tensor2, None, None);
471    ///    println!("{tensor}");
472    ///    // [[true, true, true], [true, true, true]]
473    /// }
474    /// ```
475    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
476        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
477        let atol = atol.unwrap_or(DEFAULT_ATOL);
478
479        // check finite difference is close
480        let is_close_finite_val = self
481            .clone()
482            .sub(other.clone())
483            .abs()
484            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
485            .bool_and(self.clone().is_finite())
486            .bool_and(other.clone().is_finite());
487
488        // check if both are infinite and have same sign
489        let inf_same_sign = self
490            .clone()
491            .is_finite()
492            .bool_not()
493            .bool_and(other.clone().is_finite().bool_not())
494            .bool_and(self.equal(other));
495
496        is_close_finite_val.bool_or(inf_same_sign)
497    }
498
499    /// Checks if all elements are close to another tensor.
500    ///
501    /// The tolerance is defined by the following equation:
502    ///
503    /// ```text
504    ///
505    /// abs(a - b) <= (atol + rtol * abs(b))
506    ///
507    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
508    /// and `atol` is the absolute tolerance.
509    ///
510    /// ```
511    ///
512    /// # Arguments
513    ///
514    /// * `other` - The tensor to compare with.
515    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
516    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
517    ///
518    /// # Returns
519    ///
520    /// A boolean scalar.
521    ///
522    /// # Remarks
523    ///
524    /// # Example
525    ///
526    /// ```rust
527    /// use burn_tensor::backend::Backend;
528    /// use burn_tensor::{Tensor, Shape};
529    ///
530    /// fn example<B: Backend>() {
531    ///    let device = B::Device::default();
532    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
533    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
534    ///    let result = tensor1.all_close(tensor2, None, None);
535    ///    println!("{}", result);
536    ///    // true
537    /// }
538    /// ```
539    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
540        self.is_close(other, rtol, atol)
541            .all()
542            .into_scalar()
543            .to_bool()
544    }
545
546    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
547    ///
548    /// # Returns
549    ///
550    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
551    ///
552    /// # Example
553    ///
554    /// ```rust
555    /// use burn_tensor::backend::Backend;
556    /// use burn_tensor::{Tensor, Bool, Shape};
557    ///
558    /// fn example<B: Backend>() {
559    ///    let device = B::Device::default();
560    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
561    ///    let tensor = tensor.is_nan();
562    ///    println!("{tensor}");
563    ///    // [[false, true, false], [false, false, false]]
564    /// }
565    /// ```
566    pub fn is_nan(self) -> Tensor<B, D, Bool> {
567        Tensor::new(B::float_is_nan(self.primitive.tensor()))
568    }
569
570    /// Checks if the tensor contains any NaN values.
571    ///
572    /// # Returns
573    ///
574    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
575    ///
576    /// # Example
577    ///
578    /// ```rust
579    /// use burn_tensor::backend::Backend;
580    /// use burn_tensor::{Tensor, Bool, Shape};
581    ///
582    /// fn example<B: Backend>() {
583    ///   let device = B::Device::default();
584    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
585    ///   let tensor = tensor.contains_nan();
586    ///   println!("{tensor}");
587    ///   // [true]
588    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
589    ///   let tensor = tensor.contains_nan();
590    ///   println!("{tensor}");
591    ///   // [false]
592    /// }
593    /// ```
594    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
595        // Summing the tensor will result in NaN if the tensor contains any NaN values
596        // This is faster than checking each element individually
597        // because it rolls up the NaN values into a single value
598        let sum = self.sum();
599
600        sum.is_nan()
601    }
602
603    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
604    ///
605    /// # Returns
606    ///
607    /// A boolean tensor where `true` indicates that the value is infinite
608    ///
609    /// # Example
610    ///
611    /// ```rust
612    /// use burn_tensor::backend::Backend;
613    /// use burn_tensor::{Tensor, Bool, Shape};
614    ///
615    /// fn example<B: Backend>() {
616    ///    let device = B::Device::default();
617    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
618    ///    let tensor = tensor.is_finite();
619    ///    println!("{tensor}");
620    ///    // [[false, true, false], [false, false, false]]
621    /// }
622    /// ```
623    pub fn is_inf(self) -> Tensor<B, D, Bool> {
624        Tensor::new(B::float_is_inf(self.primitive.tensor()))
625    }
626
627    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
628    ///
629    /// # Returns
630    ///
631    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
632    /// either INF, -INF or NAN
633    ///
634    /// # Example
635    ///
636    /// ```rust
637    /// use burn_tensor::backend::Backend;
638    /// use burn_tensor::{Tensor, Bool, Shape};
639    ///
640    /// fn example<B: Backend>() {
641    ///    let device = B::Device::default();
642    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
643    ///    let tensor = tensor.is_finite();
644    ///    println!("{tensor}");
645    ///    // [[true, false, true], [false, true, true]]
646    /// }
647    /// ```
648    pub fn is_finite(self) -> Tensor<B, D, Bool> {
649        self.clone()
650            .is_nan()
651            .bool_not()
652            .bool_and(self.is_inf().bool_not())
653    }
654
655    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
656    /// using the given locations in [-1, 1].
657    ///
658    /// Interpolation is bilinear.
659    /// Padding is border: out of bounds locations will be clamped to the nearest border
660    ///
661    /// # Arguments
662    ///
663    /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
664    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
665    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
666    /// * `method` - How to interpolate between samples
667    ///
668    /// # Returns
669    ///
670    /// A tensor with shape (N, C, H_out, W_out)
671    pub fn grid_sample_2d(self, grid: Tensor<B, D>, method: InterpolateMode) -> Tensor<B, D> {
672        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
673            self.primitive.tensor(),
674            grid.primitive.tensor(),
675            method,
676        )))
677    }
678
679    /// Computes the cross product of `self` and another tensor along a given dimension.
680    ///
681    /// Both `self` and `other` **must have size 3** along the specified `dim`,
682    /// because the cross product is only defined in three-dimensional space.
683    ///
684    /// # Arguments
685    ///
686    /// * `other` - The other tensor to take the cross product with.
687    /// * `dim`   - The dimension along which to compute the cross product.
688    ///
689    /// # Returns
690    ///
691    /// A tensor containing the cross product of `self` and `other` along `dim`.
692    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
693        let dim = canonicalize_dim(dim, D, false);
694        check!(TensorCheck::cross(&self, &other, dim));
695        Tensor::new(TensorPrimitive::Float(B::float_cross(
696            self.primitive.tensor(),
697            other.primitive.tensor(),
698            dim,
699        )))
700    }
701}