Skip to main content

burn_tensor/tensor/api/
float.rs

1use crate::AsIndex;
2use crate::Cast;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Float, Int, TensorPrimitive};
13#[cfg(feature = "distributed")]
14use burn_backend::AutodiffBackend;
15use burn_backend::ElementConversion;
16use burn_backend::Scalar;
17use burn_backend::TensorMetadata;
18#[cfg(feature = "distributed")]
19use burn_backend::distributed::DistributedParamId;
20use burn_backend::get_device_settings;
21use burn_backend::tensor::FloatMathOps;
22use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
23use core::f32;
24
25/// Default RTOL value for `is_close` and `all_close`.
26pub const DEFAULT_RTOL: f64 = 1e-5;
27
28/// Default ATOL value for `is_close` and `all_close`.
29pub const DEFAULT_ATOL: f64 = 1e-8;
30
31impl<const D: usize, B> Tensor<B, D>
32where
33    B: Backend,
34{
35    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
36    ///
37    #[cfg_attr(
38        doc,
39        doc = r#"
40$y_i = \text{erf}\(x_i\)$
41
42The error function is defined as:
43
44$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
45"#
46    )]
47    #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
48    pub fn erf(self) -> Self {
49        Self::new(TensorPrimitive::Float(B::float_erf(
50            self.primitive.tensor(),
51        )))
52    }
53
54    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
55    /// (or multiplicative inverse) element wise.
56    ///
57    #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
58    #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
59    pub fn recip(self) -> Self {
60        Self::new(TensorPrimitive::Float(B::float_recip(
61            self.primitive.tensor(),
62        )))
63    }
64
65    /// Converts each of the elements of the input tensor from angles in degrees to radians.
66    ///
67    /// # Example
68    /// ```ignore
69    /// let tensor_in_radians = tensor.deg2rad();
70    /// ```
71    pub fn deg2rad(self) -> Self {
72        self.mul_scalar(f32::consts::PI / 180.0)
73    }
74
75    /// Converts each of the elements of the input tensor from angles in radians to degrees.
76    ///
77    /// # Example
78    /// ```ignore
79    /// let tensor_in_degrees = tensor.rad2deg();
80    /// ```
81    pub fn rad2deg(self) -> Self {
82        self.mul_scalar(180.0 / f32::consts::PI)
83    }
84
85    /// Applies element wise round operation.
86    ///
87    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
88    /// strategy, with halfway cases rounded to the nearest even integer value.
89    pub fn round(self) -> Self {
90        Self::new(TensorPrimitive::Float(B::float_round(
91            self.primitive.tensor(),
92        )))
93    }
94
95    /// Applies element wise floor operation.
96    pub fn floor(self) -> Self {
97        Self::new(TensorPrimitive::Float(B::float_floor(
98            self.primitive.tensor(),
99        )))
100    }
101
102    /// Applies element wise ceil operation.
103    pub fn ceil(self) -> Self {
104        Self::new(TensorPrimitive::Float(B::float_ceil(
105            self.primitive.tensor(),
106        )))
107    }
108
109    /// Create a tensor from floats (f32) on a given device.
110    ///
111    /// # Example
112    ///
113    /// ```rust
114    /// use burn_tensor::backend::Backend;
115    /// use burn_tensor::Tensor;
116    ///
117    /// fn example<B: Backend>() {
118    ///     let device = B::Device::default();
119    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
120    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
121    /// }
122    /// ```
123    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
124        Self::from_data(floats.into().convert::<f32>(), device)
125    }
126
127    /// Returns a new tensor with the same shape and device as the current tensor and the data
128    /// cast to Integer.
129    ///
130    /// # Example
131    ///
132    /// ```rust
133    /// use burn_tensor::backend::Backend;
134    /// use burn_tensor::Tensor;
135    ///
136    /// fn example<B: Backend>() {
137    ///     let device = Default::default();
138    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
139    ///     let int_tensor = float_tensor.int();
140    /// }
141    /// ```
142    pub fn int(self) -> Tensor<B, D, Int> {
143        let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
144        Tensor::new(B::float_into_int(self.primitive.tensor(), out_dtype))
145    }
146
147    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
148    /// values sampled from the given distribution.
149    pub fn random_like(&self, distribution: Distribution) -> Self {
150        Self::new(TensorPrimitive::Float(B::float_random(
151            self.shape(),
152            distribution,
153            &self.device(),
154            self.dtype().into(),
155        )))
156    }
157
158    /// Calculate the variance along the given dimension.
159    pub fn var(self, dim: usize) -> Self {
160        stats::var(self, dim)
161    }
162
163    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
164    pub fn var_bias(self, dim: usize) -> Self {
165        stats::var_bias(self, dim)
166    }
167
168    /// Calculate the variance along the given dimension and also returns the mean.
169    pub fn var_mean(self, dim: usize) -> (Self, Self) {
170        let mean = self.clone().mean_dim(dim);
171        let var = stats::var_with_mean(self, mean.clone(), dim);
172        (var, mean)
173    }
174
175    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
176    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
177        let mean = self.clone().mean_dim(dim);
178        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
179        (var, mean)
180    }
181
182    /// Returns the median value along the specified dimension.
183    ///
184    /// The median is not unique for input tensors with an even number of elements
185    /// in the reduced dimension. In this case, the lower of the two medians is returned,
186    /// following PyTorch's behavior.
187    ///
188    /// # Note
189    ///
190    /// The current implementation performs a full sort along the specified dimension,
191    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
192    /// to CPU for the sort operation, which may result in slower performance compared
193    /// to native GPU operations.
194    ///
195    /// # Arguments
196    ///
197    /// - `dim` - The dimension along which to compute the median.
198    ///
199    /// # Returns
200    ///
201    /// - A tensor containing the median values along the specified dimension.
202    ///
203    /// # Example 1
204    ///
205    /// ```ignore
206    /// // Assuming backend B
207    /// let device = B::Device::default();
208    /// let tensor = Tensor::<B, 2>::from_data(
209    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
210    ///     &device,
211    /// );
212    ///
213    /// // Median along dimension 0:
214    /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]
215    /// let median = tensor.median(0);
216    /// // Result: [[1.0, 4.0, 3.0, 2.0]]
217    ///
218    /// // Median along dimension 1:
219    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
220    /// let median = tensor.median(1);
221    /// // Result: [[2.0], [6.0]]
222    /// ```
223    ///
224    /// # Example 2
225    ///
226    /// The median across all elements can be calculated as follows:
227    ///
228    /// ```ignore
229    /// // D is the number of dimensions of the tensor
230    /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);
231    ///
232    /// // Calculate median for dim 0 since the tensor has become 1 dimensional
233    /// let median = flattened_tensor.median(0);
234    /// // Result: [4.0]
235    /// ```
236    pub fn median(self, dim: usize) -> Self {
237        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
238        // instead of leveraging a full sort to get the median.
239        stats::median(self, dim)
240    }
241
242    /// Returns the median value along the specified dimension and its index.
243    ///
244    /// The median is not unique for input tensors with an even number of elements
245    /// in the reduced dimension. In this case, the lower of the two medians is returned,
246    /// following PyTorch's behavior.
247    ///
248    /// # Note
249    ///
250    /// The current implementation performs a full sort along the specified dimension,
251    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
252    /// to CPU for the sort operation, which may result in slower performance compared
253    /// to native GPU operations.
254    ///
255    /// # Arguments
256    ///
257    /// - `dim` - The dimension along which to compute the median.
258    ///
259    /// # Returns
260    ///
261    /// A tuple containing:
262    /// - A tensor with the median values.
263    /// - A tensor with the indices of the median values in the original tensor.
264    ///
265    /// # Example
266    ///
267    /// ```ignore
268    /// // Assuming backend B
269    /// let device = B::Device::default();
270    /// let tensor = Tensor::<B, 2>::from_data(
271    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
272    ///     &device,
273    /// );
274    ///
275    /// // Median along dimension 1:
276    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
277    /// let (values, indices) = tensor.median_with_indices(1);
278    /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)
279    /// ```
280    pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
281        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
282        // instead of leveraging a full sort to get the median.
283        stats::median_with_indices(self, dim)
284    }
285
286    /// Converts a tensor to the specified data type.
287    ///
288    /// Supports both within-kind casting (e.g., `FloatDType::F64`) and cross-kind casting
289    /// (e.g., `IntDType::I64` to produce an int tensor).
290    ///
291    /// This is a no-op when casting to the current dtype within the same kind.
292    ///
293    /// # Example
294    ///
295    /// ```rust
296    /// use burn_tensor::backend::Backend;
297    /// use burn_tensor::{Tensor, FloatDType, IntDType};
298    ///
299    /// fn example<B: Backend>() {
300    ///     let device = Default::default();
301    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.5], &device);
302    ///
303    ///     // Within-kind cast (float to float)
304    ///     let f64_tensor = float_tensor.clone().cast(FloatDType::F64);
305    ///
306    ///     // Cross-kind cast (float to int)
307    ///     let int_tensor = float_tensor.cast(IntDType::I64);
308    /// }
309    /// ```
310    #[must_use]
311    pub fn cast<T: Cast<B, Float>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
312        Tensor::new(T::cast(self.primitive, dtype))
313    }
314
315    /// Detach the current tensor from the autodiff graph.
316    ///
317    /// This function does nothing when autodiff is not enabled.
318    /// This can be used in batchers or elsewhere to ensure that previous operations are not
319    /// considered in the autodiff graph.
320    pub fn detach(self) -> Self {
321        Self::new(TensorPrimitive::Float(B::float_detach(
322            self.primitive.tensor(),
323        )))
324    }
325
326    /// Mark the tensor to keep gradients during the backward pass.
327    ///
328    /// This function does nothing when autodiff is not enabled.
329    pub fn require_grad(self) -> Self {
330        self.set_require_grad(true)
331    }
332
333    /// Returns true if the tensor requires gradients during the backward pass.
334    pub fn is_require_grad(&self) -> bool {
335        match &self.primitive {
336            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
337            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
338        }
339    }
340
341    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
342    /// When tracked, the gradients will be available after the backward pass.
343    ///
344    /// This function does nothing when autodiff is not enabled.
345    pub fn set_require_grad(self, require_grad: bool) -> Self {
346        let primitive = match self.primitive {
347            TensorPrimitive::Float(tensor) => {
348                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
349            }
350            TensorPrimitive::QFloat(tensor) => {
351                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
352            }
353        };
354        Self::new(primitive)
355    }
356
357    /// Applies the relu function to the tensor.
358    pub(crate) fn relu(self) -> Self {
359        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
360    }
361
362    /// Calculate covaraince matrix between different entries alongside a given dimension.
363    ///
364    /// # Arguments
365    ///
366    /// * `size` - The size of the square matrix.
367    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
368    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
369        let n = self.dims()[dim];
370        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
371        centered
372            .clone()
373            .transpose()
374            .matmul(centered)
375            .div_scalar(n as f32 - correction_factor as f32)
376    }
377
378    /// Convert the tensor to a lower precision data type based on the quantization scheme.
379    ///
380    /// # Arguments
381    ///
382    /// * `scheme` - The quantization scheme.
383    /// * `qparams` - The pre-computed quantization parameters.
384    ///
385    /// # Returns
386    ///
387    /// The quantized tensor.
388    pub fn quantize(
389        self,
390        scheme: &QuantScheme,
391        qparams: QuantizationParameters<B>,
392    ) -> Tensor<B, D> {
393        Tensor::new(TensorPrimitive::QFloat(B::quantize(
394            self.primitive.tensor(),
395            scheme,
396            QuantizationParametersPrimitive {
397                scales: qparams.scales.primitive.tensor(),
398            },
399        )))
400    }
401
402    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
403    ///
404    /// # Arguments
405    ///
406    /// * `scheme` - The quantization scheme.
407    ///
408    /// # Returns
409    ///
410    /// The quantized tensor.
411    ///
412    /// # Notes
413    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
414    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
415        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
416            self.primitive.tensor(),
417            scheme,
418        )))
419    }
420
421    /// Convert the tensor back to a higher precision data type.
422    ///
423    /// If the tensor is not quantized, its value is simply returned.
424    ///
425    /// # Returns
426    ///
427    /// The dequantized tensor.
428    pub fn dequantize(self) -> Tensor<B, D> {
429        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
430    }
431
432    /// Checks element wise if the tensor is close to another tensor.
433    ///
434    /// The tolerance is defined by the following equation:
435    ///
436    /// ```text
437    /// abs(a - b) <= (atol + rtol * abs(b))
438    ///
439    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
440    /// and `atol` is the absolute tolerance.
441    /// ```
442    ///
443    /// # Arguments
444    ///
445    /// * `other` - The tensor to compare with.
446    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
447    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
448    ///
449    /// # Returns
450    ///
451    /// A boolean tensor with the same shape as the input tensors.
452    ///
453    /// # Example
454    ///
455    /// ```rust
456    /// use burn_tensor::backend::Backend;
457    /// use burn_tensor::{Tensor, Shape};
458    ///
459    /// fn example<B: Backend>() {
460    ///    let device = B::Device::default();
461    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
462    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
463    ///    let tensor = tensor1.is_close(tensor2, None, None);
464    ///    println!("{tensor}");
465    ///    // [[true, true, true], [true, true, true]]
466    /// }
467    /// ```
468    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
469        let rtol = rtol.unwrap_or(DEFAULT_RTOL);
470        let atol = atol.unwrap_or(DEFAULT_ATOL);
471
472        // check finite difference is close
473        let is_close_finite_val = self
474            .clone()
475            .sub(other.clone())
476            .abs()
477            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
478            .bool_and(self.clone().is_finite())
479            .bool_and(other.clone().is_finite());
480
481        // check if both are infinite and have same sign
482        let inf_same_sign = self
483            .clone()
484            .is_finite()
485            .bool_not()
486            .bool_and(other.clone().is_finite().bool_not())
487            .bool_and(self.equal(other));
488
489        is_close_finite_val.bool_or(inf_same_sign)
490    }
491
492    /// Checks if all elements are close to another tensor.
493    ///
494    /// The tolerance is defined by the following equation:
495    ///
496    /// ```text
497    ///
498    /// abs(a - b) <= (atol + rtol * abs(b))
499    ///
500    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
501    /// and `atol` is the absolute tolerance.
502    ///
503    /// ```
504    ///
505    /// # Arguments
506    ///
507    /// * `other` - The tensor to compare with.
508    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
509    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
510    ///
511    /// # Returns
512    ///
513    /// A boolean scalar.
514    ///
515    /// # Remarks
516    ///
517    /// # Example
518    ///
519    /// ```rust
520    /// use burn_tensor::backend::Backend;
521    /// use burn_tensor::{Tensor, Shape};
522    ///
523    /// fn example<B: Backend>() {
524    ///    let device = B::Device::default();
525    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
526    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
527    ///    let result = tensor1.all_close(tensor2, None, None);
528    ///    println!("{}", result);
529    ///    // true
530    /// }
531    /// ```
532    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
533        self.is_close(other, rtol, atol)
534            .all()
535            .into_scalar()
536            .to_bool()
537    }
538
539    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
540    ///
541    /// # Returns
542    ///
543    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
544    ///
545    /// # Example
546    ///
547    /// ```rust
548    /// use burn_tensor::backend::Backend;
549    /// use burn_tensor::{Tensor, Bool, Shape};
550    ///
551    /// fn example<B: Backend>() {
552    ///    let device = B::Device::default();
553    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
554    ///    let tensor = tensor.is_nan();
555    ///    println!("{tensor}");
556    ///    // [[false, true, false], [false, false, false]]
557    /// }
558    /// ```
559    pub fn is_nan(self) -> Tensor<B, D, Bool> {
560        let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
561        Tensor::new(B::float_is_nan(self.primitive.tensor(), out_dtype))
562    }
563
564    /// Checks if the tensor contains any NaN values.
565    ///
566    /// # Returns
567    ///
568    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
569    ///
570    /// # Example
571    ///
572    /// ```rust
573    /// use burn_tensor::backend::Backend;
574    /// use burn_tensor::{Tensor, Bool, Shape};
575    ///
576    /// fn example<B: Backend>() {
577    ///   let device = B::Device::default();
578    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
579    ///   let tensor = tensor.contains_nan();
580    ///   println!("{tensor}");
581    ///   // [true]
582    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
583    ///   let tensor = tensor.contains_nan();
584    ///   println!("{tensor}");
585    ///   // [false]
586    /// }
587    /// ```
588    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
589        // Summing the tensor will result in NaN if the tensor contains any NaN values
590        // This is faster than checking each element individually
591        // because it rolls up the NaN values into a single value
592        let sum = self.sum();
593
594        sum.is_nan()
595    }
596
597    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
598    ///
599    /// # Returns
600    ///
601    /// A boolean tensor where `true` indicates that the value is infinite
602    ///
603    /// # Example
604    ///
605    /// ```rust
606    /// use burn_tensor::backend::Backend;
607    /// use burn_tensor::{Tensor, Bool, Shape};
608    ///
609    /// fn example<B: Backend>() {
610    ///    let device = B::Device::default();
611    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
612    ///    let tensor = tensor.is_finite();
613    ///    println!("{tensor}");
614    ///    // [[false, true, false], [false, false, false]]
615    /// }
616    /// ```
617    pub fn is_inf(self) -> Tensor<B, D, Bool> {
618        let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
619        Tensor::new(B::float_is_inf(self.primitive.tensor(), out_dtype))
620    }
621
622    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
623    ///
624    /// # Returns
625    ///
626    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
627    /// either INF, -INF or NAN
628    ///
629    /// # Example
630    ///
631    /// ```rust
632    /// use burn_tensor::backend::Backend;
633    /// use burn_tensor::{Tensor, Bool, Shape};
634    ///
635    /// fn example<B: Backend>() {
636    ///    let device = B::Device::default();
637    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
638    ///    let tensor = tensor.is_finite();
639    ///    println!("{tensor}");
640    ///    // [[true, false, true], [false, true, true]]
641    /// }
642    /// ```
643    pub fn is_finite(self) -> Tensor<B, D, Bool> {
644        self.clone()
645            .is_nan()
646            .bool_not()
647            .bool_and(self.is_inf().bool_not())
648    }
649
650    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
651    /// using the given locations in [-1, 1].
652    ///
653    /// # Arguments
654    ///
655    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
656    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
657    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
658    ///
659    /// # Returns
660    ///
661    /// A tensor with shape (N, C, H_out, W_out)
662    ///
663    /// # Example
664    ///
665    /// ```ignore
666    /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
667    ///
668    /// // Default options (bilinear, zeros padding, align_corners=false)
669    /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
670    ///
671    /// // Custom options
672    /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
673    ///     .with_padding_mode(GridSamplePaddingMode::Border)
674    ///     .with_align_corners(true);
675    /// let output = tensor.grid_sample_2d(grid, options);
676    /// ```
677    pub fn grid_sample_2d(
678        self,
679        grid: Tensor<B, D>,
680        options: impl Into<GridSampleOptions>,
681    ) -> Tensor<B, D> {
682        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
683            self.primitive.tensor(),
684            grid.primitive.tensor(),
685            options.into(),
686        )))
687    }
688
689    /// Computes the cross product of `self` and another tensor along a given dimension.
690    ///
691    /// Both `self` and `other` **must have size 3** along the specified `dim`,
692    /// because the cross product is only defined in three-dimensional space.
693    ///
694    /// # Arguments
695    ///
696    /// * `other` - The other tensor to take the cross product with.
697    /// * `dim`   - The dimension along which to compute the cross product.
698    ///
699    /// # Returns
700    ///
701    /// A tensor containing the cross product of `self` and `other` along `dim`.
702    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
703        let dim = dim.expect_dim_index(D);
704        check!(TensorCheck::cross(&self, &other, dim));
705        Tensor::new(TensorPrimitive::Float(B::float_cross(
706            self.primitive.tensor(),
707            other.primitive.tensor(),
708            dim,
709        )))
710    }
711
712    /// Applies element wise power operation with a float Tensor
713    ///
714    /// # Arguments
715    ///
716    /// * `other` - The tensor to apply the power operation with.
717    ///
718    /// # Example
719    ///
720    /// ```rust
721    /// use burn_tensor::backend::Backend;
722    /// use burn_tensor::{Tensor, Shape};
723    ///
724    /// fn example<B: Backend>() {
725    ///    let device = B::Device::default();
726    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
727    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
728    ///    let tensor = tensor1.powf(tensor2);
729    ///    println!("{tensor}");
730    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
731    /// }
732    /// ```
733    pub fn powf(self, other: Self) -> Self {
734        let primitive = match (self.primitive, other.primitive) {
735            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
736                TensorPrimitive::Float(B::float_powf(lhs, rhs))
737            }
738            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs),
739            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
740                let dtype = rhs.dtype();
741                TensorPrimitive::Float(B::float_powf(B::dequantize(lhs, dtype.into()), rhs))
742            }
743            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
744                let dtype = lhs.dtype();
745                TensorPrimitive::Float(B::float_powf(lhs, B::dequantize(rhs, dtype.into())))
746            }
747        };
748
749        Tensor::new(primitive)
750    }
751
752    /// Applies element wise power operation with a float scalar
753    ///
754    /// # Arguments
755    ///
756    /// * `other` - The scalar to apply the power operation with.
757    ///
758    /// # Example
759    ///
760    /// ```rust
761    /// use burn_tensor::backend::Backend;
762    /// use burn_tensor::{Tensor, Shape};
763    ///
764    /// fn example<B: Backend>() {
765    ///    let device = B::Device::default();
766    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
767    ///    let tensor = tensor.powf_scalar(2.0);
768    ///    println!("{tensor}");
769    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
770    /// }
771    /// ```
772    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
773        let rhs = Scalar::new(other, &self.dtype());
774
775        let primitive = match self.primitive {
776            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
777            TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
778        };
779
780        Tensor::new(primitive)
781    }
782}
783
784impl<const D: usize, B: Backend> Tensor<B, D> {
785    /// Draws samples from a categorical distribution defined by the last dimension
786    /// of the input tensor.
787    ///
788    /// The last dimension is treated as a (possibly unnormalized) set of weights
789    /// defining a categorical distribution over categories. All leading dimensions
790    /// are treated as batch dimensions. The method returns integer indices of the
791    /// sampled categories.
792    ///
793    /// # Arguments
794    ///
795    /// * `num_samples` - Number of samples to draw per distribution. Must be >= 1.
796    ///
797    /// # Panics
798    ///
799    /// Panics if `num_samples` is 0.
800    ///
801    /// # Note
802    ///
803    /// Distributions with all-zero weights produce undefined (NaN-based) sampling
804    /// results. Callers should ensure each distribution has at least one positive
805    /// weight.
806    ///
807    /// # Returns
808    ///
809    /// An integer tensor with the same shape as the input, except the last dimension
810    /// is replaced by `num_samples`, containing sampled category indices in
811    /// `[0, num_categories)`.
812    ///
813    /// # Example
814    ///
815    /// ```rust
816    /// use burn_tensor::backend::Backend;
817    /// use burn_tensor::Tensor;
818    ///
819    /// fn example<B: Backend>() {
820    ///     let device = B::Device::default();
821    ///     let probs = Tensor::<B, 2>::from_floats(
822    ///         [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
823    ///         &device,
824    ///     );
825    ///     let samples = probs.categorical(4);
826    ///     // First row always samples index 1, second row always samples index 2
827    ///     println!("{samples}");
828    /// }
829    /// ```
830    pub fn categorical(self, num_samples: usize) -> Tensor<B, D, Int> {
831        assert!(num_samples > 0, "categorical: num_samples must be >= 1");
832
833        let shape = self.shape();
834        let num_categories = shape[D - 1];
835        let batch_size = (shape.num_elements() / num_categories).max(1);
836        let device = self.device();
837
838        // Flatten leading dimensions into a single batch dimension: [batch, categories]
839        let flat: Tensor<B, 2> = self.reshape([batch_size, num_categories]);
840
841        // Normalize weights to probabilities
842        let sum = flat.clone().sum_dim(1); // [batch, 1]
843        let probs = flat / sum;
844
845        // Cumulative sum along categories dimension
846        let cumsum = probs.cumsum(1); // [batch, categories]
847
848        // Uniform random values for each sample
849        let uniform = Tensor::<B, 2>::random(
850            [batch_size, num_samples],
851            Distribution::Uniform(0.0, 1.0),
852            &device,
853        ); // [batch, num_samples]
854
855        // Expand dimensions for broadcasting:
856        //   cumsum: [batch, categories, 1]
857        //   uniform: [batch, 1, num_samples]
858        let cumsum_3d: Tensor<B, 3> = cumsum.unsqueeze_dim(2);
859        let uniform_3d: Tensor<B, 3> = uniform.unsqueeze_dim(1);
860
861        // Count categories where cumsum < uniform (inverse CDF)
862        let mask: Tensor<B, 3, Bool> = cumsum_3d.lower(uniform_3d);
863        let indices: Tensor<B, 2, Int> = mask.int().sum_dim(1).squeeze_dim::<2>(1);
864
865        // Clamp to valid range to guard against floating-point imprecision in cumsum
866        let indices = indices.clamp(0, num_categories as i64 - 1);
867
868        // Reshape back to [...leading_dims, num_samples]
869        let mut out_shape = shape;
870        out_shape[D - 1] = num_samples;
871        indices.reshape(out_shape)
872    }
873}
874
875#[cfg(feature = "distributed")]
876impl<const D: usize, B> Tensor<B, D>
877where
878    B: AutodiffBackend,
879{
880    /// Returns true if the tensor is marked as distributed.
881    pub fn is_distributed(&self) -> bool {
882        match &self.primitive {
883            TensorPrimitive::Float(tensor) => B::is_distributed(tensor),
884            TensorPrimitive::QFloat(_) => unimplemented!(),
885        }
886    }
887
888    /// Mark the tensor as distributed.
889    ///
890    /// This function does nothing when autodiff or distributed is not enabled.
891    pub fn set_distributed(self, param_id: DistributedParamId) -> Self {
892        let primitive = match self.primitive {
893            TensorPrimitive::Float(tensor) => {
894                TensorPrimitive::Float(B::set_distributed_params(tensor, param_id))
895            }
896            TensorPrimitive::QFloat(_) => unimplemented!(),
897        };
898        Self::new(primitive)
899    }
900}
901
902impl<B, const D: usize, K> Tensor<B, D, K>
903where
904    B: Backend,
905    K: FloatMathOps<B>,
906{
907    /// Applies element wise square operation.
908    ///
909    #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
910    #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
911    pub fn square(self) -> Self {
912        Self::new(K::square(self.primitive))
913    }
914
915    /// Applies element wise exponential operation.
916    ///
917    #[cfg_attr(doc, doc = r#"$y_i = e^{x_i}$"#)]
918    #[cfg_attr(not(doc), doc = "`y = e^x`")]
919    pub fn exp(self) -> Self {
920        Self::new(K::exp(self.primitive))
921    }
922
923    /// Applies element wise natural logarithm of one plus the input tensor.
924    ///
925    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
926    #[cfg_attr(not(doc), doc = "`y_i = log1p(x_i)`")]
927    pub fn log1p(self) -> Self {
928        Self::new(K::log1p(self.primitive))
929    }
930
931    /// Applies element wise natural log operation *ln*.
932    ///
933    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
934    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
935    pub fn log(self) -> Self {
936        Self::new(K::log(self.primitive))
937    }
938
939    /// Applies element wise square root operation.
940    ///
941    pub fn sqrt(self) -> Self {
942        Tensor::new(K::sqrt(self.primitive))
943    }
944    /// Applies element wise cosine operation.
945    ///
946    #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
947    #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
948    pub fn cos(self) -> Self {
949        Tensor::new(K::cos(self.primitive))
950    }
951
952    /// Applies element wise sine operation.
953    ///
954    #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
955    #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
956    pub fn sin(self) -> Self {
957        Tensor::new(K::sin(self.primitive))
958    }
959
960    /// Applies element wise tangent operation.
961    ///
962    #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
963    #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
964    pub fn tan(self) -> Self {
965        Tensor::new(K::tan(self.primitive))
966    }
967
968    /// Applies element wise hyperbolic cosine operation.
969    ///
970    #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
971    #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
972    ///
973    /// # Example
974    ///
975    /// ```rust
976    /// use burn_tensor::backend::Backend;
977    /// use burn_tensor::Tensor;
978    ///
979    /// fn example<B: Backend>() {
980    ///     let device = Default::default();
981    ///
982    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
983    ///     println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
984    /// }
985    /// ```
986    pub fn cosh(self) -> Self {
987        Tensor::new(K::cosh(self.primitive))
988    }
989
990    /// Applies element wise hyperbolic sine operation.
991    ///
992    #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
993    #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
994    ///
995    /// # Example
996    ///
997    /// ```rust
998    /// use burn_tensor::backend::Backend;
999    /// use burn_tensor::Tensor;
1000    ///
1001    /// fn example<B: Backend>() {
1002    ///     let device = Default::default();
1003    ///
1004    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
1005    ///     println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
1006    /// }
1007    /// ```
1008    pub fn sinh(self) -> Self {
1009        Tensor::new(K::sinh(self.primitive))
1010    }
1011
1012    /// Applies element wise hyperbolic tangent operation.
1013    ///
1014    #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
1015    #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
1016    ///
1017    /// # Example
1018    ///
1019    /// ```rust
1020    /// use burn_tensor::backend::Backend;
1021    /// use burn_tensor::Tensor;
1022    ///
1023    /// fn example<B: Backend>() {
1024    ///     let device = Default::default();
1025    ///
1026    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
1027    ///     println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
1028    /// }
1029    /// ```
1030    pub fn tanh(self) -> Self {
1031        Tensor::new(K::tanh(self.primitive))
1032    }
1033
1034    /// Applies element wise inverse cosine operation.
1035    ///
1036    #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
1037    #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
1038    ///
1039    /// # Example
1040    ///
1041    /// ```rust
1042    /// use burn_tensor::backend::Backend;
1043    /// use burn_tensor::Tensor;
1044    ///
1045    /// fn example<B: Backend>() {
1046    ///     let device = Default::default();
1047    ///
1048    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
1049    ///     println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
1050    /// }
1051    /// ```
1052    pub fn acos(self) -> Self {
1053        Tensor::new(K::acos(self.primitive))
1054    }
1055
1056    /// Applies element wise inverse hyperbolic cosine operation.
1057    ///
1058    #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
1059    #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
1060    ///
1061    /// # Example
1062    ///
1063    /// ```rust
1064    /// use burn_tensor::backend::Backend;
1065    /// use burn_tensor::Tensor;
1066    ///
1067    /// fn example<B: Backend>() {
1068    ///     let device = Default::default();
1069    ///
1070    ///     let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
1071    ///     println!("{}", tensor.acosh()); // [0.0000, 1.3170, 1.7627]
1072    /// }
1073    /// ```
1074    pub fn acosh(self) -> Self {
1075        Tensor::new(K::acosh(self.primitive))
1076    }
1077
1078    /// Applies element wise inverse sine operation.
1079    ///
1080    #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
1081    #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
1082    ///
1083    /// # Example
1084    ///
1085    /// ```rust
1086    /// use burn_tensor::backend::Backend;
1087    /// use burn_tensor::Tensor;
1088    ///
1089    /// fn example<B: Backend>() {
1090    ///     let device = Default::default();
1091    ///
1092    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
1093    ///     println!("{}", tensor.asin()); // [ 0.0000, -1.5708,  1.5708]
1094    /// }
1095    /// ```
1096    pub fn asin(self) -> Self {
1097        Tensor::new(K::asin(self.primitive))
1098    }
1099
1100    /// Applies element wise inverse hyperbolic sine operation.
1101    ///
1102    #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
1103    #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
1104    ///
1105    /// # Example
1106    ///
1107    /// ```rust
1108    /// use burn_tensor::backend::Backend;
1109    /// use burn_tensor::Tensor;
1110    ///
1111    /// fn example<B: Backend>() {
1112    ///     let device = Default::default();
1113    ///
1114    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
1115    ///     println!("{}", tensor.asinh()); // [ 0.0000, -0.8814,  0.8814]
1116    /// }
1117    /// ```
1118    pub fn asinh(self) -> Self {
1119        Tensor::new(K::asinh(self.primitive))
1120    }
1121
1122    /// Applies element wise inverse tangent operation.
1123    ///
1124    #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
1125    #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
1126    ///
1127    /// # Example
1128    ///
1129    /// ```rust
1130    /// use burn_tensor::backend::Backend;
1131    /// use burn_tensor::Tensor;
1132    ///
1133    /// fn example<B: Backend>() {
1134    ///     let device = Default::default();
1135    ///
1136    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
1137    ///     println!("{}", tensor.atan()); // [ 0.0, -0.7854,  1.1071]
1138    /// }
1139    /// ```
1140    pub fn atan(self) -> Self {
1141        Tensor::new(K::atan(self.primitive))
1142    }
1143
1144    /// Applies element wise inverse hyperbolic tangent operation.
1145    ///
1146    #[cfg_attr(doc, doc = r#"$y_i = \atanh\(x_i\)$"#)]
1147    #[cfg_attr(not(doc), doc = "`y_i = atanh(x_i)`")]
1148    ///
1149    /// # Example
1150    ///
1151    /// ```rust
1152    /// use burn_tensor::backend::Backend;
1153    /// use burn_tensor::Tensor;
1154    ///
1155    /// fn example<B: Backend>() {
1156    ///     let device = Default::default();
1157    ///
1158    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
1159    ///     println!("{}", tensor.atanh()); // [ 0.0, -0.5493,  0.5493]
1160    /// }
1161    /// ```
1162    pub fn atanh(self) -> Self {
1163        Tensor::new(K::atanh(self.primitive))
1164    }
1165
1166    /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
1167    ///
1168    #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
1169    #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
1170    ///
1171    /// # Example
1172    ///
1173    /// ```rust
1174    /// use burn_tensor::backend::Backend;
1175    /// use burn_tensor::Tensor;
1176    ///
1177    /// fn example<B: Backend>() {
1178    ///     let device = Default::default();
1179    ///
1180    ///     let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
1181    ///     let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
1182    ///     println!("{}", lhs.atan2(rhs)); // [-1.1071,  2.0344, -2.0344]
1183    /// }
1184    /// ```
1185    pub fn atan2(self, other: Self) -> Self {
1186        Tensor::new(K::atan2(self.primitive, other.primitive))
1187    }
1188}