burn_tensor/tensor/api/float.rs
1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Int, TensorPrimitive};
13use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
14
15/// Default RTOL value for `is_close` and `all_close`.
16pub const DEFAULT_RTOL: f64 = 1e-5;
17
18/// Default ATOL value for `is_close` and `all_close`.
19pub const DEFAULT_ATOL: f64 = 1e-8;
20
21impl<const D: usize, B> Tensor<B, D>
22where
23 B: Backend,
24{
25 /// Applies element wise exponential operation.
26 ///
27 #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
28 #[cfg_attr(not(doc), doc = "`y = e^x`")]
29 pub fn exp(self) -> Self {
30 Self::new(TensorPrimitive::Float(B::float_exp(
31 self.primitive.tensor(),
32 )))
33 }
34
35 /// Applies element wise natural log operation *ln*.
36 ///
37 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
38 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
39 pub fn log(self) -> Self {
40 Self::new(TensorPrimitive::Float(B::float_log(
41 self.primitive.tensor(),
42 )))
43 }
44
45 /// Applies the natural logarithm of one plus the input tensor, element-wise.
46 ///
47 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
48 #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
49 pub fn log1p(self) -> Self {
50 Self::new(TensorPrimitive::Float(B::float_log1p(
51 self.primitive.tensor(),
52 )))
53 }
54
55 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
56 ///
57 #[cfg_attr(
58 doc,
59 doc = r#"
60$y_i = \text{erf}\(x_i\)$
61
62The error function is defined as:
63
64$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
65"#
66 )]
67 #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
68 pub fn erf(self) -> Self {
69 Self::new(TensorPrimitive::Float(B::float_erf(
70 self.primitive.tensor(),
71 )))
72 }
73
74 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
75 /// (or multiplicative inverse) element wise.
76 ///
77 #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
78 #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
79 pub fn recip(self) -> Self {
80 Self::new(TensorPrimitive::Float(B::float_recip(
81 self.primitive.tensor(),
82 )))
83 }
84
85 /// Applies element wise square operation.
86 ///
87 #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
88 #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
89 pub fn square(self) -> Self {
90 self.powi_scalar(2)
91 }
92
93 /// Applies element wise root square operation.
94 ///
95 #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
96 #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
97 pub fn sqrt(self) -> Self {
98 Self::new(TensorPrimitive::Float(B::float_sqrt(
99 self.primitive.tensor(),
100 )))
101 }
102
103 /// Applies element wise cosine operation.
104 ///
105 #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
106 #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
107 pub fn cos(self) -> Self {
108 Self::new(TensorPrimitive::Float(B::float_cos(
109 self.primitive.tensor(),
110 )))
111 }
112
113 /// Applies element wise sine operation.
114 ///
115 #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
116 #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
117 pub fn sin(self) -> Self {
118 Self::new(TensorPrimitive::Float(B::float_sin(
119 self.primitive.tensor(),
120 )))
121 }
122
123 /// Applies element wise tangent operation.
124 ///
125 #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
126 #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
127 pub fn tan(self) -> Self {
128 Self::new(TensorPrimitive::Float(B::float_tan(
129 self.primitive.tensor(),
130 )))
131 }
132
133 /// Applies element wise hyperbolic cosine operation.
134 ///
135 #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
136 #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
137 ///
138 /// # Example
139 ///
140 /// ```rust
141 /// use burn_tensor::backend::Backend;
142 /// use burn_tensor::Tensor;
143 ///
144 /// fn example<B: Backend>() {
145 /// let device = Default::default();
146 ///
147 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
148 /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
149 /// }
150 /// ```
151 pub fn cosh(self) -> Self {
152 Self::new(TensorPrimitive::Float(B::float_cosh(
153 self.primitive.tensor(),
154 )))
155 }
156
157 /// Applies element wise hyperbolic sine operation.
158 ///
159 #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
160 #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
161 ///
162 /// # Example
163 ///
164 /// ```rust
165 /// use burn_tensor::backend::Backend;
166 /// use burn_tensor::Tensor;
167 ///
168 /// fn example<B: Backend>() {
169 /// let device = Default::default();
170 ///
171 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
172 /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
173 /// }
174 /// ```
175 pub fn sinh(self) -> Self {
176 Self::new(TensorPrimitive::Float(B::float_sinh(
177 self.primitive.tensor(),
178 )))
179 }
180
181 /// Applies element wise hyperbolic tangent operation.
182 ///
183 #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
184 #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
185 ///
186 /// # Example
187 ///
188 /// ```rust
189 /// use burn_tensor::backend::Backend;
190 /// use burn_tensor::Tensor;
191 ///
192 /// fn example<B: Backend>() {
193 /// let device = Default::default();
194 ///
195 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
196 /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
197 /// }
198 /// ```
199 pub fn tanh(self) -> Self {
200 Self::new(TensorPrimitive::Float(B::float_tanh(
201 self.primitive.tensor(),
202 )))
203 }
204
205 /// Applies element wise inverse sine operation.
206 ///
207 #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
208 #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
209 ///
210 /// # Example
211 ///
212 /// ```rust
213 /// use burn_tensor::backend::Backend;
214 /// use burn_tensor::Tensor;
215 ///
216 /// fn example<B: Backend>() {
217 /// let device = Default::default();
218 ///
219 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
220 /// println!("{}", tensor.asin()); // [ 0.0000, -1.5708, 1.5708]
221 /// }
222 /// ```
223 pub fn asin(self) -> Self {
224 Self::new(TensorPrimitive::Float(B::float_asin(
225 self.primitive.tensor(),
226 )))
227 }
228
229 /// Applies element wise inverse hyperbolic sine operation.
230 ///
231 #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
232 #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
233 ///
234 /// # Example
235 ///
236 /// ```rust
237 /// use burn_tensor::backend::Backend;
238 /// use burn_tensor::Tensor;
239 ///
240 /// fn example<B: Backend>() {
241 /// let device = Default::default();
242 ///
243 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
244 /// println!("{}", tensor.asinh()); // [ 0.0000, -0.8814, 0.8814]
245 /// }
246 /// ```
247 pub fn asinh(self) -> Self {
248 Self::new(TensorPrimitive::Float(B::float_asinh(
249 self.primitive.tensor(),
250 )))
251 }
252
253 /// Applies element wise inverse cosine operation.
254 ///
255 #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
256 #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
257 ///
258 /// # Example
259 ///
260 /// ```rust
261 /// use burn_tensor::backend::Backend;
262 /// use burn_tensor::Tensor;
263 ///
264 /// fn example<B: Backend>() {
265 /// let device = Default::default();
266 ///
267 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
268 /// println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
269 /// }
270 /// ```
271 pub fn acos(self) -> Self {
272 Self::new(TensorPrimitive::Float(B::float_acos(
273 self.primitive.tensor(),
274 )))
275 }
276
277 /// Applies element wise inverse hyperbolic cosine operation.
278 ///
279 #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
280 #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
281 ///
282 /// # Example
283 ///
284 /// ```rust
285 /// use burn_tensor::backend::Backend;
286 /// use burn_tensor::Tensor;
287 ///
288 /// fn example<B: Backend>() {
289 /// let device = Default::default();
290 ///
291 /// let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
292 /// println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627]
293 /// }
294 /// ```
295 pub fn acosh(self) -> Self {
296 Self::new(TensorPrimitive::Float(B::float_acosh(
297 self.primitive.tensor(),
298 )))
299 }
300
301 /// Applies element wise inverse tangent operation.
302 ///
303 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
304 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
305 ///
306 /// # Example
307 ///
308 /// ```rust
309 /// use burn_tensor::backend::Backend;
310 /// use burn_tensor::Tensor;
311 ///
312 /// fn example<B: Backend>() {
313 /// let device = Default::default();
314 ///
315 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
316 /// println!("{}", tensor.sinh()); // [ 0.0, -0.7854, 1.1071]
317 /// }
318 /// ```
319 pub fn atan(self) -> Self {
320 Self::new(TensorPrimitive::Float(B::float_atan(
321 self.primitive.tensor(),
322 )))
323 }
324
325 /// Applies element wise inverse hyperbolic tangent operation.
326 ///
327 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
328 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
329 ///
330 /// # Example
331 ///
332 /// ```rust
333 /// use burn_tensor::backend::Backend;
334 /// use burn_tensor::Tensor;
335 ///
336 /// fn example<B: Backend>() {
337 /// let device = Default::default();
338 ///
339 /// let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
340 /// println!("{}", tensor.sinh()); // [ 0.0, -0.5493, 0.5493]
341 /// }
342 /// ```
343 pub fn atanh(self) -> Self {
344 Self::new(TensorPrimitive::Float(B::float_atanh(
345 self.primitive.tensor(),
346 )))
347 }
348
349 /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
350 ///
351 #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
352 #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
353 ///
354 /// # Example
355 ///
356 /// ```rust
357 /// use burn_tensor::backend::Backend;
358 /// use burn_tensor::Tensor;
359 ///
360 /// fn example<B: Backend>() {
361 /// let device = Default::default();
362 ///
363 /// let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
364 /// let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
365 /// println!("{}", lhs.atan2(rhs)); // [-1.1071, 2.0344, -2.0344]
366 /// }
367 /// ```
368 pub fn atan2(self, other: Self) -> Self {
369 Self::new(TensorPrimitive::Float(B::float_atan2(
370 self.primitive.tensor(),
371 other.primitive.tensor(),
372 )))
373 }
374
375 /// Applies element wise round operation.
376 ///
377 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
378 /// strategy, with halfway cases rounded to the nearest even integer value.
379 pub fn round(self) -> Self {
380 Self::new(TensorPrimitive::Float(B::float_round(
381 self.primitive.tensor(),
382 )))
383 }
384
385 /// Applies element wise floor operation.
386 pub fn floor(self) -> Self {
387 Self::new(TensorPrimitive::Float(B::float_floor(
388 self.primitive.tensor(),
389 )))
390 }
391
392 /// Applies element wise ceil operation.
393 pub fn ceil(self) -> Self {
394 Self::new(TensorPrimitive::Float(B::float_ceil(
395 self.primitive.tensor(),
396 )))
397 }
398
399 /// Create a tensor from floats (f32) on a given device.
400 ///
401 /// # Example
402 ///
403 /// ```rust
404 /// use burn_tensor::backend::Backend;
405 /// use burn_tensor::Tensor;
406 ///
407 /// fn example<B: Backend>() {
408 /// let device = B::Device::default();
409 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
410 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
411 /// }
412 /// ```
413 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
414 Self::from_data(floats.into().convert::<f32>(), device)
415 }
416
417 /// Returns a new tensor with the same shape and device as the current tensor and the data
418 /// cast to Integer.
419 ///
420 /// # Example
421 ///
422 /// ```rust
423 /// use burn_tensor::backend::Backend;
424 /// use burn_tensor::Tensor;
425 ///
426 /// fn example<B: Backend>() {
427 /// let device = Default::default();
428 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
429 /// let int_tensor = float_tensor.int();
430 /// }
431 /// ```
432 pub fn int(self) -> Tensor<B, D, Int> {
433 Tensor::new(B::float_into_int(self.primitive.tensor()))
434 }
435
436 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
437 /// values sampled from the given distribution.
438 pub fn random_like(&self, distribution: Distribution) -> Self {
439 Self::new(TensorPrimitive::Float(B::float_random(
440 self.shape(),
441 distribution,
442 &self.device(),
443 )))
444 .cast(self.dtype())
445 }
446
447 /// Calculate the variance along the given dimension.
448 pub fn var(self, dim: usize) -> Self {
449 stats::var(self, dim)
450 }
451
452 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
453 pub fn var_bias(self, dim: usize) -> Self {
454 stats::var_bias(self, dim)
455 }
456
457 /// Calculate the variance along the given dimension and also returns the mean.
458 pub fn var_mean(self, dim: usize) -> (Self, Self) {
459 let mean = self.clone().mean_dim(dim);
460 let var = stats::var_with_mean(self, mean.clone(), dim);
461 (var, mean)
462 }
463
464 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
465 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
466 let mean = self.clone().mean_dim(dim);
467 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
468 (var, mean)
469 }
470
471 /// Converts a tensor to the specified floating point data type.
472 ///
473 /// This is always a no-op when casting to the current dtype.
474 ///
475 /// # Warning
476 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
477 /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
478 pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
479 let dtype = dtype.into();
480 let self_type: FloatDType = self.dtype().into();
481 if dtype == self_type {
482 // no-op.
483 return self;
484 }
485
486 Tensor::new(TensorPrimitive::Float(B::float_cast(
487 self.primitive.tensor(),
488 dtype,
489 )))
490 }
491
492 /// Detach the current tensor from the autodiff graph.
493 ///
494 /// This function does nothing when autodiff is not enabled.
495 /// This can be used in batchers or elsewhere to ensure that previous operations are not
496 /// considered in the autodiff graph.
497 pub fn detach(self) -> Self {
498 Self::new(TensorPrimitive::Float(B::float_detach(
499 self.primitive.tensor(),
500 )))
501 }
502
503 /// Mark the tensor to keep gradients during the backward pass.
504 ///
505 /// This function does nothing when autodiff is not enabled.
506 pub fn require_grad(self) -> Self {
507 self.set_require_grad(true)
508 }
509
510 /// Returns true if the tensor requires gradients during the backward pass.
511 pub fn is_require_grad(&self) -> bool {
512 match &self.primitive {
513 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
514 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
515 }
516 }
517
518 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
519 /// When tracked, the gradients will be available after the backward pass.
520 ///
521 /// This function does nothing when autodiff is not enabled.
522 pub fn set_require_grad(self, require_grad: bool) -> Self {
523 let primitive = match self.primitive {
524 TensorPrimitive::Float(tensor) => {
525 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
526 }
527 TensorPrimitive::QFloat(tensor) => {
528 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
529 }
530 };
531 Self::new(primitive)
532 }
533
534 /// Applies the relu function to the tensor.
535 pub(crate) fn relu(self) -> Self {
536 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
537 }
538
539 /// Calculate covaraince matrix between different entries alongside a given dimension.
540 ///
541 /// # Arguments
542 ///
543 /// * `size` - The size of the square matrix.
544 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
545 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
546 let n = self.dims()[dim];
547 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
548 centered
549 .clone()
550 .transpose()
551 .matmul(centered)
552 .div_scalar(n as f32 - correction_factor as f32)
553 }
554
555 /// Convert the tensor to a lower precision data type based on the quantization scheme.
556 ///
557 /// # Arguments
558 ///
559 /// * `scheme` - The quantization scheme.
560 /// * `qparams` - The pre-computed quantization parameters.
561 ///
562 /// # Returns
563 ///
564 /// The quantized tensor.
565 pub fn quantize(
566 self,
567 scheme: &QuantScheme,
568 qparams: QuantizationParameters<B>,
569 ) -> Tensor<B, D> {
570 Tensor::new(TensorPrimitive::QFloat(B::quantize(
571 self.primitive.tensor(),
572 scheme,
573 QuantizationParametersPrimitive {
574 scales: qparams.scales.primitive.tensor(),
575 },
576 )))
577 }
578
579 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
580 ///
581 /// # Arguments
582 ///
583 /// * `scheme` - The quantization scheme.
584 ///
585 /// # Returns
586 ///
587 /// The quantized tensor.
588 ///
589 /// # Notes
590 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
591 pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
592 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
593 self.primitive.tensor(),
594 scheme,
595 )))
596 }
597
598 /// Convert the tensor back to a higher precision data type.
599 ///
600 /// If the tensor is not quantized, its value is simply returned.
601 ///
602 /// # Returns
603 ///
604 /// The dequantized tensor.
605 pub fn dequantize(self) -> Tensor<B, D> {
606 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
607 }
608
609 /// Checks element wise if the tensor is close to another tensor.
610 ///
611 /// The tolerance is defined by the following equation:
612 ///
613 /// ```text
614 /// abs(a - b) <= (atol + rtol * abs(b))
615 ///
616 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
617 /// and `atol` is the absolute tolerance.
618 /// ```
619 ///
620 /// # Arguments
621 ///
622 /// * `other` - The tensor to compare with.
623 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
624 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
625 ///
626 /// # Returns
627 ///
628 /// A boolean tensor with the same shape as the input tensors.
629 ///
630 /// # Example
631 ///
632 /// ```rust
633 /// use burn_tensor::backend::Backend;
634 /// use burn_tensor::{Tensor, Shape};
635 ///
636 /// fn example<B: Backend>() {
637 /// let device = B::Device::default();
638 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
639 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
640 /// let tensor = tensor1.is_close(tensor2, None, None);
641 /// println!("{tensor}");
642 /// // [[true, true, true], [true, true, true]]
643 /// }
644 /// ```
645 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
646 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
647 let atol = atol.unwrap_or(DEFAULT_ATOL);
648
649 // check finite difference is close
650 let is_close_finite_val = self
651 .clone()
652 .sub(other.clone())
653 .abs()
654 .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
655 .bool_and(self.clone().is_finite())
656 .bool_and(other.clone().is_finite());
657
658 // check if both are infinite and have same sign
659 let inf_same_sign = self
660 .clone()
661 .is_finite()
662 .bool_not()
663 .bool_and(other.clone().is_finite().bool_not())
664 .bool_and(self.equal(other));
665
666 is_close_finite_val.bool_or(inf_same_sign)
667 }
668
669 /// Checks if all elements are close to another tensor.
670 ///
671 /// The tolerance is defined by the following equation:
672 ///
673 /// ```text
674 ///
675 /// abs(a - b) <= (atol + rtol * abs(b))
676 ///
677 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
678 /// and `atol` is the absolute tolerance.
679 ///
680 /// ```
681 ///
682 /// # Arguments
683 ///
684 /// * `other` - The tensor to compare with.
685 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
686 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
687 ///
688 /// # Returns
689 ///
690 /// A boolean scalar.
691 ///
692 /// # Remarks
693 ///
694 /// # Example
695 ///
696 /// ```rust
697 /// use burn_tensor::backend::Backend;
698 /// use burn_tensor::{Tensor, Shape};
699 ///
700 /// fn example<B: Backend>() {
701 /// let device = B::Device::default();
702 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
703 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
704 /// let result = tensor1.all_close(tensor2, None, None);
705 /// println!("{}", result);
706 /// // true
707 /// }
708 /// ```
709 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
710 self.is_close(other, rtol, atol)
711 .all()
712 .into_scalar()
713 .to_bool()
714 }
715
716 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
717 ///
718 /// # Returns
719 ///
720 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
721 ///
722 /// # Example
723 ///
724 /// ```rust
725 /// use burn_tensor::backend::Backend;
726 /// use burn_tensor::{Tensor, Bool, Shape};
727 ///
728 /// fn example<B: Backend>() {
729 /// let device = B::Device::default();
730 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
731 /// let tensor = tensor.is_nan();
732 /// println!("{tensor}");
733 /// // [[false, true, false], [false, false, false]]
734 /// }
735 /// ```
736 pub fn is_nan(self) -> Tensor<B, D, Bool> {
737 Tensor::new(B::float_is_nan(self.primitive.tensor()))
738 }
739
740 /// Checks if the tensor contains any NaN values.
741 ///
742 /// # Returns
743 ///
744 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
745 ///
746 /// # Example
747 ///
748 /// ```rust
749 /// use burn_tensor::backend::Backend;
750 /// use burn_tensor::{Tensor, Bool, Shape};
751 ///
752 /// fn example<B: Backend>() {
753 /// let device = B::Device::default();
754 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
755 /// let tensor = tensor.contains_nan();
756 /// println!("{tensor}");
757 /// // [true]
758 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
759 /// let tensor = tensor.contains_nan();
760 /// println!("{tensor}");
761 /// // [false]
762 /// }
763 /// ```
764 pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
765 // Summing the tensor will result in NaN if the tensor contains any NaN values
766 // This is faster than checking each element individually
767 // because it rolls up the NaN values into a single value
768 let sum = self.sum();
769
770 sum.is_nan()
771 }
772
773 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
774 ///
775 /// # Returns
776 ///
777 /// A boolean tensor where `true` indicates that the value is infinite
778 ///
779 /// # Example
780 ///
781 /// ```rust
782 /// use burn_tensor::backend::Backend;
783 /// use burn_tensor::{Tensor, Bool, Shape};
784 ///
785 /// fn example<B: Backend>() {
786 /// let device = B::Device::default();
787 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
788 /// let tensor = tensor.is_finite();
789 /// println!("{tensor}");
790 /// // [[false, true, false], [false, false, false]]
791 /// }
792 /// ```
793 pub fn is_inf(self) -> Tensor<B, D, Bool> {
794 Tensor::new(B::float_is_inf(self.primitive.tensor()))
795 }
796
797 /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
798 ///
799 /// # Returns
800 ///
801 /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
802 /// either INF, -INF or NAN
803 ///
804 /// # Example
805 ///
806 /// ```rust
807 /// use burn_tensor::backend::Backend;
808 /// use burn_tensor::{Tensor, Bool, Shape};
809 ///
810 /// fn example<B: Backend>() {
811 /// let device = B::Device::default();
812 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
813 /// let tensor = tensor.is_finite();
814 /// println!("{tensor}");
815 /// // [[true, false, true], [false, true, true]]
816 /// }
817 /// ```
818 pub fn is_finite(self) -> Tensor<B, D, Bool> {
819 self.clone()
820 .is_nan()
821 .bool_not()
822 .bool_and(self.is_inf().bool_not())
823 }
824
825 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
826 /// using the given locations in [-1, 1].
827 ///
828 /// # Arguments
829 ///
830 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
831 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
832 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
833 ///
834 /// # Returns
835 ///
836 /// A tensor with shape (N, C, H_out, W_out)
837 ///
838 /// # Example
839 ///
840 /// ```ignore
841 /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
842 ///
843 /// // Default options (bilinear, zeros padding, align_corners=false)
844 /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
845 ///
846 /// // Custom options
847 /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
848 /// .with_padding_mode(GridSamplePaddingMode::Border)
849 /// .with_align_corners(true);
850 /// let output = tensor.grid_sample_2d(grid, options);
851 /// ```
852 pub fn grid_sample_2d(
853 self,
854 grid: Tensor<B, D>,
855 options: impl Into<GridSampleOptions>,
856 ) -> Tensor<B, D> {
857 Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
858 self.primitive.tensor(),
859 grid.primitive.tensor(),
860 options.into(),
861 )))
862 }
863
864 /// Computes the cross product of `self` and another tensor along a given dimension.
865 ///
866 /// Both `self` and `other` **must have size 3** along the specified `dim`,
867 /// because the cross product is only defined in three-dimensional space.
868 ///
869 /// # Arguments
870 ///
871 /// * `other` - The other tensor to take the cross product with.
872 /// * `dim` - The dimension along which to compute the cross product.
873 ///
874 /// # Returns
875 ///
876 /// A tensor containing the cross product of `self` and `other` along `dim`.
877 pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
878 let dim = dim.expect_dim_index(D);
879 check!(TensorCheck::cross(&self, &other, dim));
880 Tensor::new(TensorPrimitive::Float(B::float_cross(
881 self.primitive.tensor(),
882 other.primitive.tensor(),
883 dim,
884 )))
885 }
886}