Skip to main content

burn_tensor/tensor/api/
float.rs

1use crate::AsIndex;
2use crate::Cast;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Float, Int, TensorPrimitive};
13#[cfg(feature = "distributed")]
14use burn_backend::AutodiffBackend;
15use burn_backend::ElementConversion;
16use burn_backend::Scalar;
17use burn_backend::TensorMetadata;
18#[cfg(feature = "distributed")]
19use burn_backend::distributed::DistributedParamId;
20use burn_backend::get_device_settings;
21use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
22use core::f32;
23
24/// Default RTOL value for `is_close` and `all_close`.
25pub const DEFAULT_RTOL: f64 = 1e-5;
26
27/// Default ATOL value for `is_close` and `all_close`.
28pub const DEFAULT_ATOL: f64 = 1e-8;
29
30impl<const D: usize, B> Tensor<B, D>
31where
32    B: Backend,
33{
34    /// Applies element wise exponential operation.
35    ///
36    #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
37    #[cfg_attr(not(doc), doc = "`y = e^x`")]
38    pub fn exp(self) -> Self {
39        Self::new(TensorPrimitive::Float(B::float_exp(
40            self.primitive.tensor(),
41        )))
42    }
43
44    /// Applies element wise natural log operation *ln*.
45    ///
46    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
47    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
48    pub fn log(self) -> Self {
49        Self::new(TensorPrimitive::Float(B::float_log(
50            self.primitive.tensor(),
51        )))
52    }
53
54    /// Applies the natural logarithm of one plus the input tensor, element-wise.
55    ///
56    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
57    #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
58    pub fn log1p(self) -> Self {
59        Self::new(TensorPrimitive::Float(B::float_log1p(
60            self.primitive.tensor(),
61        )))
62    }
63
64    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
65    ///
66    #[cfg_attr(
67        doc,
68        doc = r#"
69$y_i = \text{erf}\(x_i\)$
70
71The error function is defined as:
72
73$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
74"#
75    )]
76    #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
77    pub fn erf(self) -> Self {
78        Self::new(TensorPrimitive::Float(B::float_erf(
79            self.primitive.tensor(),
80        )))
81    }
82
83    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
84    /// (or multiplicative inverse) element wise.
85    ///
86    #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
87    #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
88    pub fn recip(self) -> Self {
89        Self::new(TensorPrimitive::Float(B::float_recip(
90            self.primitive.tensor(),
91        )))
92    }
93
94    /// Applies element wise square operation.
95    ///
96    #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
97    #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
98    pub fn square(self) -> Self {
99        self.powi_scalar(2)
100    }
101
102    /// Applies element wise root square operation.
103    ///
104    #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
105    #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
106    pub fn sqrt(self) -> Self {
107        Self::new(TensorPrimitive::Float(B::float_sqrt(
108            self.primitive.tensor(),
109        )))
110    }
111
112    /// Applies element wise cosine operation.
113    ///
114    #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
115    #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
116    pub fn cos(self) -> Self {
117        Self::new(TensorPrimitive::Float(B::float_cos(
118            self.primitive.tensor(),
119        )))
120    }
121
122    /// Applies element wise sine operation.
123    ///
124    #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
125    #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
126    pub fn sin(self) -> Self {
127        Self::new(TensorPrimitive::Float(B::float_sin(
128            self.primitive.tensor(),
129        )))
130    }
131
132    /// Applies element wise tangent operation.
133    ///
134    #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
135    #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
136    pub fn tan(self) -> Self {
137        Self::new(TensorPrimitive::Float(B::float_tan(
138            self.primitive.tensor(),
139        )))
140    }
141
142    /// Applies element wise hyperbolic cosine operation.
143    ///
144    #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
145    #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
146    ///
147    /// # Example
148    ///
149    /// ```rust
150    /// use burn_tensor::backend::Backend;
151    /// use burn_tensor::Tensor;
152    ///
153    /// fn example<B: Backend>() {
154    ///     let device = Default::default();
155    ///
156    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
157    ///     println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
158    /// }
159    /// ```
160    pub fn cosh(self) -> Self {
161        Self::new(TensorPrimitive::Float(B::float_cosh(
162            self.primitive.tensor(),
163        )))
164    }
165
166    /// Applies element wise hyperbolic sine operation.
167    ///
168    #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
169    #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
170    ///
171    /// # Example
172    ///
173    /// ```rust
174    /// use burn_tensor::backend::Backend;
175    /// use burn_tensor::Tensor;
176    ///
177    /// fn example<B: Backend>() {
178    ///     let device = Default::default();
179    ///
180    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
181    ///     println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
182    /// }
183    /// ```
184    pub fn sinh(self) -> Self {
185        Self::new(TensorPrimitive::Float(B::float_sinh(
186            self.primitive.tensor(),
187        )))
188    }
189
190    /// Applies element wise hyperbolic tangent operation.
191    ///
192    #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
193    #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
194    ///
195    /// # Example
196    ///
197    /// ```rust
198    /// use burn_tensor::backend::Backend;
199    /// use burn_tensor::Tensor;
200    ///
201    /// fn example<B: Backend>() {
202    ///     let device = Default::default();
203    ///
204    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
205    ///     println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
206    /// }
207    /// ```
208    pub fn tanh(self) -> Self {
209        Self::new(TensorPrimitive::Float(B::float_tanh(
210            self.primitive.tensor(),
211        )))
212    }
213
214    /// Applies element wise inverse sine operation.
215    ///
216    #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
217    #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
218    ///
219    /// # Example
220    ///
221    /// ```rust
222    /// use burn_tensor::backend::Backend;
223    /// use burn_tensor::Tensor;
224    ///
225    /// fn example<B: Backend>() {
226    ///     let device = Default::default();
227    ///
228    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
229    ///     println!("{}", tensor.asin()); // [ 0.0000, -1.5708,  1.5708]
230    /// }
231    /// ```
232    pub fn asin(self) -> Self {
233        Self::new(TensorPrimitive::Float(B::float_asin(
234            self.primitive.tensor(),
235        )))
236    }
237
238    /// Applies element wise inverse hyperbolic sine operation.
239    ///
240    #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
241    #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
242    ///
243    /// # Example
244    ///
245    /// ```rust
246    /// use burn_tensor::backend::Backend;
247    /// use burn_tensor::Tensor;
248    ///
249    /// fn example<B: Backend>() {
250    ///     let device = Default::default();
251    ///
252    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
253    ///     println!("{}", tensor.asinh()); // [ 0.0000, -0.8814,  0.8814]
254    /// }
255    /// ```
256    pub fn asinh(self) -> Self {
257        Self::new(TensorPrimitive::Float(B::float_asinh(
258            self.primitive.tensor(),
259        )))
260    }
261
262    /// Applies element wise inverse cosine operation.
263    ///
264    #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
265    #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
266    ///
267    /// # Example
268    ///
269    /// ```rust
270    /// use burn_tensor::backend::Backend;
271    /// use burn_tensor::Tensor;
272    ///
273    /// fn example<B: Backend>() {
274    ///     let device = Default::default();
275    ///
276    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
277    ///     println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
278    /// }
279    /// ```
280    pub fn acos(self) -> Self {
281        Self::new(TensorPrimitive::Float(B::float_acos(
282            self.primitive.tensor(),
283        )))
284    }
285
286    /// Applies element wise inverse hyperbolic cosine operation.
287    ///
288    #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
289    #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
290    ///
291    /// # Example
292    ///
293    /// ```rust
294    /// use burn_tensor::backend::Backend;
295    /// use burn_tensor::Tensor;
296    ///
297    /// fn example<B: Backend>() {
298    ///     let device = Default::default();
299    ///
300    ///     let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
301    ///     println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627]
302    /// }
303    /// ```
304    pub fn acosh(self) -> Self {
305        Self::new(TensorPrimitive::Float(B::float_acosh(
306            self.primitive.tensor(),
307        )))
308    }
309
310    /// Applies element wise inverse tangent operation.
311    ///
312    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
313    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
314    ///
315    /// # Example
316    ///
317    /// ```rust
318    /// use burn_tensor::backend::Backend;
319    /// use burn_tensor::Tensor;
320    ///
321    /// fn example<B: Backend>() {
322    ///     let device = Default::default();
323    ///
324    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
325    ///     println!("{}", tensor.sinh()); // [ 0.0, -0.7854,  1.1071]
326    /// }
327    /// ```
328    pub fn atan(self) -> Self {
329        Self::new(TensorPrimitive::Float(B::float_atan(
330            self.primitive.tensor(),
331        )))
332    }
333
334    /// Applies element wise inverse hyperbolic tangent operation.
335    ///
336    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
337    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
338    ///
339    /// # Example
340    ///
341    /// ```rust
342    /// use burn_tensor::backend::Backend;
343    /// use burn_tensor::Tensor;
344    ///
345    /// fn example<B: Backend>() {
346    ///     let device = Default::default();
347    ///
348    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
349    ///     println!("{}", tensor.sinh()); // [ 0.0, -0.5493,  0.5493]
350    /// }
351    /// ```
352    pub fn atanh(self) -> Self {
353        Self::new(TensorPrimitive::Float(B::float_atanh(
354            self.primitive.tensor(),
355        )))
356    }
357
358    /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
359    ///
360    #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
361    #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
362    ///
363    /// # Example
364    ///
365    /// ```rust
366    /// use burn_tensor::backend::Backend;
367    /// use burn_tensor::Tensor;
368    ///
369    /// fn example<B: Backend>() {
370    ///     let device = Default::default();
371    ///
372    ///     let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
373    ///     let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
374    ///     println!("{}", lhs.atan2(rhs)); // [-1.1071,  2.0344, -2.0344]
375    /// }
376    /// ```
377    pub fn atan2(self, other: Self) -> Self {
378        Self::new(TensorPrimitive::Float(B::float_atan2(
379            self.primitive.tensor(),
380            other.primitive.tensor(),
381        )))
382    }
383
384    /// Converts each of the elements of the input tensor from angles in degrees to radians.
385    ///
386    /// # Example
387    /// ```ignore
388    /// let tensor_in_radians = tensor.deg2rad();
389    /// ```
390    pub fn deg2rad(self) -> Self {
391        self.mul_scalar(f32::consts::PI / 180.0)
392    }
393
394    /// Converts each of the elements of the input tensor from angles in radians to degrees.
395    ///
396    /// # Example
397    /// ```ignore
398    /// let tensor_in_degrees = tensor.rad2deg();
399    /// ```
400    pub fn rad2deg(self) -> Self {
401        self.mul_scalar(180.0 / f32::consts::PI)
402    }
403
404    /// Applies element wise round operation.
405    ///
406    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
407    /// strategy, with halfway cases rounded to the nearest even integer value.
408    pub fn round(self) -> Self {
409        Self::new(TensorPrimitive::Float(B::float_round(
410            self.primitive.tensor(),
411        )))
412    }
413
414    /// Applies element wise floor operation.
415    pub fn floor(self) -> Self {
416        Self::new(TensorPrimitive::Float(B::float_floor(
417            self.primitive.tensor(),
418        )))
419    }
420
421    /// Applies element wise ceil operation.
422    pub fn ceil(self) -> Self {
423        Self::new(TensorPrimitive::Float(B::float_ceil(
424            self.primitive.tensor(),
425        )))
426    }
427
428    /// Create a tensor from floats (f32) on a given device.
429    ///
430    /// # Example
431    ///
432    /// ```rust
433    /// use burn_tensor::backend::Backend;
434    /// use burn_tensor::Tensor;
435    ///
436    /// fn example<B: Backend>() {
437    ///     let device = B::Device::default();
438    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
439    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
440    /// }
441    /// ```
442    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
443        Self::from_data(floats.into().convert::<f32>(), device)
444    }
445
446    /// Returns a new tensor with the same shape and device as the current tensor and the data
447    /// cast to Integer.
448    ///
449    /// # Example
450    ///
451    /// ```rust
452    /// use burn_tensor::backend::Backend;
453    /// use burn_tensor::Tensor;
454    ///
455    /// fn example<B: Backend>() {
456    ///     let device = Default::default();
457    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
458    ///     let int_tensor = float_tensor.int();
459    /// }
460    /// ```
461    pub fn int(self) -> Tensor<B, D, Int> {
462        let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
463        Tensor::new(B::float_into_int(self.primitive.tensor(), out_dtype))
464    }
465
466    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
467    /// values sampled from the given distribution.
468    pub fn random_like(&self, distribution: Distribution) -> Self {
469        Self::new(TensorPrimitive::Float(B::float_random(
470            self.shape(),
471            distribution,
472            &self.device(),
473            self.dtype().into(),
474        )))
475    }
476
477    /// Calculate the variance along the given dimension.
478    pub fn var(self, dim: usize) -> Self {
479        stats::var(self, dim)
480    }
481
482    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
483    pub fn var_bias(self, dim: usize) -> Self {
484        stats::var_bias(self, dim)
485    }
486
487    /// Calculate the variance along the given dimension and also returns the mean.
488    pub fn var_mean(self, dim: usize) -> (Self, Self) {
489        let mean = self.clone().mean_dim(dim);
490        let var = stats::var_with_mean(self, mean.clone(), dim);
491        (var, mean)
492    }
493
494    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
495    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
496        let mean = self.clone().mean_dim(dim);
497        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
498        (var, mean)
499    }
500
501    /// Returns the median value along the specified dimension.
502    ///
503    /// The median is not unique for input tensors with an even number of elements
504    /// in the reduced dimension. In this case, the lower of the two medians is returned,
505    /// following PyTorch's behavior.
506    ///
507    /// # Note
508    ///
509    /// The current implementation performs a full sort along the specified dimension,
510    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
511    /// to CPU for the sort operation, which may result in slower performance compared
512    /// to native GPU operations.
513    ///
514    /// # Arguments
515    ///
516    /// - `dim` - The dimension along which to compute the median.
517    ///
518    /// # Returns
519    ///
520    /// - A tensor containing the median values along the specified dimension.
521    ///
522    /// # Example 1
523    ///
524    /// ```ignore
525    /// // Assuming backend B
526    /// let device = B::Device::default();
527    /// let tensor = Tensor::<B, 2>::from_data(
528    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
529    ///     &device,
530    /// );
531    ///
532    /// // Median along dimension 0:
533    /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]
534    /// let median = tensor.median(0);
535    /// // Result: [[1.0, 4.0, 3.0, 2.0]]
536    ///
537    /// // Median along dimension 1:
538    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
539    /// let median = tensor.median(1);
540    /// // Result: [[2.0], [6.0]]
541    /// ```
542    ///
543    /// # Example 2
544    ///
545    /// The median across all elements can be calculated as follows:
546    ///
547    /// ```ignore
548    /// // D is the number of dimensions of the tensor
549    /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);
550    ///
551    /// // Calculate median for dim 0 since the tensor has become 1 dimensional
552    /// let median = flattened_tensor.median(0);
553    /// // Result: [4.0]
554    /// ```
555    pub fn median(self, dim: usize) -> Self {
556        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
557        // instead of leveraging a full sort to get the median.
558        stats::median(self, dim)
559    }
560
561    /// Returns the median value along the specified dimension and its index.
562    ///
563    /// The median is not unique for input tensors with an even number of elements
564    /// in the reduced dimension. In this case, the lower of the two medians is returned,
565    /// following PyTorch's behavior.
566    ///
567    /// # Note
568    ///
569    /// The current implementation performs a full sort along the specified dimension,
570    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
571    /// to CPU for the sort operation, which may result in slower performance compared
572    /// to native GPU operations.
573    ///
574    /// # Arguments
575    ///
576    /// - `dim` - The dimension along which to compute the median.
577    ///
578    /// # Returns
579    ///
580    /// A tuple containing:
581    /// - A tensor with the median values.
582    /// - A tensor with the indices of the median values in the original tensor.
583    ///
584    /// # Example
585    ///
586    /// ```ignore
587    /// // Assuming backend B
588    /// let device = B::Device::default();
589    /// let tensor = Tensor::<B, 2>::from_data(
590    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
591    ///     &device,
592    /// );
593    ///
594    /// // Median along dimension 1:
595    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
596    /// let (values, indices) = tensor.median_with_indices(1);
597    /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)
598    /// ```
599    pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
600        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
601        // instead of leveraging a full sort to get the median.
602        stats::median_with_indices(self, dim)
603    }
604
605    /// Converts a tensor to the specified data type.
606    ///
607    /// Supports both within-kind casting (e.g., `FloatDType::F64`) and cross-kind casting
608    /// (e.g., `IntDType::I64` to produce an int tensor).
609    ///
610    /// This is a no-op when casting to the current dtype within the same kind.
611    ///
612    /// # Example
613    ///
614    /// ```rust
615    /// use burn_tensor::backend::Backend;
616    /// use burn_tensor::{Tensor, FloatDType, IntDType};
617    ///
618    /// fn example<B: Backend>() {
619    ///     let device = Default::default();
620    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.5], &device);
621    ///
622    ///     // Within-kind cast (float to float)
623    ///     let f64_tensor = float_tensor.clone().cast(FloatDType::F64);
624    ///
625    ///     // Cross-kind cast (float to int)
626    ///     let int_tensor = float_tensor.cast(IntDType::I64);
627    /// }
628    /// ```
629    #[must_use]
630    pub fn cast<T: Cast<B, Float>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
631        Tensor::new(T::cast(self.primitive, dtype))
632    }
633
634    /// Detach the current tensor from the autodiff graph.
635    ///
636    /// This function does nothing when autodiff is not enabled.
637    /// This can be used in batchers or elsewhere to ensure that previous operations are not
638    /// considered in the autodiff graph.
639    pub fn detach(self) -> Self {
640        Self::new(TensorPrimitive::Float(B::float_detach(
641            self.primitive.tensor(),
642        )))
643    }
644
645    /// Mark the tensor to keep gradients during the backward pass.
646    ///
647    /// This function does nothing when autodiff is not enabled.
648    pub fn require_grad(self) -> Self {
649        self.set_require_grad(true)
650    }
651
652    /// Returns true if the tensor requires gradients during the backward pass.
653    pub fn is_require_grad(&self) -> bool {
654        match &self.primitive {
655            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
656            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
657        }
658    }
659
660    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
661    /// When tracked, the gradients will be available after the backward pass.
662    ///
663    /// This function does nothing when autodiff is not enabled.
664    pub fn set_require_grad(self, require_grad: bool) -> Self {
665        let primitive = match self.primitive {
666            TensorPrimitive::Float(tensor) => {
667                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
668            }
669            TensorPrimitive::QFloat(tensor) => {
670                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
671            }
672        };
673        Self::new(primitive)
674    }
675
676    /// Applies the relu function to the tensor.
677    pub(crate) fn relu(self) -> Self {
678        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
679    }
680
681    /// Calculate covaraince matrix between different entries alongside a given dimension.
682    ///
683    /// # Arguments
684    ///
685    /// * `size` - The size of the square matrix.
686    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
687    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
688        let n = self.dims()[dim];
689        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
690        centered
691            .clone()
692            .transpose()
693            .matmul(centered)
694            .div_scalar(n as f32 - correction_factor as f32)
695    }
696
697    /// Convert the tensor to a lower precision data type based on the quantization scheme.
698    ///
699    /// # Arguments
700    ///
701    /// * `scheme` - The quantization scheme.
702    /// * `qparams` - The pre-computed quantization parameters.
703    ///
704    /// # Returns
705    ///
706    /// The quantized tensor.
707    pub fn quantize(
708        self,
709        scheme: &QuantScheme,
710        qparams: QuantizationParameters<B>,
711    ) -> Tensor<B, D> {
712        Tensor::new(TensorPrimitive::QFloat(B::quantize(
713            self.primitive.tensor(),
714            scheme,
715            QuantizationParametersPrimitive {
716                scales: qparams.scales.primitive.tensor(),
717            },
718        )))
719    }
720
721    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
722    ///
723    /// # Arguments
724    ///
725    /// * `scheme` - The quantization scheme.
726    ///
727    /// # Returns
728    ///
729    /// The quantized tensor.
730    ///
731    /// # Notes
732    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
733    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
734        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
735            self.primitive.tensor(),
736            scheme,
737        )))
738    }
739
740    /// Convert the tensor back to a higher precision data type.
741    ///
742    /// If the tensor is not quantized, its value is simply returned.
743    ///
744    /// # Returns
745    ///
746    /// The dequantized tensor.
747    pub fn dequantize(self) -> Tensor<B, D> {
748        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
749    }
750
751    /// Checks element wise if the tensor is close to another tensor.
752    ///
753    /// The tolerance is defined by the following equation:
754    ///
755    /// ```text
756    /// abs(a - b) <= (atol + rtol * abs(b))
757    ///
758    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
759    /// and `atol` is the absolute tolerance.
760    /// ```
761    ///
762    /// # Arguments
763    ///
764    /// * `other` - The tensor to compare with.
765    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
766    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
767    ///
768    /// # Returns
769    ///
770    /// A boolean tensor with the same shape as the input tensors.
771    ///
772    /// # Example
773    ///
774    /// ```rust
775    /// use burn_tensor::backend::Backend;
776    /// use burn_tensor::{Tensor, Shape};
777    ///
778    /// fn example<B: Backend>() {
779    ///    let device = B::Device::default();
780    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
781    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
782    ///    let tensor = tensor1.is_close(tensor2, None, None);
783    ///    println!("{tensor}");
784    ///    // [[true, true, true], [true, true, true]]
785    /// }
786    /// ```
787    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
788        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
789        let atol = atol.unwrap_or(DEFAULT_ATOL);
790
791        // check finite difference is close
792        let is_close_finite_val = self
793            .clone()
794            .sub(other.clone())
795            .abs()
796            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
797            .bool_and(self.clone().is_finite())
798            .bool_and(other.clone().is_finite());
799
800        // check if both are infinite and have same sign
801        let inf_same_sign = self
802            .clone()
803            .is_finite()
804            .bool_not()
805            .bool_and(other.clone().is_finite().bool_not())
806            .bool_and(self.equal(other));
807
808        is_close_finite_val.bool_or(inf_same_sign)
809    }
810
811    /// Checks if all elements are close to another tensor.
812    ///
813    /// The tolerance is defined by the following equation:
814    ///
815    /// ```text
816    ///
817    /// abs(a - b) <= (atol + rtol * abs(b))
818    ///
819    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
820    /// and `atol` is the absolute tolerance.
821    ///
822    /// ```
823    ///
824    /// # Arguments
825    ///
826    /// * `other` - The tensor to compare with.
827    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
828    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
829    ///
830    /// # Returns
831    ///
832    /// A boolean scalar.
833    ///
834    /// # Remarks
835    ///
836    /// # Example
837    ///
838    /// ```rust
839    /// use burn_tensor::backend::Backend;
840    /// use burn_tensor::{Tensor, Shape};
841    ///
842    /// fn example<B: Backend>() {
843    ///    let device = B::Device::default();
844    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
845    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
846    ///    let result = tensor1.all_close(tensor2, None, None);
847    ///    println!("{}", result);
848    ///    // true
849    /// }
850    /// ```
851    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
852        self.is_close(other, rtol, atol)
853            .all()
854            .into_scalar()
855            .to_bool()
856    }
857
858    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
859    ///
860    /// # Returns
861    ///
862    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
863    ///
864    /// # Example
865    ///
866    /// ```rust
867    /// use burn_tensor::backend::Backend;
868    /// use burn_tensor::{Tensor, Bool, Shape};
869    ///
870    /// fn example<B: Backend>() {
871    ///    let device = B::Device::default();
872    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
873    ///    let tensor = tensor.is_nan();
874    ///    println!("{tensor}");
875    ///    // [[false, true, false], [false, false, false]]
876    /// }
877    /// ```
878    pub fn is_nan(self) -> Tensor<B, D, Bool> {
879        let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
880        Tensor::new(B::float_is_nan(self.primitive.tensor(), out_dtype))
881    }
882
883    /// Checks if the tensor contains any NaN values.
884    ///
885    /// # Returns
886    ///
887    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
888    ///
889    /// # Example
890    ///
891    /// ```rust
892    /// use burn_tensor::backend::Backend;
893    /// use burn_tensor::{Tensor, Bool, Shape};
894    ///
895    /// fn example<B: Backend>() {
896    ///   let device = B::Device::default();
897    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
898    ///   let tensor = tensor.contains_nan();
899    ///   println!("{tensor}");
900    ///   // [true]
901    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
902    ///   let tensor = tensor.contains_nan();
903    ///   println!("{tensor}");
904    ///   // [false]
905    /// }
906    /// ```
907    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
908        // Summing the tensor will result in NaN if the tensor contains any NaN values
909        // This is faster than checking each element individually
910        // because it rolls up the NaN values into a single value
911        let sum = self.sum();
912
913        sum.is_nan()
914    }
915
916    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
917    ///
918    /// # Returns
919    ///
920    /// A boolean tensor where `true` indicates that the value is infinite
921    ///
922    /// # Example
923    ///
924    /// ```rust
925    /// use burn_tensor::backend::Backend;
926    /// use burn_tensor::{Tensor, Bool, Shape};
927    ///
928    /// fn example<B: Backend>() {
929    ///    let device = B::Device::default();
930    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
931    ///    let tensor = tensor.is_finite();
932    ///    println!("{tensor}");
933    ///    // [[false, true, false], [false, false, false]]
934    /// }
935    /// ```
936    pub fn is_inf(self) -> Tensor<B, D, Bool> {
937        let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
938        Tensor::new(B::float_is_inf(self.primitive.tensor(), out_dtype))
939    }
940
941    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
942    ///
943    /// # Returns
944    ///
945    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
946    /// either INF, -INF or NAN
947    ///
948    /// # Example
949    ///
950    /// ```rust
951    /// use burn_tensor::backend::Backend;
952    /// use burn_tensor::{Tensor, Bool, Shape};
953    ///
954    /// fn example<B: Backend>() {
955    ///    let device = B::Device::default();
956    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
957    ///    let tensor = tensor.is_finite();
958    ///    println!("{tensor}");
959    ///    // [[true, false, true], [false, true, true]]
960    /// }
961    /// ```
962    pub fn is_finite(self) -> Tensor<B, D, Bool> {
963        self.clone()
964            .is_nan()
965            .bool_not()
966            .bool_and(self.is_inf().bool_not())
967    }
968
969    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
970    /// using the given locations in [-1, 1].
971    ///
972    /// # Arguments
973    ///
974    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
975    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
976    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
977    ///
978    /// # Returns
979    ///
980    /// A tensor with shape (N, C, H_out, W_out)
981    ///
982    /// # Example
983    ///
984    /// ```ignore
985    /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
986    ///
987    /// // Default options (bilinear, zeros padding, align_corners=false)
988    /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
989    ///
990    /// // Custom options
991    /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
992    ///     .with_padding_mode(GridSamplePaddingMode::Border)
993    ///     .with_align_corners(true);
994    /// let output = tensor.grid_sample_2d(grid, options);
995    /// ```
996    pub fn grid_sample_2d(
997        self,
998        grid: Tensor<B, D>,
999        options: impl Into<GridSampleOptions>,
1000    ) -> Tensor<B, D> {
1001        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
1002            self.primitive.tensor(),
1003            grid.primitive.tensor(),
1004            options.into(),
1005        )))
1006    }
1007
1008    /// Computes the cross product of `self` and another tensor along a given dimension.
1009    ///
1010    /// Both `self` and `other` **must have size 3** along the specified `dim`,
1011    /// because the cross product is only defined in three-dimensional space.
1012    ///
1013    /// # Arguments
1014    ///
1015    /// * `other` - The other tensor to take the cross product with.
1016    /// * `dim`   - The dimension along which to compute the cross product.
1017    ///
1018    /// # Returns
1019    ///
1020    /// A tensor containing the cross product of `self` and `other` along `dim`.
1021    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
1022        let dim = dim.expect_dim_index(D);
1023        check!(TensorCheck::cross(&self, &other, dim));
1024        Tensor::new(TensorPrimitive::Float(B::float_cross(
1025            self.primitive.tensor(),
1026            other.primitive.tensor(),
1027            dim,
1028        )))
1029    }
1030
1031    /// Applies element wise power operation with a float Tensor
1032    ///
1033    /// # Arguments
1034    ///
1035    /// * `other` - The tensor to apply the power operation with.
1036    ///
1037    /// # Example
1038    ///
1039    /// ```rust
1040    /// use burn_tensor::backend::Backend;
1041    /// use burn_tensor::{Tensor, Shape};
1042    ///
1043    /// fn example<B: Backend>() {
1044    ///    let device = B::Device::default();
1045    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1046    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1047    ///    let tensor = tensor1.powf(tensor2);
1048    ///    println!("{tensor}");
1049    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1050    /// }
1051    /// ```
1052    pub fn powf(self, other: Self) -> Self {
1053        let primitive = match (self.primitive, other.primitive) {
1054            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
1055                TensorPrimitive::Float(B::float_powf(lhs, rhs))
1056            }
1057            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs),
1058            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
1059                let dtype = rhs.dtype();
1060                TensorPrimitive::Float(B::float_powf(B::dequantize(lhs, dtype.into()), rhs))
1061            }
1062            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
1063                let dtype = lhs.dtype();
1064                TensorPrimitive::Float(B::float_powf(lhs, B::dequantize(rhs, dtype.into())))
1065            }
1066        };
1067
1068        Tensor::new(primitive)
1069    }
1070
1071    /// Applies element wise power operation with a float scalar
1072    ///
1073    /// # Arguments
1074    ///
1075    /// * `other` - The scalar to apply the power operation with.
1076    ///
1077    /// # Example
1078    ///
1079    /// ```rust
1080    /// use burn_tensor::backend::Backend;
1081    /// use burn_tensor::{Tensor, Shape};
1082    ///
1083    /// fn example<B: Backend>() {
1084    ///    let device = B::Device::default();
1085    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1086    ///    let tensor = tensor.powf_scalar(2.0);
1087    ///    println!("{tensor}");
1088    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1089    /// }
1090    /// ```
1091    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1092        let rhs = Scalar::new(other, &self.dtype());
1093
1094        let primitive = match self.primitive {
1095            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
1096            TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
1097        };
1098
1099        Tensor::new(primitive)
1100    }
1101}
1102
1103impl<const D: usize, B: Backend> Tensor<B, D> {
1104    /// Draws samples from a categorical distribution defined by the last dimension
1105    /// of the input tensor.
1106    ///
1107    /// The last dimension is treated as a (possibly unnormalized) set of weights
1108    /// defining a categorical distribution over categories. All leading dimensions
1109    /// are treated as batch dimensions. The method returns integer indices of the
1110    /// sampled categories.
1111    ///
1112    /// # Arguments
1113    ///
1114    /// * `num_samples` - Number of samples to draw per distribution. Must be >= 1.
1115    ///
1116    /// # Panics
1117    ///
1118    /// Panics if `num_samples` is 0.
1119    ///
1120    /// # Note
1121    ///
1122    /// Distributions with all-zero weights produce undefined (NaN-based) sampling
1123    /// results. Callers should ensure each distribution has at least one positive
1124    /// weight.
1125    ///
1126    /// # Returns
1127    ///
1128    /// An integer tensor with the same shape as the input, except the last dimension
1129    /// is replaced by `num_samples`, containing sampled category indices in
1130    /// `[0, num_categories)`.
1131    ///
1132    /// # Example
1133    ///
1134    /// ```rust
1135    /// use burn_tensor::backend::Backend;
1136    /// use burn_tensor::Tensor;
1137    ///
1138    /// fn example<B: Backend>() {
1139    ///     let device = B::Device::default();
1140    ///     let probs = Tensor::<B, 2>::from_floats(
1141    ///         [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
1142    ///         &device,
1143    ///     );
1144    ///     let samples = probs.categorical(4);
1145    ///     // First row always samples index 1, second row always samples index 2
1146    ///     println!("{samples}");
1147    /// }
1148    /// ```
1149    pub fn categorical(self, num_samples: usize) -> Tensor<B, D, Int> {
1150        assert!(num_samples > 0, "categorical: num_samples must be >= 1");
1151
1152        let shape = self.shape();
1153        let num_categories = shape[D - 1];
1154        let batch_size = (shape.num_elements() / num_categories).max(1);
1155        let device = self.device();
1156
1157        // Flatten leading dimensions into a single batch dimension: [batch, categories]
1158        let flat: Tensor<B, 2> = self.reshape([batch_size, num_categories]);
1159
1160        // Normalize weights to probabilities
1161        let sum = flat.clone().sum_dim(1); // [batch, 1]
1162        let probs = flat / sum;
1163
1164        // Cumulative sum along categories dimension
1165        let cumsum = probs.cumsum(1); // [batch, categories]
1166
1167        // Uniform random values for each sample
1168        let uniform = Tensor::<B, 2>::random(
1169            [batch_size, num_samples],
1170            Distribution::Uniform(0.0, 1.0),
1171            &device,
1172        ); // [batch, num_samples]
1173
1174        // Expand dimensions for broadcasting:
1175        //   cumsum: [batch, categories, 1]
1176        //   uniform: [batch, 1, num_samples]
1177        let cumsum_3d: Tensor<B, 3> = cumsum.unsqueeze_dim(2);
1178        let uniform_3d: Tensor<B, 3> = uniform.unsqueeze_dim(1);
1179
1180        // Count categories where cumsum < uniform (inverse CDF)
1181        let mask: Tensor<B, 3, Bool> = cumsum_3d.lower(uniform_3d);
1182        let indices: Tensor<B, 2, Int> = mask.int().sum_dim(1).squeeze_dim::<2>(1);
1183
1184        // Clamp to valid range to guard against floating-point imprecision in cumsum
1185        let indices = indices.clamp(0, num_categories as i64 - 1);
1186
1187        // Reshape back to [...leading_dims, num_samples]
1188        let mut out_shape = shape;
1189        out_shape[D - 1] = num_samples;
1190        indices.reshape(out_shape)
1191    }
1192}
1193
1194#[cfg(feature = "distributed")]
1195impl<const D: usize, B> Tensor<B, D>
1196where
1197    B: AutodiffBackend,
1198{
1199    /// Returns true if the tensor is marked as distributed.
1200    pub fn is_distributed(&self) -> bool {
1201        match &self.primitive {
1202            TensorPrimitive::Float(tensor) => B::is_distributed(tensor),
1203            TensorPrimitive::QFloat(_) => unimplemented!(),
1204        }
1205    }
1206
1207    /// Mark the tensor as distributed.
1208    ///
1209    /// This function does nothing when autodiff or distributed is not enabled.
1210    pub fn set_distributed(self, param_id: DistributedParamId) -> Self {
1211        let primitive = match self.primitive {
1212            TensorPrimitive::Float(tensor) => {
1213                TensorPrimitive::Float(B::set_distributed_params(tensor, param_id))
1214            }
1215            TensorPrimitive::QFloat(_) => unimplemented!(),
1216        };
1217        Self::new(primitive)
1218    }
1219}