burn_tensor/tensor/api/float.rs
1use crate::AsIndex;
2use crate::Cast;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Float, Int, TensorPrimitive};
13#[cfg(feature = "distributed")]
14use burn_backend::AutodiffBackend;
15use burn_backend::ElementConversion;
16use burn_backend::Scalar;
17use burn_backend::TensorMetadata;
18#[cfg(feature = "distributed")]
19use burn_backend::distributed::DistributedParamId;
20use burn_backend::get_device_settings;
21use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
22use core::f32;
23
24/// Default RTOL value for `is_close` and `all_close`.
25pub const DEFAULT_RTOL: f64 = 1e-5;
26
27/// Default ATOL value for `is_close` and `all_close`.
28pub const DEFAULT_ATOL: f64 = 1e-8;
29
30impl<const D: usize, B> Tensor<B, D>
31where
32 B: Backend,
33{
34 /// Applies element wise exponential operation.
35 ///
36 #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
37 #[cfg_attr(not(doc), doc = "`y = e^x`")]
38 pub fn exp(self) -> Self {
39 Self::new(TensorPrimitive::Float(B::float_exp(
40 self.primitive.tensor(),
41 )))
42 }
43
44 /// Applies element wise natural log operation *ln*.
45 ///
46 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
47 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
48 pub fn log(self) -> Self {
49 Self::new(TensorPrimitive::Float(B::float_log(
50 self.primitive.tensor(),
51 )))
52 }
53
54 /// Applies the natural logarithm of one plus the input tensor, element-wise.
55 ///
56 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
57 #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
58 pub fn log1p(self) -> Self {
59 Self::new(TensorPrimitive::Float(B::float_log1p(
60 self.primitive.tensor(),
61 )))
62 }
63
64 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
65 ///
66 #[cfg_attr(
67 doc,
68 doc = r#"
69$y_i = \text{erf}\(x_i\)$
70
71The error function is defined as:
72
73$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
74"#
75 )]
76 #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
77 pub fn erf(self) -> Self {
78 Self::new(TensorPrimitive::Float(B::float_erf(
79 self.primitive.tensor(),
80 )))
81 }
82
83 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
84 /// (or multiplicative inverse) element wise.
85 ///
86 #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
87 #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
88 pub fn recip(self) -> Self {
89 Self::new(TensorPrimitive::Float(B::float_recip(
90 self.primitive.tensor(),
91 )))
92 }
93
94 /// Applies element wise square operation.
95 ///
96 #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
97 #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
98 pub fn square(self) -> Self {
99 self.powi_scalar(2)
100 }
101
102 /// Applies element wise root square operation.
103 ///
104 #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
105 #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
106 pub fn sqrt(self) -> Self {
107 Self::new(TensorPrimitive::Float(B::float_sqrt(
108 self.primitive.tensor(),
109 )))
110 }
111
112 /// Applies element wise cosine operation.
113 ///
114 #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
115 #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
116 pub fn cos(self) -> Self {
117 Self::new(TensorPrimitive::Float(B::float_cos(
118 self.primitive.tensor(),
119 )))
120 }
121
122 /// Applies element wise sine operation.
123 ///
124 #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
125 #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
126 pub fn sin(self) -> Self {
127 Self::new(TensorPrimitive::Float(B::float_sin(
128 self.primitive.tensor(),
129 )))
130 }
131
132 /// Applies element wise tangent operation.
133 ///
134 #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
135 #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
136 pub fn tan(self) -> Self {
137 Self::new(TensorPrimitive::Float(B::float_tan(
138 self.primitive.tensor(),
139 )))
140 }
141
142 /// Applies element wise hyperbolic cosine operation.
143 ///
144 #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
145 #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
146 ///
147 /// # Example
148 ///
149 /// ```rust
150 /// use burn_tensor::backend::Backend;
151 /// use burn_tensor::Tensor;
152 ///
153 /// fn example<B: Backend>() {
154 /// let device = Default::default();
155 ///
156 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
157 /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
158 /// }
159 /// ```
160 pub fn cosh(self) -> Self {
161 Self::new(TensorPrimitive::Float(B::float_cosh(
162 self.primitive.tensor(),
163 )))
164 }
165
166 /// Applies element wise hyperbolic sine operation.
167 ///
168 #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
169 #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
170 ///
171 /// # Example
172 ///
173 /// ```rust
174 /// use burn_tensor::backend::Backend;
175 /// use burn_tensor::Tensor;
176 ///
177 /// fn example<B: Backend>() {
178 /// let device = Default::default();
179 ///
180 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
181 /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
182 /// }
183 /// ```
184 pub fn sinh(self) -> Self {
185 Self::new(TensorPrimitive::Float(B::float_sinh(
186 self.primitive.tensor(),
187 )))
188 }
189
190 /// Applies element wise hyperbolic tangent operation.
191 ///
192 #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
193 #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
194 ///
195 /// # Example
196 ///
197 /// ```rust
198 /// use burn_tensor::backend::Backend;
199 /// use burn_tensor::Tensor;
200 ///
201 /// fn example<B: Backend>() {
202 /// let device = Default::default();
203 ///
204 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
205 /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
206 /// }
207 /// ```
208 pub fn tanh(self) -> Self {
209 Self::new(TensorPrimitive::Float(B::float_tanh(
210 self.primitive.tensor(),
211 )))
212 }
213
214 /// Applies element wise inverse sine operation.
215 ///
216 #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
217 #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
218 ///
219 /// # Example
220 ///
221 /// ```rust
222 /// use burn_tensor::backend::Backend;
223 /// use burn_tensor::Tensor;
224 ///
225 /// fn example<B: Backend>() {
226 /// let device = Default::default();
227 ///
228 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
229 /// println!("{}", tensor.asin()); // [ 0.0000, -1.5708, 1.5708]
230 /// }
231 /// ```
232 pub fn asin(self) -> Self {
233 Self::new(TensorPrimitive::Float(B::float_asin(
234 self.primitive.tensor(),
235 )))
236 }
237
238 /// Applies element wise inverse hyperbolic sine operation.
239 ///
240 #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
241 #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
242 ///
243 /// # Example
244 ///
245 /// ```rust
246 /// use burn_tensor::backend::Backend;
247 /// use burn_tensor::Tensor;
248 ///
249 /// fn example<B: Backend>() {
250 /// let device = Default::default();
251 ///
252 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
253 /// println!("{}", tensor.asinh()); // [ 0.0000, -0.8814, 0.8814]
254 /// }
255 /// ```
256 pub fn asinh(self) -> Self {
257 Self::new(TensorPrimitive::Float(B::float_asinh(
258 self.primitive.tensor(),
259 )))
260 }
261
262 /// Applies element wise inverse cosine operation.
263 ///
264 #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
265 #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
266 ///
267 /// # Example
268 ///
269 /// ```rust
270 /// use burn_tensor::backend::Backend;
271 /// use burn_tensor::Tensor;
272 ///
273 /// fn example<B: Backend>() {
274 /// let device = Default::default();
275 ///
276 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
277 /// println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
278 /// }
279 /// ```
280 pub fn acos(self) -> Self {
281 Self::new(TensorPrimitive::Float(B::float_acos(
282 self.primitive.tensor(),
283 )))
284 }
285
286 /// Applies element wise inverse hyperbolic cosine operation.
287 ///
288 #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
289 #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
290 ///
291 /// # Example
292 ///
293 /// ```rust
294 /// use burn_tensor::backend::Backend;
295 /// use burn_tensor::Tensor;
296 ///
297 /// fn example<B: Backend>() {
298 /// let device = Default::default();
299 ///
300 /// let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
301 /// println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627]
302 /// }
303 /// ```
304 pub fn acosh(self) -> Self {
305 Self::new(TensorPrimitive::Float(B::float_acosh(
306 self.primitive.tensor(),
307 )))
308 }
309
310 /// Applies element wise inverse tangent operation.
311 ///
312 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
313 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
314 ///
315 /// # Example
316 ///
317 /// ```rust
318 /// use burn_tensor::backend::Backend;
319 /// use burn_tensor::Tensor;
320 ///
321 /// fn example<B: Backend>() {
322 /// let device = Default::default();
323 ///
324 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
325 /// println!("{}", tensor.sinh()); // [ 0.0, -0.7854, 1.1071]
326 /// }
327 /// ```
328 pub fn atan(self) -> Self {
329 Self::new(TensorPrimitive::Float(B::float_atan(
330 self.primitive.tensor(),
331 )))
332 }
333
334 /// Applies element wise inverse hyperbolic tangent operation.
335 ///
336 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
337 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
338 ///
339 /// # Example
340 ///
341 /// ```rust
342 /// use burn_tensor::backend::Backend;
343 /// use burn_tensor::Tensor;
344 ///
345 /// fn example<B: Backend>() {
346 /// let device = Default::default();
347 ///
348 /// let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
349 /// println!("{}", tensor.sinh()); // [ 0.0, -0.5493, 0.5493]
350 /// }
351 /// ```
352 pub fn atanh(self) -> Self {
353 Self::new(TensorPrimitive::Float(B::float_atanh(
354 self.primitive.tensor(),
355 )))
356 }
357
358 /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
359 ///
360 #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
361 #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
362 ///
363 /// # Example
364 ///
365 /// ```rust
366 /// use burn_tensor::backend::Backend;
367 /// use burn_tensor::Tensor;
368 ///
369 /// fn example<B: Backend>() {
370 /// let device = Default::default();
371 ///
372 /// let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
373 /// let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
374 /// println!("{}", lhs.atan2(rhs)); // [-1.1071, 2.0344, -2.0344]
375 /// }
376 /// ```
377 pub fn atan2(self, other: Self) -> Self {
378 Self::new(TensorPrimitive::Float(B::float_atan2(
379 self.primitive.tensor(),
380 other.primitive.tensor(),
381 )))
382 }
383
384 /// Converts each of the elements of the input tensor from angles in degrees to radians.
385 ///
386 /// # Example
387 /// ```ignore
388 /// let tensor_in_radians = tensor.deg2rad();
389 /// ```
390 pub fn deg2rad(self) -> Self {
391 self.mul_scalar(f32::consts::PI / 180.0)
392 }
393
394 /// Converts each of the elements of the input tensor from angles in radians to degrees.
395 ///
396 /// # Example
397 /// ```ignore
398 /// let tensor_in_degrees = tensor.rad2deg();
399 /// ```
400 pub fn rad2deg(self) -> Self {
401 self.mul_scalar(180.0 / f32::consts::PI)
402 }
403
404 /// Applies element wise round operation.
405 ///
406 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
407 /// strategy, with halfway cases rounded to the nearest even integer value.
408 pub fn round(self) -> Self {
409 Self::new(TensorPrimitive::Float(B::float_round(
410 self.primitive.tensor(),
411 )))
412 }
413
414 /// Applies element wise floor operation.
415 pub fn floor(self) -> Self {
416 Self::new(TensorPrimitive::Float(B::float_floor(
417 self.primitive.tensor(),
418 )))
419 }
420
421 /// Applies element wise ceil operation.
422 pub fn ceil(self) -> Self {
423 Self::new(TensorPrimitive::Float(B::float_ceil(
424 self.primitive.tensor(),
425 )))
426 }
427
428 /// Create a tensor from floats (f32) on a given device.
429 ///
430 /// # Example
431 ///
432 /// ```rust
433 /// use burn_tensor::backend::Backend;
434 /// use burn_tensor::Tensor;
435 ///
436 /// fn example<B: Backend>() {
437 /// let device = B::Device::default();
438 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
439 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
440 /// }
441 /// ```
442 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
443 Self::from_data(floats.into().convert::<f32>(), device)
444 }
445
446 /// Returns a new tensor with the same shape and device as the current tensor and the data
447 /// cast to Integer.
448 ///
449 /// # Example
450 ///
451 /// ```rust
452 /// use burn_tensor::backend::Backend;
453 /// use burn_tensor::Tensor;
454 ///
455 /// fn example<B: Backend>() {
456 /// let device = Default::default();
457 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
458 /// let int_tensor = float_tensor.int();
459 /// }
460 /// ```
461 pub fn int(self) -> Tensor<B, D, Int> {
462 let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
463 Tensor::new(B::float_into_int(self.primitive.tensor(), out_dtype))
464 }
465
466 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
467 /// values sampled from the given distribution.
468 pub fn random_like(&self, distribution: Distribution) -> Self {
469 Self::new(TensorPrimitive::Float(B::float_random(
470 self.shape(),
471 distribution,
472 &self.device(),
473 self.dtype().into(),
474 )))
475 }
476
477 /// Calculate the variance along the given dimension.
478 pub fn var(self, dim: usize) -> Self {
479 stats::var(self, dim)
480 }
481
482 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
483 pub fn var_bias(self, dim: usize) -> Self {
484 stats::var_bias(self, dim)
485 }
486
487 /// Calculate the variance along the given dimension and also returns the mean.
488 pub fn var_mean(self, dim: usize) -> (Self, Self) {
489 let mean = self.clone().mean_dim(dim);
490 let var = stats::var_with_mean(self, mean.clone(), dim);
491 (var, mean)
492 }
493
494 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
495 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
496 let mean = self.clone().mean_dim(dim);
497 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
498 (var, mean)
499 }
500
501 /// Returns the median value along the specified dimension.
502 ///
503 /// The median is not unique for input tensors with an even number of elements
504 /// in the reduced dimension. In this case, the lower of the two medians is returned,
505 /// following PyTorch's behavior.
506 ///
507 /// # Note
508 ///
509 /// The current implementation performs a full sort along the specified dimension,
510 /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
511 /// to CPU for the sort operation, which may result in slower performance compared
512 /// to native GPU operations.
513 ///
514 /// # Arguments
515 ///
516 /// - `dim` - The dimension along which to compute the median.
517 ///
518 /// # Returns
519 ///
520 /// - A tensor containing the median values along the specified dimension.
521 ///
522 /// # Example 1
523 ///
524 /// ```ignore
525 /// // Assuming backend B
526 /// let device = B::Device::default();
527 /// let tensor = Tensor::<B, 2>::from_data(
528 /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
529 /// &device,
530 /// );
531 ///
532 /// // Median along dimension 0:
533 /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]
534 /// let median = tensor.median(0);
535 /// // Result: [[1.0, 4.0, 3.0, 2.0]]
536 ///
537 /// // Median along dimension 1:
538 /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
539 /// let median = tensor.median(1);
540 /// // Result: [[2.0], [6.0]]
541 /// ```
542 ///
543 /// # Example 2
544 ///
545 /// The median across all elements can be calculated as follows:
546 ///
547 /// ```ignore
548 /// // D is the number of dimensions of the tensor
549 /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);
550 ///
551 /// // Calculate median for dim 0 since the tensor has become 1 dimensional
552 /// let median = flattened_tensor.median(0);
553 /// // Result: [4.0]
554 /// ```
555 pub fn median(self, dim: usize) -> Self {
556 // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
557 // instead of leveraging a full sort to get the median.
558 stats::median(self, dim)
559 }
560
561 /// Returns the median value along the specified dimension and its index.
562 ///
563 /// The median is not unique for input tensors with an even number of elements
564 /// in the reduced dimension. In this case, the lower of the two medians is returned,
565 /// following PyTorch's behavior.
566 ///
567 /// # Note
568 ///
569 /// The current implementation performs a full sort along the specified dimension,
570 /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
571 /// to CPU for the sort operation, which may result in slower performance compared
572 /// to native GPU operations.
573 ///
574 /// # Arguments
575 ///
576 /// - `dim` - The dimension along which to compute the median.
577 ///
578 /// # Returns
579 ///
580 /// A tuple containing:
581 /// - A tensor with the median values.
582 /// - A tensor with the indices of the median values in the original tensor.
583 ///
584 /// # Example
585 ///
586 /// ```ignore
587 /// // Assuming backend B
588 /// let device = B::Device::default();
589 /// let tensor = Tensor::<B, 2>::from_data(
590 /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
591 /// &device,
592 /// );
593 ///
594 /// // Median along dimension 1:
595 /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
596 /// let (values, indices) = tensor.median_with_indices(1);
597 /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)
598 /// ```
599 pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
600 // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
601 // instead of leveraging a full sort to get the median.
602 stats::median_with_indices(self, dim)
603 }
604
605 /// Converts a tensor to the specified data type.
606 ///
607 /// Supports both within-kind casting (e.g., `FloatDType::F64`) and cross-kind casting
608 /// (e.g., `IntDType::I64` to produce an int tensor).
609 ///
610 /// This is a no-op when casting to the current dtype within the same kind.
611 ///
612 /// # Example
613 ///
614 /// ```rust
615 /// use burn_tensor::backend::Backend;
616 /// use burn_tensor::{Tensor, FloatDType, IntDType};
617 ///
618 /// fn example<B: Backend>() {
619 /// let device = Default::default();
620 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.5], &device);
621 ///
622 /// // Within-kind cast (float to float)
623 /// let f64_tensor = float_tensor.clone().cast(FloatDType::F64);
624 ///
625 /// // Cross-kind cast (float to int)
626 /// let int_tensor = float_tensor.cast(IntDType::I64);
627 /// }
628 /// ```
629 #[must_use]
630 pub fn cast<T: Cast<B, Float>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
631 Tensor::new(T::cast(self.primitive, dtype))
632 }
633
634 /// Detach the current tensor from the autodiff graph.
635 ///
636 /// This function does nothing when autodiff is not enabled.
637 /// This can be used in batchers or elsewhere to ensure that previous operations are not
638 /// considered in the autodiff graph.
639 pub fn detach(self) -> Self {
640 Self::new(TensorPrimitive::Float(B::float_detach(
641 self.primitive.tensor(),
642 )))
643 }
644
645 /// Mark the tensor to keep gradients during the backward pass.
646 ///
647 /// This function does nothing when autodiff is not enabled.
648 pub fn require_grad(self) -> Self {
649 self.set_require_grad(true)
650 }
651
652 /// Returns true if the tensor requires gradients during the backward pass.
653 pub fn is_require_grad(&self) -> bool {
654 match &self.primitive {
655 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
656 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
657 }
658 }
659
660 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
661 /// When tracked, the gradients will be available after the backward pass.
662 ///
663 /// This function does nothing when autodiff is not enabled.
664 pub fn set_require_grad(self, require_grad: bool) -> Self {
665 let primitive = match self.primitive {
666 TensorPrimitive::Float(tensor) => {
667 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
668 }
669 TensorPrimitive::QFloat(tensor) => {
670 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
671 }
672 };
673 Self::new(primitive)
674 }
675
676 /// Applies the relu function to the tensor.
677 pub(crate) fn relu(self) -> Self {
678 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
679 }
680
681 /// Calculate covaraince matrix between different entries alongside a given dimension.
682 ///
683 /// # Arguments
684 ///
685 /// * `size` - The size of the square matrix.
686 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
687 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
688 let n = self.dims()[dim];
689 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
690 centered
691 .clone()
692 .transpose()
693 .matmul(centered)
694 .div_scalar(n as f32 - correction_factor as f32)
695 }
696
697 /// Convert the tensor to a lower precision data type based on the quantization scheme.
698 ///
699 /// # Arguments
700 ///
701 /// * `scheme` - The quantization scheme.
702 /// * `qparams` - The pre-computed quantization parameters.
703 ///
704 /// # Returns
705 ///
706 /// The quantized tensor.
707 pub fn quantize(
708 self,
709 scheme: &QuantScheme,
710 qparams: QuantizationParameters<B>,
711 ) -> Tensor<B, D> {
712 Tensor::new(TensorPrimitive::QFloat(B::quantize(
713 self.primitive.tensor(),
714 scheme,
715 QuantizationParametersPrimitive {
716 scales: qparams.scales.primitive.tensor(),
717 },
718 )))
719 }
720
721 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
722 ///
723 /// # Arguments
724 ///
725 /// * `scheme` - The quantization scheme.
726 ///
727 /// # Returns
728 ///
729 /// The quantized tensor.
730 ///
731 /// # Notes
732 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
733 pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
734 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
735 self.primitive.tensor(),
736 scheme,
737 )))
738 }
739
740 /// Convert the tensor back to a higher precision data type.
741 ///
742 /// If the tensor is not quantized, its value is simply returned.
743 ///
744 /// # Returns
745 ///
746 /// The dequantized tensor.
747 pub fn dequantize(self) -> Tensor<B, D> {
748 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
749 }
750
751 /// Checks element wise if the tensor is close to another tensor.
752 ///
753 /// The tolerance is defined by the following equation:
754 ///
755 /// ```text
756 /// abs(a - b) <= (atol + rtol * abs(b))
757 ///
758 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
759 /// and `atol` is the absolute tolerance.
760 /// ```
761 ///
762 /// # Arguments
763 ///
764 /// * `other` - The tensor to compare with.
765 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
766 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
767 ///
768 /// # Returns
769 ///
770 /// A boolean tensor with the same shape as the input tensors.
771 ///
772 /// # Example
773 ///
774 /// ```rust
775 /// use burn_tensor::backend::Backend;
776 /// use burn_tensor::{Tensor, Shape};
777 ///
778 /// fn example<B: Backend>() {
779 /// let device = B::Device::default();
780 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
781 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
782 /// let tensor = tensor1.is_close(tensor2, None, None);
783 /// println!("{tensor}");
784 /// // [[true, true, true], [true, true, true]]
785 /// }
786 /// ```
787 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
788 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
789 let atol = atol.unwrap_or(DEFAULT_ATOL);
790
791 // check finite difference is close
792 let is_close_finite_val = self
793 .clone()
794 .sub(other.clone())
795 .abs()
796 .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
797 .bool_and(self.clone().is_finite())
798 .bool_and(other.clone().is_finite());
799
800 // check if both are infinite and have same sign
801 let inf_same_sign = self
802 .clone()
803 .is_finite()
804 .bool_not()
805 .bool_and(other.clone().is_finite().bool_not())
806 .bool_and(self.equal(other));
807
808 is_close_finite_val.bool_or(inf_same_sign)
809 }
810
811 /// Checks if all elements are close to another tensor.
812 ///
813 /// The tolerance is defined by the following equation:
814 ///
815 /// ```text
816 ///
817 /// abs(a - b) <= (atol + rtol * abs(b))
818 ///
819 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
820 /// and `atol` is the absolute tolerance.
821 ///
822 /// ```
823 ///
824 /// # Arguments
825 ///
826 /// * `other` - The tensor to compare with.
827 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
828 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
829 ///
830 /// # Returns
831 ///
832 /// A boolean scalar.
833 ///
834 /// # Remarks
835 ///
836 /// # Example
837 ///
838 /// ```rust
839 /// use burn_tensor::backend::Backend;
840 /// use burn_tensor::{Tensor, Shape};
841 ///
842 /// fn example<B: Backend>() {
843 /// let device = B::Device::default();
844 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
845 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
846 /// let result = tensor1.all_close(tensor2, None, None);
847 /// println!("{}", result);
848 /// // true
849 /// }
850 /// ```
851 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
852 self.is_close(other, rtol, atol)
853 .all()
854 .into_scalar()
855 .to_bool()
856 }
857
858 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
859 ///
860 /// # Returns
861 ///
862 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
863 ///
864 /// # Example
865 ///
866 /// ```rust
867 /// use burn_tensor::backend::Backend;
868 /// use burn_tensor::{Tensor, Bool, Shape};
869 ///
870 /// fn example<B: Backend>() {
871 /// let device = B::Device::default();
872 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
873 /// let tensor = tensor.is_nan();
874 /// println!("{tensor}");
875 /// // [[false, true, false], [false, false, false]]
876 /// }
877 /// ```
878 pub fn is_nan(self) -> Tensor<B, D, Bool> {
879 let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
880 Tensor::new(B::float_is_nan(self.primitive.tensor(), out_dtype))
881 }
882
883 /// Checks if the tensor contains any NaN values.
884 ///
885 /// # Returns
886 ///
887 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
888 ///
889 /// # Example
890 ///
891 /// ```rust
892 /// use burn_tensor::backend::Backend;
893 /// use burn_tensor::{Tensor, Bool, Shape};
894 ///
895 /// fn example<B: Backend>() {
896 /// let device = B::Device::default();
897 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
898 /// let tensor = tensor.contains_nan();
899 /// println!("{tensor}");
900 /// // [true]
901 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
902 /// let tensor = tensor.contains_nan();
903 /// println!("{tensor}");
904 /// // [false]
905 /// }
906 /// ```
907 pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
908 // Summing the tensor will result in NaN if the tensor contains any NaN values
909 // This is faster than checking each element individually
910 // because it rolls up the NaN values into a single value
911 let sum = self.sum();
912
913 sum.is_nan()
914 }
915
916 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
917 ///
918 /// # Returns
919 ///
920 /// A boolean tensor where `true` indicates that the value is infinite
921 ///
922 /// # Example
923 ///
924 /// ```rust
925 /// use burn_tensor::backend::Backend;
926 /// use burn_tensor::{Tensor, Bool, Shape};
927 ///
928 /// fn example<B: Backend>() {
929 /// let device = B::Device::default();
930 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
931 /// let tensor = tensor.is_finite();
932 /// println!("{tensor}");
933 /// // [[false, true, false], [false, false, false]]
934 /// }
935 /// ```
936 pub fn is_inf(self) -> Tensor<B, D, Bool> {
937 let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
938 Tensor::new(B::float_is_inf(self.primitive.tensor(), out_dtype))
939 }
940
941 /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
942 ///
943 /// # Returns
944 ///
945 /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
946 /// either INF, -INF or NAN
947 ///
948 /// # Example
949 ///
950 /// ```rust
951 /// use burn_tensor::backend::Backend;
952 /// use burn_tensor::{Tensor, Bool, Shape};
953 ///
954 /// fn example<B: Backend>() {
955 /// let device = B::Device::default();
956 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
957 /// let tensor = tensor.is_finite();
958 /// println!("{tensor}");
959 /// // [[true, false, true], [false, true, true]]
960 /// }
961 /// ```
962 pub fn is_finite(self) -> Tensor<B, D, Bool> {
963 self.clone()
964 .is_nan()
965 .bool_not()
966 .bool_and(self.is_inf().bool_not())
967 }
968
969 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
970 /// using the given locations in [-1, 1].
971 ///
972 /// # Arguments
973 ///
974 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
975 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
976 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
977 ///
978 /// # Returns
979 ///
980 /// A tensor with shape (N, C, H_out, W_out)
981 ///
982 /// # Example
983 ///
984 /// ```ignore
985 /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
986 ///
987 /// // Default options (bilinear, zeros padding, align_corners=false)
988 /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
989 ///
990 /// // Custom options
991 /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
992 /// .with_padding_mode(GridSamplePaddingMode::Border)
993 /// .with_align_corners(true);
994 /// let output = tensor.grid_sample_2d(grid, options);
995 /// ```
996 pub fn grid_sample_2d(
997 self,
998 grid: Tensor<B, D>,
999 options: impl Into<GridSampleOptions>,
1000 ) -> Tensor<B, D> {
1001 Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
1002 self.primitive.tensor(),
1003 grid.primitive.tensor(),
1004 options.into(),
1005 )))
1006 }
1007
1008 /// Computes the cross product of `self` and another tensor along a given dimension.
1009 ///
1010 /// Both `self` and `other` **must have size 3** along the specified `dim`,
1011 /// because the cross product is only defined in three-dimensional space.
1012 ///
1013 /// # Arguments
1014 ///
1015 /// * `other` - The other tensor to take the cross product with.
1016 /// * `dim` - The dimension along which to compute the cross product.
1017 ///
1018 /// # Returns
1019 ///
1020 /// A tensor containing the cross product of `self` and `other` along `dim`.
1021 pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
1022 let dim = dim.expect_dim_index(D);
1023 check!(TensorCheck::cross(&self, &other, dim));
1024 Tensor::new(TensorPrimitive::Float(B::float_cross(
1025 self.primitive.tensor(),
1026 other.primitive.tensor(),
1027 dim,
1028 )))
1029 }
1030
1031 /// Applies element wise power operation with a float Tensor
1032 ///
1033 /// # Arguments
1034 ///
1035 /// * `other` - The tensor to apply the power operation with.
1036 ///
1037 /// # Example
1038 ///
1039 /// ```rust
1040 /// use burn_tensor::backend::Backend;
1041 /// use burn_tensor::{Tensor, Shape};
1042 ///
1043 /// fn example<B: Backend>() {
1044 /// let device = B::Device::default();
1045 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1046 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1047 /// let tensor = tensor1.powf(tensor2);
1048 /// println!("{tensor}");
1049 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1050 /// }
1051 /// ```
1052 pub fn powf(self, other: Self) -> Self {
1053 let primitive = match (self.primitive, other.primitive) {
1054 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
1055 TensorPrimitive::Float(B::float_powf(lhs, rhs))
1056 }
1057 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs),
1058 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
1059 let dtype = rhs.dtype();
1060 TensorPrimitive::Float(B::float_powf(B::dequantize(lhs, dtype.into()), rhs))
1061 }
1062 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
1063 let dtype = lhs.dtype();
1064 TensorPrimitive::Float(B::float_powf(lhs, B::dequantize(rhs, dtype.into())))
1065 }
1066 };
1067
1068 Tensor::new(primitive)
1069 }
1070
1071 /// Applies element wise power operation with a float scalar
1072 ///
1073 /// # Arguments
1074 ///
1075 /// * `other` - The scalar to apply the power operation with.
1076 ///
1077 /// # Example
1078 ///
1079 /// ```rust
1080 /// use burn_tensor::backend::Backend;
1081 /// use burn_tensor::{Tensor, Shape};
1082 ///
1083 /// fn example<B: Backend>() {
1084 /// let device = B::Device::default();
1085 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1086 /// let tensor = tensor.powf_scalar(2.0);
1087 /// println!("{tensor}");
1088 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1089 /// }
1090 /// ```
1091 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1092 let rhs = Scalar::new(other, &self.dtype());
1093
1094 let primitive = match self.primitive {
1095 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
1096 TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
1097 };
1098
1099 Tensor::new(primitive)
1100 }
1101}
1102
1103impl<const D: usize, B: Backend> Tensor<B, D> {
1104 /// Draws samples from a categorical distribution defined by the last dimension
1105 /// of the input tensor.
1106 ///
1107 /// The last dimension is treated as a (possibly unnormalized) set of weights
1108 /// defining a categorical distribution over categories. All leading dimensions
1109 /// are treated as batch dimensions. The method returns integer indices of the
1110 /// sampled categories.
1111 ///
1112 /// # Arguments
1113 ///
1114 /// * `num_samples` - Number of samples to draw per distribution. Must be >= 1.
1115 ///
1116 /// # Panics
1117 ///
1118 /// Panics if `num_samples` is 0.
1119 ///
1120 /// # Note
1121 ///
1122 /// Distributions with all-zero weights produce undefined (NaN-based) sampling
1123 /// results. Callers should ensure each distribution has at least one positive
1124 /// weight.
1125 ///
1126 /// # Returns
1127 ///
1128 /// An integer tensor with the same shape as the input, except the last dimension
1129 /// is replaced by `num_samples`, containing sampled category indices in
1130 /// `[0, num_categories)`.
1131 ///
1132 /// # Example
1133 ///
1134 /// ```rust
1135 /// use burn_tensor::backend::Backend;
1136 /// use burn_tensor::Tensor;
1137 ///
1138 /// fn example<B: Backend>() {
1139 /// let device = B::Device::default();
1140 /// let probs = Tensor::<B, 2>::from_floats(
1141 /// [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
1142 /// &device,
1143 /// );
1144 /// let samples = probs.categorical(4);
1145 /// // First row always samples index 1, second row always samples index 2
1146 /// println!("{samples}");
1147 /// }
1148 /// ```
1149 pub fn categorical(self, num_samples: usize) -> Tensor<B, D, Int> {
1150 assert!(num_samples > 0, "categorical: num_samples must be >= 1");
1151
1152 let shape = self.shape();
1153 let num_categories = shape[D - 1];
1154 let batch_size = (shape.num_elements() / num_categories).max(1);
1155 let device = self.device();
1156
1157 // Flatten leading dimensions into a single batch dimension: [batch, categories]
1158 let flat: Tensor<B, 2> = self.reshape([batch_size, num_categories]);
1159
1160 // Normalize weights to probabilities
1161 let sum = flat.clone().sum_dim(1); // [batch, 1]
1162 let probs = flat / sum;
1163
1164 // Cumulative sum along categories dimension
1165 let cumsum = probs.cumsum(1); // [batch, categories]
1166
1167 // Uniform random values for each sample
1168 let uniform = Tensor::<B, 2>::random(
1169 [batch_size, num_samples],
1170 Distribution::Uniform(0.0, 1.0),
1171 &device,
1172 ); // [batch, num_samples]
1173
1174 // Expand dimensions for broadcasting:
1175 // cumsum: [batch, categories, 1]
1176 // uniform: [batch, 1, num_samples]
1177 let cumsum_3d: Tensor<B, 3> = cumsum.unsqueeze_dim(2);
1178 let uniform_3d: Tensor<B, 3> = uniform.unsqueeze_dim(1);
1179
1180 // Count categories where cumsum < uniform (inverse CDF)
1181 let mask: Tensor<B, 3, Bool> = cumsum_3d.lower(uniform_3d);
1182 let indices: Tensor<B, 2, Int> = mask.int().sum_dim(1).squeeze_dim::<2>(1);
1183
1184 // Clamp to valid range to guard against floating-point imprecision in cumsum
1185 let indices = indices.clamp(0, num_categories as i64 - 1);
1186
1187 // Reshape back to [...leading_dims, num_samples]
1188 let mut out_shape = shape;
1189 out_shape[D - 1] = num_samples;
1190 indices.reshape(out_shape)
1191 }
1192}
1193
1194#[cfg(feature = "distributed")]
1195impl<const D: usize, B> Tensor<B, D>
1196where
1197 B: AutodiffBackend,
1198{
1199 /// Returns true if the tensor is marked as distributed.
1200 pub fn is_distributed(&self) -> bool {
1201 match &self.primitive {
1202 TensorPrimitive::Float(tensor) => B::is_distributed(tensor),
1203 TensorPrimitive::QFloat(_) => unimplemented!(),
1204 }
1205 }
1206
1207 /// Mark the tensor as distributed.
1208 ///
1209 /// This function does nothing when autodiff or distributed is not enabled.
1210 pub fn set_distributed(self, param_id: DistributedParamId) -> Self {
1211 let primitive = match self.primitive {
1212 TensorPrimitive::Float(tensor) => {
1213 TensorPrimitive::Float(B::set_distributed_params(tensor, param_id))
1214 }
1215 TensorPrimitive::QFloat(_) => unimplemented!(),
1216 };
1217 Self::new(primitive)
1218 }
1219}