burn_tensor/tensor/api/float.rs
1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Int, TensorPrimitive};
13use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
14use core::f32;
15
16/// Default RTOL value for `is_close` and `all_close`.
17pub const DEFAULT_RTOL: f64 = 1e-5;
18
19/// Default ATOL value for `is_close` and `all_close`.
20pub const DEFAULT_ATOL: f64 = 1e-8;
21
22impl<const D: usize, B> Tensor<B, D>
23where
24 B: Backend,
25{
26 /// Applies element wise exponential operation.
27 ///
28 #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
29 #[cfg_attr(not(doc), doc = "`y = e^x`")]
30 pub fn exp(self) -> Self {
31 Self::new(TensorPrimitive::Float(B::float_exp(
32 self.primitive.tensor(),
33 )))
34 }
35
36 /// Applies element wise natural log operation *ln*.
37 ///
38 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
39 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
40 pub fn log(self) -> Self {
41 Self::new(TensorPrimitive::Float(B::float_log(
42 self.primitive.tensor(),
43 )))
44 }
45
46 /// Applies the natural logarithm of one plus the input tensor, element-wise.
47 ///
48 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
49 #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
50 pub fn log1p(self) -> Self {
51 Self::new(TensorPrimitive::Float(B::float_log1p(
52 self.primitive.tensor(),
53 )))
54 }
55
56 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
57 ///
58 #[cfg_attr(
59 doc,
60 doc = r#"
61$y_i = \text{erf}\(x_i\)$
62
63The error function is defined as:
64
65$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
66"#
67 )]
68 #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
69 pub fn erf(self) -> Self {
70 Self::new(TensorPrimitive::Float(B::float_erf(
71 self.primitive.tensor(),
72 )))
73 }
74
75 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
76 /// (or multiplicative inverse) element wise.
77 ///
78 #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
79 #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
80 pub fn recip(self) -> Self {
81 Self::new(TensorPrimitive::Float(B::float_recip(
82 self.primitive.tensor(),
83 )))
84 }
85
86 /// Applies element wise square operation.
87 ///
88 #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
89 #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
90 pub fn square(self) -> Self {
91 self.powi_scalar(2)
92 }
93
94 /// Applies element wise root square operation.
95 ///
96 #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
97 #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
98 pub fn sqrt(self) -> Self {
99 Self::new(TensorPrimitive::Float(B::float_sqrt(
100 self.primitive.tensor(),
101 )))
102 }
103
104 /// Applies element wise cosine operation.
105 ///
106 #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
107 #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
108 pub fn cos(self) -> Self {
109 Self::new(TensorPrimitive::Float(B::float_cos(
110 self.primitive.tensor(),
111 )))
112 }
113
114 /// Applies element wise sine operation.
115 ///
116 #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
117 #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
118 pub fn sin(self) -> Self {
119 Self::new(TensorPrimitive::Float(B::float_sin(
120 self.primitive.tensor(),
121 )))
122 }
123
124 /// Applies element wise tangent operation.
125 ///
126 #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
127 #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
128 pub fn tan(self) -> Self {
129 Self::new(TensorPrimitive::Float(B::float_tan(
130 self.primitive.tensor(),
131 )))
132 }
133
134 /// Applies element wise hyperbolic cosine operation.
135 ///
136 #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
137 #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
138 ///
139 /// # Example
140 ///
141 /// ```rust
142 /// use burn_tensor::backend::Backend;
143 /// use burn_tensor::Tensor;
144 ///
145 /// fn example<B: Backend>() {
146 /// let device = Default::default();
147 ///
148 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
149 /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
150 /// }
151 /// ```
152 pub fn cosh(self) -> Self {
153 Self::new(TensorPrimitive::Float(B::float_cosh(
154 self.primitive.tensor(),
155 )))
156 }
157
158 /// Applies element wise hyperbolic sine operation.
159 ///
160 #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
161 #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
162 ///
163 /// # Example
164 ///
165 /// ```rust
166 /// use burn_tensor::backend::Backend;
167 /// use burn_tensor::Tensor;
168 ///
169 /// fn example<B: Backend>() {
170 /// let device = Default::default();
171 ///
172 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
173 /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
174 /// }
175 /// ```
176 pub fn sinh(self) -> Self {
177 Self::new(TensorPrimitive::Float(B::float_sinh(
178 self.primitive.tensor(),
179 )))
180 }
181
182 /// Applies element wise hyperbolic tangent operation.
183 ///
184 #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
185 #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
186 ///
187 /// # Example
188 ///
189 /// ```rust
190 /// use burn_tensor::backend::Backend;
191 /// use burn_tensor::Tensor;
192 ///
193 /// fn example<B: Backend>() {
194 /// let device = Default::default();
195 ///
196 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
197 /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
198 /// }
199 /// ```
200 pub fn tanh(self) -> Self {
201 Self::new(TensorPrimitive::Float(B::float_tanh(
202 self.primitive.tensor(),
203 )))
204 }
205
206 /// Applies element wise inverse sine operation.
207 ///
208 #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
209 #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
210 ///
211 /// # Example
212 ///
213 /// ```rust
214 /// use burn_tensor::backend::Backend;
215 /// use burn_tensor::Tensor;
216 ///
217 /// fn example<B: Backend>() {
218 /// let device = Default::default();
219 ///
220 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
221 /// println!("{}", tensor.asin()); // [ 0.0000, -1.5708, 1.5708]
222 /// }
223 /// ```
224 pub fn asin(self) -> Self {
225 Self::new(TensorPrimitive::Float(B::float_asin(
226 self.primitive.tensor(),
227 )))
228 }
229
230 /// Applies element wise inverse hyperbolic sine operation.
231 ///
232 #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
233 #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
234 ///
235 /// # Example
236 ///
237 /// ```rust
238 /// use burn_tensor::backend::Backend;
239 /// use burn_tensor::Tensor;
240 ///
241 /// fn example<B: Backend>() {
242 /// let device = Default::default();
243 ///
244 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
245 /// println!("{}", tensor.asinh()); // [ 0.0000, -0.8814, 0.8814]
246 /// }
247 /// ```
248 pub fn asinh(self) -> Self {
249 Self::new(TensorPrimitive::Float(B::float_asinh(
250 self.primitive.tensor(),
251 )))
252 }
253
254 /// Applies element wise inverse cosine operation.
255 ///
256 #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
257 #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
258 ///
259 /// # Example
260 ///
261 /// ```rust
262 /// use burn_tensor::backend::Backend;
263 /// use burn_tensor::Tensor;
264 ///
265 /// fn example<B: Backend>() {
266 /// let device = Default::default();
267 ///
268 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
269 /// println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
270 /// }
271 /// ```
272 pub fn acos(self) -> Self {
273 Self::new(TensorPrimitive::Float(B::float_acos(
274 self.primitive.tensor(),
275 )))
276 }
277
278 /// Applies element wise inverse hyperbolic cosine operation.
279 ///
280 #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
281 #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
282 ///
283 /// # Example
284 ///
285 /// ```rust
286 /// use burn_tensor::backend::Backend;
287 /// use burn_tensor::Tensor;
288 ///
289 /// fn example<B: Backend>() {
290 /// let device = Default::default();
291 ///
292 /// let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
293 /// println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627]
294 /// }
295 /// ```
296 pub fn acosh(self) -> Self {
297 Self::new(TensorPrimitive::Float(B::float_acosh(
298 self.primitive.tensor(),
299 )))
300 }
301
302 /// Applies element wise inverse tangent operation.
303 ///
304 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
305 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
306 ///
307 /// # Example
308 ///
309 /// ```rust
310 /// use burn_tensor::backend::Backend;
311 /// use burn_tensor::Tensor;
312 ///
313 /// fn example<B: Backend>() {
314 /// let device = Default::default();
315 ///
316 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
317 /// println!("{}", tensor.sinh()); // [ 0.0, -0.7854, 1.1071]
318 /// }
319 /// ```
320 pub fn atan(self) -> Self {
321 Self::new(TensorPrimitive::Float(B::float_atan(
322 self.primitive.tensor(),
323 )))
324 }
325
326 /// Applies element wise inverse hyperbolic tangent operation.
327 ///
328 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
329 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
330 ///
331 /// # Example
332 ///
333 /// ```rust
334 /// use burn_tensor::backend::Backend;
335 /// use burn_tensor::Tensor;
336 ///
337 /// fn example<B: Backend>() {
338 /// let device = Default::default();
339 ///
340 /// let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
341 /// println!("{}", tensor.sinh()); // [ 0.0, -0.5493, 0.5493]
342 /// }
343 /// ```
344 pub fn atanh(self) -> Self {
345 Self::new(TensorPrimitive::Float(B::float_atanh(
346 self.primitive.tensor(),
347 )))
348 }
349
350 /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
351 ///
352 #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
353 #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
354 ///
355 /// # Example
356 ///
357 /// ```rust
358 /// use burn_tensor::backend::Backend;
359 /// use burn_tensor::Tensor;
360 ///
361 /// fn example<B: Backend>() {
362 /// let device = Default::default();
363 ///
364 /// let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
365 /// let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
366 /// println!("{}", lhs.atan2(rhs)); // [-1.1071, 2.0344, -2.0344]
367 /// }
368 /// ```
369 pub fn atan2(self, other: Self) -> Self {
370 Self::new(TensorPrimitive::Float(B::float_atan2(
371 self.primitive.tensor(),
372 other.primitive.tensor(),
373 )))
374 }
375
376 /// Converts each of the elements of the input tensor from angles in degrees to radians.
377 ///
378 /// # Example
379 /// ```ignore
380 /// let tensor_in_radians = tensor.deg2rad();
381 /// ```
382 pub fn deg2rad(self) -> Self {
383 self.mul_scalar(f32::consts::PI / 180.0)
384 }
385
386 /// Converts each of the elements of the input tensor from angles in radians to degrees.
387 ///
388 /// # Example
389 /// ```ignore
390 /// let tensor_in_degrees = tensor.rad2deg();
391 /// ```
392 pub fn rad2deg(self) -> Self {
393 self.mul_scalar(180.0 / f32::consts::PI)
394 }
395
396 /// Applies element wise round operation.
397 ///
398 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
399 /// strategy, with halfway cases rounded to the nearest even integer value.
400 pub fn round(self) -> Self {
401 Self::new(TensorPrimitive::Float(B::float_round(
402 self.primitive.tensor(),
403 )))
404 }
405
406 /// Applies element wise floor operation.
407 pub fn floor(self) -> Self {
408 Self::new(TensorPrimitive::Float(B::float_floor(
409 self.primitive.tensor(),
410 )))
411 }
412
413 /// Applies element wise ceil operation.
414 pub fn ceil(self) -> Self {
415 Self::new(TensorPrimitive::Float(B::float_ceil(
416 self.primitive.tensor(),
417 )))
418 }
419
420 /// Create a tensor from floats (f32) on a given device.
421 ///
422 /// # Example
423 ///
424 /// ```rust
425 /// use burn_tensor::backend::Backend;
426 /// use burn_tensor::Tensor;
427 ///
428 /// fn example<B: Backend>() {
429 /// let device = B::Device::default();
430 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
431 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
432 /// }
433 /// ```
434 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
435 Self::from_data(floats.into().convert::<f32>(), device)
436 }
437
438 /// Returns a new tensor with the same shape and device as the current tensor and the data
439 /// cast to Integer.
440 ///
441 /// # Example
442 ///
443 /// ```rust
444 /// use burn_tensor::backend::Backend;
445 /// use burn_tensor::Tensor;
446 ///
447 /// fn example<B: Backend>() {
448 /// let device = Default::default();
449 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
450 /// let int_tensor = float_tensor.int();
451 /// }
452 /// ```
453 pub fn int(self) -> Tensor<B, D, Int> {
454 Tensor::new(B::float_into_int(self.primitive.tensor()))
455 }
456
457 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
458 /// values sampled from the given distribution.
459 pub fn random_like(&self, distribution: Distribution) -> Self {
460 Self::new(TensorPrimitive::Float(B::float_random(
461 self.shape(),
462 distribution,
463 &self.device(),
464 )))
465 .cast(self.dtype())
466 }
467
468 /// Calculate the variance along the given dimension.
469 pub fn var(self, dim: usize) -> Self {
470 stats::var(self, dim)
471 }
472
473 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
474 pub fn var_bias(self, dim: usize) -> Self {
475 stats::var_bias(self, dim)
476 }
477
478 /// Calculate the variance along the given dimension and also returns the mean.
479 pub fn var_mean(self, dim: usize) -> (Self, Self) {
480 let mean = self.clone().mean_dim(dim);
481 let var = stats::var_with_mean(self, mean.clone(), dim);
482 (var, mean)
483 }
484
485 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
486 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
487 let mean = self.clone().mean_dim(dim);
488 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
489 (var, mean)
490 }
491
492 /// Returns the median value along the specified dimension.
493 ///
494 /// The median is not unique for input tensors with an even number of elements
495 /// in the reduced dimension. In this case, the lower of the two medians is returned,
496 /// following PyTorch's behavior.
497 ///
498 /// # Note
499 ///
500 /// The current implementation performs a full sort along the specified dimension,
501 /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
502 /// to CPU for the sort operation, which may result in slower performance compared
503 /// to native GPU operations.
504 ///
505 /// # Arguments
506 ///
507 /// - `dim` - The dimension along which to compute the median.
508 ///
509 /// # Returns
510 ///
511 /// - A tensor containing the median values along the specified dimension.
512 ///
513 /// # Example 1
514 ///
515 /// ```ignore
516 /// // Assuming backend B
517 /// let device = B::Device::default();
518 /// let tensor = Tensor::<B, 2>::from_data(
519 /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
520 /// &device,
521 /// );
522 ///
523 /// // Median along dimension 0:
524 /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]
525 /// let median = tensor.median(0);
526 /// // Result: [[1.0, 4.0, 3.0, 2.0]]
527 ///
528 /// // Median along dimension 1:
529 /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
530 /// let median = tensor.median(1);
531 /// // Result: [[2.0], [6.0]]
532 /// ```
533 ///
534 /// # Example 2
535 ///
536 /// The median across all elements can be calculated as follows:
537 ///
538 /// ```ignore
539 /// // D is the number of dimensions of the tensor
540 /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);
541 ///
542 /// // Calculate median for dim 0 since the tensor has become 1 dimensional
543 /// let median = flattened_tensor.median(0);
544 /// // Result: [4.0]
545 /// ```
546 pub fn median(self, dim: usize) -> Self {
547 // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
548 // instead of leveraging a full sort to get the median.
549 stats::median(self, dim)
550 }
551
552 /// Returns the median value along the specified dimension and its index.
553 ///
554 /// The median is not unique for input tensors with an even number of elements
555 /// in the reduced dimension. In this case, the lower of the two medians is returned,
556 /// following PyTorch's behavior.
557 ///
558 /// # Note
559 ///
560 /// The current implementation performs a full sort along the specified dimension,
561 /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
562 /// to CPU for the sort operation, which may result in slower performance compared
563 /// to native GPU operations.
564 ///
565 /// # Arguments
566 ///
567 /// - `dim` - The dimension along which to compute the median.
568 ///
569 /// # Returns
570 ///
571 /// A tuple containing:
572 /// - A tensor with the median values.
573 /// - A tensor with the indices of the median values in the original tensor.
574 ///
575 /// # Example
576 ///
577 /// ```ignore
578 /// // Assuming backend B
579 /// let device = B::Device::default();
580 /// let tensor = Tensor::<B, 2>::from_data(
581 /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
582 /// &device,
583 /// );
584 ///
585 /// // Median along dimension 1:
586 /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
587 /// let (values, indices) = tensor.median_with_indices(1);
588 /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)
589 /// ```
590 pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
591 // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
592 // instead of leveraging a full sort to get the median.
593 stats::median_with_indices(self, dim)
594 }
595
596 /// Converts a tensor to the specified floating point data type.
597 ///
598 /// This is always a no-op when casting to the current dtype.
599 ///
600 /// # Warning
601 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
602 /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
603 pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
604 let dtype = dtype.into();
605 let self_type: FloatDType = self.dtype().into();
606 if dtype == self_type {
607 // no-op.
608 return self;
609 }
610
611 Tensor::new(TensorPrimitive::Float(B::float_cast(
612 self.primitive.tensor(),
613 dtype,
614 )))
615 }
616
617 /// Detach the current tensor from the autodiff graph.
618 ///
619 /// This function does nothing when autodiff is not enabled.
620 /// This can be used in batchers or elsewhere to ensure that previous operations are not
621 /// considered in the autodiff graph.
622 pub fn detach(self) -> Self {
623 Self::new(TensorPrimitive::Float(B::float_detach(
624 self.primitive.tensor(),
625 )))
626 }
627
628 /// Mark the tensor to keep gradients during the backward pass.
629 ///
630 /// This function does nothing when autodiff is not enabled.
631 pub fn require_grad(self) -> Self {
632 self.set_require_grad(true)
633 }
634
635 /// Returns true if the tensor requires gradients during the backward pass.
636 pub fn is_require_grad(&self) -> bool {
637 match &self.primitive {
638 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
639 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
640 }
641 }
642
643 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
644 /// When tracked, the gradients will be available after the backward pass.
645 ///
646 /// This function does nothing when autodiff is not enabled.
647 pub fn set_require_grad(self, require_grad: bool) -> Self {
648 let primitive = match self.primitive {
649 TensorPrimitive::Float(tensor) => {
650 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
651 }
652 TensorPrimitive::QFloat(tensor) => {
653 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
654 }
655 };
656 Self::new(primitive)
657 }
658
659 /// Applies the relu function to the tensor.
660 pub(crate) fn relu(self) -> Self {
661 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
662 }
663
664 /// Calculate covaraince matrix between different entries alongside a given dimension.
665 ///
666 /// # Arguments
667 ///
668 /// * `size` - The size of the square matrix.
669 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
670 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
671 let n = self.dims()[dim];
672 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
673 centered
674 .clone()
675 .transpose()
676 .matmul(centered)
677 .div_scalar(n as f32 - correction_factor as f32)
678 }
679
680 /// Convert the tensor to a lower precision data type based on the quantization scheme.
681 ///
682 /// # Arguments
683 ///
684 /// * `scheme` - The quantization scheme.
685 /// * `qparams` - The pre-computed quantization parameters.
686 ///
687 /// # Returns
688 ///
689 /// The quantized tensor.
690 pub fn quantize(
691 self,
692 scheme: &QuantScheme,
693 qparams: QuantizationParameters<B>,
694 ) -> Tensor<B, D> {
695 Tensor::new(TensorPrimitive::QFloat(B::quantize(
696 self.primitive.tensor(),
697 scheme,
698 QuantizationParametersPrimitive {
699 scales: qparams.scales.primitive.tensor(),
700 },
701 )))
702 }
703
704 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
705 ///
706 /// # Arguments
707 ///
708 /// * `scheme` - The quantization scheme.
709 ///
710 /// # Returns
711 ///
712 /// The quantized tensor.
713 ///
714 /// # Notes
715 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
716 pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
717 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
718 self.primitive.tensor(),
719 scheme,
720 )))
721 }
722
723 /// Convert the tensor back to a higher precision data type.
724 ///
725 /// If the tensor is not quantized, its value is simply returned.
726 ///
727 /// # Returns
728 ///
729 /// The dequantized tensor.
730 pub fn dequantize(self) -> Tensor<B, D> {
731 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
732 }
733
734 /// Checks element wise if the tensor is close to another tensor.
735 ///
736 /// The tolerance is defined by the following equation:
737 ///
738 /// ```text
739 /// abs(a - b) <= (atol + rtol * abs(b))
740 ///
741 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
742 /// and `atol` is the absolute tolerance.
743 /// ```
744 ///
745 /// # Arguments
746 ///
747 /// * `other` - The tensor to compare with.
748 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
749 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
750 ///
751 /// # Returns
752 ///
753 /// A boolean tensor with the same shape as the input tensors.
754 ///
755 /// # Example
756 ///
757 /// ```rust
758 /// use burn_tensor::backend::Backend;
759 /// use burn_tensor::{Tensor, Shape};
760 ///
761 /// fn example<B: Backend>() {
762 /// let device = B::Device::default();
763 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
764 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
765 /// let tensor = tensor1.is_close(tensor2, None, None);
766 /// println!("{tensor}");
767 /// // [[true, true, true], [true, true, true]]
768 /// }
769 /// ```
770 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
771 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
772 let atol = atol.unwrap_or(DEFAULT_ATOL);
773
774 // check finite difference is close
775 let is_close_finite_val = self
776 .clone()
777 .sub(other.clone())
778 .abs()
779 .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
780 .bool_and(self.clone().is_finite())
781 .bool_and(other.clone().is_finite());
782
783 // check if both are infinite and have same sign
784 let inf_same_sign = self
785 .clone()
786 .is_finite()
787 .bool_not()
788 .bool_and(other.clone().is_finite().bool_not())
789 .bool_and(self.equal(other));
790
791 is_close_finite_val.bool_or(inf_same_sign)
792 }
793
794 /// Checks if all elements are close to another tensor.
795 ///
796 /// The tolerance is defined by the following equation:
797 ///
798 /// ```text
799 ///
800 /// abs(a - b) <= (atol + rtol * abs(b))
801 ///
802 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
803 /// and `atol` is the absolute tolerance.
804 ///
805 /// ```
806 ///
807 /// # Arguments
808 ///
809 /// * `other` - The tensor to compare with.
810 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
811 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
812 ///
813 /// # Returns
814 ///
815 /// A boolean scalar.
816 ///
817 /// # Remarks
818 ///
819 /// # Example
820 ///
821 /// ```rust
822 /// use burn_tensor::backend::Backend;
823 /// use burn_tensor::{Tensor, Shape};
824 ///
825 /// fn example<B: Backend>() {
826 /// let device = B::Device::default();
827 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
828 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
829 /// let result = tensor1.all_close(tensor2, None, None);
830 /// println!("{}", result);
831 /// // true
832 /// }
833 /// ```
834 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
835 self.is_close(other, rtol, atol)
836 .all()
837 .into_scalar()
838 .to_bool()
839 }
840
841 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
842 ///
843 /// # Returns
844 ///
845 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
846 ///
847 /// # Example
848 ///
849 /// ```rust
850 /// use burn_tensor::backend::Backend;
851 /// use burn_tensor::{Tensor, Bool, Shape};
852 ///
853 /// fn example<B: Backend>() {
854 /// let device = B::Device::default();
855 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
856 /// let tensor = tensor.is_nan();
857 /// println!("{tensor}");
858 /// // [[false, true, false], [false, false, false]]
859 /// }
860 /// ```
861 pub fn is_nan(self) -> Tensor<B, D, Bool> {
862 Tensor::new(B::float_is_nan(self.primitive.tensor()))
863 }
864
865 /// Checks if the tensor contains any NaN values.
866 ///
867 /// # Returns
868 ///
869 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
870 ///
871 /// # Example
872 ///
873 /// ```rust
874 /// use burn_tensor::backend::Backend;
875 /// use burn_tensor::{Tensor, Bool, Shape};
876 ///
877 /// fn example<B: Backend>() {
878 /// let device = B::Device::default();
879 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
880 /// let tensor = tensor.contains_nan();
881 /// println!("{tensor}");
882 /// // [true]
883 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
884 /// let tensor = tensor.contains_nan();
885 /// println!("{tensor}");
886 /// // [false]
887 /// }
888 /// ```
889 pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
890 // Summing the tensor will result in NaN if the tensor contains any NaN values
891 // This is faster than checking each element individually
892 // because it rolls up the NaN values into a single value
893 let sum = self.sum();
894
895 sum.is_nan()
896 }
897
898 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
899 ///
900 /// # Returns
901 ///
902 /// A boolean tensor where `true` indicates that the value is infinite
903 ///
904 /// # Example
905 ///
906 /// ```rust
907 /// use burn_tensor::backend::Backend;
908 /// use burn_tensor::{Tensor, Bool, Shape};
909 ///
910 /// fn example<B: Backend>() {
911 /// let device = B::Device::default();
912 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
913 /// let tensor = tensor.is_finite();
914 /// println!("{tensor}");
915 /// // [[false, true, false], [false, false, false]]
916 /// }
917 /// ```
918 pub fn is_inf(self) -> Tensor<B, D, Bool> {
919 Tensor::new(B::float_is_inf(self.primitive.tensor()))
920 }
921
922 /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
923 ///
924 /// # Returns
925 ///
926 /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
927 /// either INF, -INF or NAN
928 ///
929 /// # Example
930 ///
931 /// ```rust
932 /// use burn_tensor::backend::Backend;
933 /// use burn_tensor::{Tensor, Bool, Shape};
934 ///
935 /// fn example<B: Backend>() {
936 /// let device = B::Device::default();
937 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
938 /// let tensor = tensor.is_finite();
939 /// println!("{tensor}");
940 /// // [[true, false, true], [false, true, true]]
941 /// }
942 /// ```
943 pub fn is_finite(self) -> Tensor<B, D, Bool> {
944 self.clone()
945 .is_nan()
946 .bool_not()
947 .bool_and(self.is_inf().bool_not())
948 }
949
950 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
951 /// using the given locations in [-1, 1].
952 ///
953 /// # Arguments
954 ///
955 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
956 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
957 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
958 ///
959 /// # Returns
960 ///
961 /// A tensor with shape (N, C, H_out, W_out)
962 ///
963 /// # Example
964 ///
965 /// ```ignore
966 /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
967 ///
968 /// // Default options (bilinear, zeros padding, align_corners=false)
969 /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
970 ///
971 /// // Custom options
972 /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
973 /// .with_padding_mode(GridSamplePaddingMode::Border)
974 /// .with_align_corners(true);
975 /// let output = tensor.grid_sample_2d(grid, options);
976 /// ```
977 pub fn grid_sample_2d(
978 self,
979 grid: Tensor<B, D>,
980 options: impl Into<GridSampleOptions>,
981 ) -> Tensor<B, D> {
982 Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
983 self.primitive.tensor(),
984 grid.primitive.tensor(),
985 options.into(),
986 )))
987 }
988
989 /// Computes the cross product of `self` and another tensor along a given dimension.
990 ///
991 /// Both `self` and `other` **must have size 3** along the specified `dim`,
992 /// because the cross product is only defined in three-dimensional space.
993 ///
994 /// # Arguments
995 ///
996 /// * `other` - The other tensor to take the cross product with.
997 /// * `dim` - The dimension along which to compute the cross product.
998 ///
999 /// # Returns
1000 ///
1001 /// A tensor containing the cross product of `self` and `other` along `dim`.
1002 pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
1003 let dim = dim.expect_dim_index(D);
1004 check!(TensorCheck::cross(&self, &other, dim));
1005 Tensor::new(TensorPrimitive::Float(B::float_cross(
1006 self.primitive.tensor(),
1007 other.primitive.tensor(),
1008 dim,
1009 )))
1010 }
1011}