burn_tensor/tensor/api/float.rs
1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Int, TensorPrimitive};
13use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
14
15/// Default RTOL value for `is_close` and `all_close`.
16pub const DEFAULT_RTOL: f64 = 1e-5;
17
18/// Default ATOL value for `is_close` and `all_close`.
19pub const DEFAULT_ATOL: f64 = 1e-8;
20
21impl<const D: usize, B> Tensor<B, D>
22where
23 B: Backend,
24{
25 /// Applies element wise exponential operation.
26 ///
27 #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
28 #[cfg_attr(not(doc), doc = "`y = e^x`")]
29 pub fn exp(self) -> Self {
30 Self::new(TensorPrimitive::Float(B::float_exp(
31 self.primitive.tensor(),
32 )))
33 }
34
35 /// Applies element wise natural log operation *ln*.
36 ///
37 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
38 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
39 pub fn log(self) -> Self {
40 Self::new(TensorPrimitive::Float(B::float_log(
41 self.primitive.tensor(),
42 )))
43 }
44
45 /// Applies the natural logarithm of one plus the input tensor, element-wise.
46 ///
47 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
48 #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
49 pub fn log1p(self) -> Self {
50 Self::new(TensorPrimitive::Float(B::float_log1p(
51 self.primitive.tensor(),
52 )))
53 }
54
55 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
56 ///
57 #[cfg_attr(
58 doc,
59 doc = r#"
60$y_i = \text{erf}\(x_i\)$
61
62The error function is defined as:
63
64$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
65"#
66 )]
67 #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
68 pub fn erf(self) -> Self {
69 Self::new(TensorPrimitive::Float(B::float_erf(
70 self.primitive.tensor(),
71 )))
72 }
73
74 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
75 /// (or multiplicative inverse) element wise.
76 ///
77 #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
78 #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
79 pub fn recip(self) -> Self {
80 Self::new(TensorPrimitive::Float(B::float_recip(
81 self.primitive.tensor(),
82 )))
83 }
84
85 /// Applies element wise square operation.
86 ///
87 #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
88 #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
89 pub fn square(self) -> Self {
90 self.powi_scalar(2)
91 }
92
93 /// Applies element wise root square operation.
94 ///
95 #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
96 #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
97 pub fn sqrt(self) -> Self {
98 Self::new(TensorPrimitive::Float(B::float_sqrt(
99 self.primitive.tensor(),
100 )))
101 }
102
103 /// Applies element wise cosine operation.
104 ///
105 #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
106 #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
107 pub fn cos(self) -> Self {
108 Self::new(TensorPrimitive::Float(B::float_cos(
109 self.primitive.tensor(),
110 )))
111 }
112
113 /// Applies element wise sine operation.
114 ///
115 #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
116 #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
117 pub fn sin(self) -> Self {
118 Self::new(TensorPrimitive::Float(B::float_sin(
119 self.primitive.tensor(),
120 )))
121 }
122
123 /// Applies element wise tangent operation.
124 ///
125 #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
126 #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
127 pub fn tan(self) -> Self {
128 Self::new(TensorPrimitive::Float(B::float_tan(
129 self.primitive.tensor(),
130 )))
131 }
132
133 /// Applies element wise hyperbolic cosine operation.
134 ///
135 #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
136 #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
137 ///
138 /// # Example
139 ///
140 /// ```rust
141 /// use burn_tensor::backend::Backend;
142 /// use burn_tensor::Tensor;
143 ///
144 /// fn example<B: Backend>() {
145 /// let device = Default::default();
146 ///
147 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
148 /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
149 /// }
150 /// ```
151 pub fn cosh(self) -> Self {
152 Self::new(TensorPrimitive::Float(B::float_cosh(
153 self.primitive.tensor(),
154 )))
155 }
156
157 /// Applies element wise hyperbolic sine operation.
158 ///
159 #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
160 #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
161 ///
162 /// # Example
163 ///
164 /// ```rust
165 /// use burn_tensor::backend::Backend;
166 /// use burn_tensor::Tensor;
167 ///
168 /// fn example<B: Backend>() {
169 /// let device = Default::default();
170 ///
171 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
172 /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
173 /// }
174 /// ```
175 pub fn sinh(self) -> Self {
176 Self::new(TensorPrimitive::Float(B::float_sinh(
177 self.primitive.tensor(),
178 )))
179 }
180
181 /// Applies element wise hyperbolic tangent operation.
182 ///
183 #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
184 #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
185 ///
186 /// # Example
187 ///
188 /// ```rust
189 /// use burn_tensor::backend::Backend;
190 /// use burn_tensor::Tensor;
191 ///
192 /// fn example<B: Backend>() {
193 /// let device = Default::default();
194 ///
195 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
196 /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
197 /// }
198 /// ```
199 pub fn tanh(self) -> Self {
200 Self::new(TensorPrimitive::Float(B::float_tanh(
201 self.primitive.tensor(),
202 )))
203 }
204
205 /// Applies element wise round operation.
206 ///
207 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
208 /// strategy, with halfway cases rounded to the nearest even integer value.
209 pub fn round(self) -> Self {
210 Self::new(TensorPrimitive::Float(B::float_round(
211 self.primitive.tensor(),
212 )))
213 }
214
215 /// Applies element wise floor operation.
216 pub fn floor(self) -> Self {
217 Self::new(TensorPrimitive::Float(B::float_floor(
218 self.primitive.tensor(),
219 )))
220 }
221
222 /// Applies element wise ceil operation.
223 pub fn ceil(self) -> Self {
224 Self::new(TensorPrimitive::Float(B::float_ceil(
225 self.primitive.tensor(),
226 )))
227 }
228
229 /// Create a tensor from floats (f32) on a given device.
230 ///
231 /// # Example
232 ///
233 /// ```rust
234 /// use burn_tensor::backend::Backend;
235 /// use burn_tensor::Tensor;
236 ///
237 /// fn example<B: Backend>() {
238 /// let device = B::Device::default();
239 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
240 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
241 /// }
242 /// ```
243 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
244 Self::from_data(floats.into().convert::<f32>(), device)
245 }
246
247 /// Returns a new tensor with the same shape and device as the current tensor and the data
248 /// cast to Integer.
249 ///
250 /// # Example
251 ///
252 /// ```rust
253 /// use burn_tensor::backend::Backend;
254 /// use burn_tensor::Tensor;
255 ///
256 /// fn example<B: Backend>() {
257 /// let device = Default::default();
258 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
259 /// let int_tensor = float_tensor.int();
260 /// }
261 /// ```
262 pub fn int(self) -> Tensor<B, D, Int> {
263 Tensor::new(B::float_into_int(self.primitive.tensor()))
264 }
265
266 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
267 /// values sampled from the given distribution.
268 pub fn random_like(&self, distribution: Distribution) -> Self {
269 Self::new(TensorPrimitive::Float(B::float_random(
270 self.shape(),
271 distribution,
272 &self.device(),
273 )))
274 .cast(self.dtype())
275 }
276
277 /// Calculate the variance along the given dimension.
278 pub fn var(self, dim: usize) -> Self {
279 stats::var(self, dim)
280 }
281
282 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
283 pub fn var_bias(self, dim: usize) -> Self {
284 stats::var_bias(self, dim)
285 }
286
287 /// Calculate the variance along the given dimension and also returns the mean.
288 pub fn var_mean(self, dim: usize) -> (Self, Self) {
289 let mean = self.clone().mean_dim(dim);
290 let var = stats::var_with_mean(self, mean.clone(), dim);
291 (var, mean)
292 }
293
294 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
295 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
296 let mean = self.clone().mean_dim(dim);
297 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
298 (var, mean)
299 }
300
301 /// Converts a tensor to the specified floating point data type.
302 ///
303 /// This is always a no-op when casting to the current dtype.
304 ///
305 /// # Warning
306 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
307 /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
308 pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
309 let dtype = dtype.into();
310 let self_type: FloatDType = self.dtype().into();
311 if dtype == self_type {
312 // no-op.
313 return self;
314 }
315
316 Tensor::new(TensorPrimitive::Float(B::float_cast(
317 self.primitive.tensor(),
318 dtype,
319 )))
320 }
321
322 /// Detach the current tensor from the autodiff graph.
323 ///
324 /// This function does nothing when autodiff is not enabled.
325 /// This can be used in batchers or elsewhere to ensure that previous operations are not
326 /// considered in the autodiff graph.
327 pub fn detach(self) -> Self {
328 Self::new(TensorPrimitive::Float(B::float_detach(
329 self.primitive.tensor(),
330 )))
331 }
332
333 /// Mark the tensor to keep gradients during the backward pass.
334 ///
335 /// This function does nothing when autodiff is not enabled.
336 pub fn require_grad(self) -> Self {
337 self.set_require_grad(true)
338 }
339
340 /// Returns true if the tensor requires gradients during the backward pass.
341 pub fn is_require_grad(&self) -> bool {
342 match &self.primitive {
343 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
344 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
345 }
346 }
347
348 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
349 /// When tracked, the gradients will be available after the backward pass.
350 ///
351 /// This function does nothing when autodiff is not enabled.
352 pub fn set_require_grad(self, require_grad: bool) -> Self {
353 let primitive = match self.primitive {
354 TensorPrimitive::Float(tensor) => {
355 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
356 }
357 TensorPrimitive::QFloat(tensor) => {
358 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
359 }
360 };
361 Self::new(primitive)
362 }
363
364 /// Applies the relu function to the tensor.
365 pub(crate) fn relu(self) -> Self {
366 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
367 }
368
369 /// Calculate covaraince matrix between different entries alongside a given dimension.
370 ///
371 /// # Arguments
372 ///
373 /// * `size` - The size of the square matrix.
374 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
375 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
376 let n = self.dims()[dim];
377 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
378 centered
379 .clone()
380 .transpose()
381 .matmul(centered)
382 .div_scalar(n as f32 - correction_factor as f32)
383 }
384
385 /// Convert the tensor to a lower precision data type based on the quantization scheme.
386 ///
387 /// # Arguments
388 ///
389 /// * `scheme` - The quantization scheme.
390 /// * `qparams` - The pre-computed quantization parameters.
391 ///
392 /// # Returns
393 ///
394 /// The quantized tensor.
395 pub fn quantize(
396 self,
397 scheme: &QuantScheme,
398 qparams: QuantizationParameters<B>,
399 ) -> Tensor<B, D> {
400 Tensor::new(TensorPrimitive::QFloat(B::quantize(
401 self.primitive.tensor(),
402 scheme,
403 QuantizationParametersPrimitive {
404 scales: qparams.scales.primitive.tensor(),
405 },
406 )))
407 }
408
409 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
410 ///
411 /// # Arguments
412 ///
413 /// * `scheme` - The quantization scheme.
414 ///
415 /// # Returns
416 ///
417 /// The quantized tensor.
418 ///
419 /// # Notes
420 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
421 pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
422 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
423 self.primitive.tensor(),
424 scheme,
425 )))
426 }
427
428 /// Convert the tensor back to a higher precision data type.
429 ///
430 /// If the tensor is not quantized, its value is simply returned.
431 ///
432 /// # Returns
433 ///
434 /// The dequantized tensor.
435 pub fn dequantize(self) -> Tensor<B, D> {
436 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
437 }
438
439 /// Checks element wise if the tensor is close to another tensor.
440 ///
441 /// The tolerance is defined by the following equation:
442 ///
443 /// ```text
444 /// abs(a - b) <= (atol + rtol * abs(b))
445 ///
446 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
447 /// and `atol` is the absolute tolerance.
448 /// ```
449 ///
450 /// # Arguments
451 ///
452 /// * `other` - The tensor to compare with.
453 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
454 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
455 ///
456 /// # Returns
457 ///
458 /// A boolean tensor with the same shape as the input tensors.
459 ///
460 /// # Example
461 ///
462 /// ```rust
463 /// use burn_tensor::backend::Backend;
464 /// use burn_tensor::{Tensor, Shape};
465 ///
466 /// fn example<B: Backend>() {
467 /// let device = B::Device::default();
468 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
469 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
470 /// let tensor = tensor1.is_close(tensor2, None, None);
471 /// println!("{tensor}");
472 /// // [[true, true, true], [true, true, true]]
473 /// }
474 /// ```
475 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
476 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
477 let atol = atol.unwrap_or(DEFAULT_ATOL);
478
479 // check finite difference is close
480 let is_close_finite_val = self
481 .clone()
482 .sub(other.clone())
483 .abs()
484 .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
485 .bool_and(self.clone().is_finite())
486 .bool_and(other.clone().is_finite());
487
488 // check if both are infinite and have same sign
489 let inf_same_sign = self
490 .clone()
491 .is_finite()
492 .bool_not()
493 .bool_and(other.clone().is_finite().bool_not())
494 .bool_and(self.equal(other));
495
496 is_close_finite_val.bool_or(inf_same_sign)
497 }
498
499 /// Checks if all elements are close to another tensor.
500 ///
501 /// The tolerance is defined by the following equation:
502 ///
503 /// ```text
504 ///
505 /// abs(a - b) <= (atol + rtol * abs(b))
506 ///
507 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
508 /// and `atol` is the absolute tolerance.
509 ///
510 /// ```
511 ///
512 /// # Arguments
513 ///
514 /// * `other` - The tensor to compare with.
515 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
516 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
517 ///
518 /// # Returns
519 ///
520 /// A boolean scalar.
521 ///
522 /// # Remarks
523 ///
524 /// # Example
525 ///
526 /// ```rust
527 /// use burn_tensor::backend::Backend;
528 /// use burn_tensor::{Tensor, Shape};
529 ///
530 /// fn example<B: Backend>() {
531 /// let device = B::Device::default();
532 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
533 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
534 /// let result = tensor1.all_close(tensor2, None, None);
535 /// println!("{}", result);
536 /// // true
537 /// }
538 /// ```
539 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
540 self.is_close(other, rtol, atol)
541 .all()
542 .into_scalar()
543 .to_bool()
544 }
545
546 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
547 ///
548 /// # Returns
549 ///
550 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
551 ///
552 /// # Example
553 ///
554 /// ```rust
555 /// use burn_tensor::backend::Backend;
556 /// use burn_tensor::{Tensor, Bool, Shape};
557 ///
558 /// fn example<B: Backend>() {
559 /// let device = B::Device::default();
560 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
561 /// let tensor = tensor.is_nan();
562 /// println!("{tensor}");
563 /// // [[false, true, false], [false, false, false]]
564 /// }
565 /// ```
566 pub fn is_nan(self) -> Tensor<B, D, Bool> {
567 Tensor::new(B::float_is_nan(self.primitive.tensor()))
568 }
569
570 /// Checks if the tensor contains any NaN values.
571 ///
572 /// # Returns
573 ///
574 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
575 ///
576 /// # Example
577 ///
578 /// ```rust
579 /// use burn_tensor::backend::Backend;
580 /// use burn_tensor::{Tensor, Bool, Shape};
581 ///
582 /// fn example<B: Backend>() {
583 /// let device = B::Device::default();
584 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
585 /// let tensor = tensor.contains_nan();
586 /// println!("{tensor}");
587 /// // [true]
588 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
589 /// let tensor = tensor.contains_nan();
590 /// println!("{tensor}");
591 /// // [false]
592 /// }
593 /// ```
594 pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
595 // Summing the tensor will result in NaN if the tensor contains any NaN values
596 // This is faster than checking each element individually
597 // because it rolls up the NaN values into a single value
598 let sum = self.sum();
599
600 sum.is_nan()
601 }
602
603 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
604 ///
605 /// # Returns
606 ///
607 /// A boolean tensor where `true` indicates that the value is infinite
608 ///
609 /// # Example
610 ///
611 /// ```rust
612 /// use burn_tensor::backend::Backend;
613 /// use burn_tensor::{Tensor, Bool, Shape};
614 ///
615 /// fn example<B: Backend>() {
616 /// let device = B::Device::default();
617 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
618 /// let tensor = tensor.is_finite();
619 /// println!("{tensor}");
620 /// // [[false, true, false], [false, false, false]]
621 /// }
622 /// ```
623 pub fn is_inf(self) -> Tensor<B, D, Bool> {
624 Tensor::new(B::float_is_inf(self.primitive.tensor()))
625 }
626
627 /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
628 ///
629 /// # Returns
630 ///
631 /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
632 /// either INF, -INF or NAN
633 ///
634 /// # Example
635 ///
636 /// ```rust
637 /// use burn_tensor::backend::Backend;
638 /// use burn_tensor::{Tensor, Bool, Shape};
639 ///
640 /// fn example<B: Backend>() {
641 /// let device = B::Device::default();
642 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
643 /// let tensor = tensor.is_finite();
644 /// println!("{tensor}");
645 /// // [[true, false, true], [false, true, true]]
646 /// }
647 /// ```
648 pub fn is_finite(self) -> Tensor<B, D, Bool> {
649 self.clone()
650 .is_nan()
651 .bool_not()
652 .bool_and(self.is_inf().bool_not())
653 }
654
655 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
656 /// using the given locations in [-1, 1].
657 ///
658 /// # Arguments
659 ///
660 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
661 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
662 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
663 ///
664 /// # Returns
665 ///
666 /// A tensor with shape (N, C, H_out, W_out)
667 ///
668 /// # Example
669 ///
670 /// ```ignore
671 /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
672 ///
673 /// // Default options (bilinear, zeros padding, align_corners=false)
674 /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
675 ///
676 /// // Custom options
677 /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
678 /// .with_padding_mode(GridSamplePaddingMode::Border)
679 /// .with_align_corners(true);
680 /// let output = tensor.grid_sample_2d(grid, options);
681 /// ```
682 pub fn grid_sample_2d(self, grid: Tensor<B, D>, options: GridSampleOptions) -> Tensor<B, D> {
683 Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
684 self.primitive.tensor(),
685 grid.primitive.tensor(),
686 options,
687 )))
688 }
689
690 /// Computes the cross product of `self` and another tensor along a given dimension.
691 ///
692 /// Both `self` and `other` **must have size 3** along the specified `dim`,
693 /// because the cross product is only defined in three-dimensional space.
694 ///
695 /// # Arguments
696 ///
697 /// * `other` - The other tensor to take the cross product with.
698 /// * `dim` - The dimension along which to compute the cross product.
699 ///
700 /// # Returns
701 ///
702 /// A tensor containing the cross product of `self` and `other` along `dim`.
703 pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
704 let dim = dim.expect_dim_index(D);
705 check!(TensorCheck::cross(&self, &other, dim));
706 Tensor::new(TensorPrimitive::Float(B::float_cross(
707 self.primitive.tensor(),
708 other.primitive.tensor(),
709 dim,
710 )))
711 }
712}