burn_tensor/tensor/api/float.rs
1use crate::AsIndex;
2use crate::Cast;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::ops::GridSampleOptions;
8use crate::quantization::{QuantScheme, QuantizationParameters};
9use crate::tensor::backend::Backend;
10use crate::tensor::stats;
11use crate::tensor::{Distribution, TensorData};
12use crate::{Bool, Float, Int, TensorPrimitive};
13#[cfg(feature = "distributed")]
14use burn_backend::AutodiffBackend;
15use burn_backend::ElementConversion;
16use burn_backend::Scalar;
17use burn_backend::TensorMetadata;
18#[cfg(feature = "distributed")]
19use burn_backend::distributed::DistributedParamId;
20use burn_backend::get_device_settings;
21use burn_backend::tensor::FloatMathOps;
22use burn_backend::tensor::quantization::QuantizationParametersPrimitive;
23use core::f32;
24
25/// Default RTOL value for `is_close` and `all_close`.
26pub const DEFAULT_RTOL: f64 = 1e-5;
27
28/// Default ATOL value for `is_close` and `all_close`.
29pub const DEFAULT_ATOL: f64 = 1e-8;
30
31impl<const D: usize, B> Tensor<B, D>
32where
33 B: Backend,
34{
35 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
36 ///
37 #[cfg_attr(
38 doc,
39 doc = r#"
40$y_i = \text{erf}\(x_i\)$
41
42The error function is defined as:
43
44$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
45"#
46 )]
47 #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
48 pub fn erf(self) -> Self {
49 Self::new(TensorPrimitive::Float(B::float_erf(
50 self.primitive.tensor(),
51 )))
52 }
53
54 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
55 /// (or multiplicative inverse) element wise.
56 ///
57 #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
58 #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
59 pub fn recip(self) -> Self {
60 Self::new(TensorPrimitive::Float(B::float_recip(
61 self.primitive.tensor(),
62 )))
63 }
64
65 /// Converts each of the elements of the input tensor from angles in degrees to radians.
66 ///
67 /// # Example
68 /// ```ignore
69 /// let tensor_in_radians = tensor.deg2rad();
70 /// ```
71 pub fn deg2rad(self) -> Self {
72 self.mul_scalar(f32::consts::PI / 180.0)
73 }
74
75 /// Converts each of the elements of the input tensor from angles in radians to degrees.
76 ///
77 /// # Example
78 /// ```ignore
79 /// let tensor_in_degrees = tensor.rad2deg();
80 /// ```
81 pub fn rad2deg(self) -> Self {
82 self.mul_scalar(180.0 / f32::consts::PI)
83 }
84
85 /// Applies element wise round operation.
86 ///
87 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
88 /// strategy, with halfway cases rounded to the nearest even integer value.
89 pub fn round(self) -> Self {
90 Self::new(TensorPrimitive::Float(B::float_round(
91 self.primitive.tensor(),
92 )))
93 }
94
95 /// Applies element wise floor operation.
96 pub fn floor(self) -> Self {
97 Self::new(TensorPrimitive::Float(B::float_floor(
98 self.primitive.tensor(),
99 )))
100 }
101
102 /// Applies element wise ceil operation.
103 pub fn ceil(self) -> Self {
104 Self::new(TensorPrimitive::Float(B::float_ceil(
105 self.primitive.tensor(),
106 )))
107 }
108
109 /// Create a tensor from floats (f32) on a given device.
110 ///
111 /// # Example
112 ///
113 /// ```rust
114 /// use burn_tensor::backend::Backend;
115 /// use burn_tensor::Tensor;
116 ///
117 /// fn example<B: Backend>() {
118 /// let device = B::Device::default();
119 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
120 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
121 /// }
122 /// ```
123 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
124 Self::from_data(floats.into().convert::<f32>(), device)
125 }
126
127 /// Returns a new tensor with the same shape and device as the current tensor and the data
128 /// cast to Integer.
129 ///
130 /// # Example
131 ///
132 /// ```rust
133 /// use burn_tensor::backend::Backend;
134 /// use burn_tensor::Tensor;
135 ///
136 /// fn example<B: Backend>() {
137 /// let device = Default::default();
138 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
139 /// let int_tensor = float_tensor.int();
140 /// }
141 /// ```
142 pub fn int(self) -> Tensor<B, D, Int> {
143 let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
144 Tensor::new(B::float_into_int(self.primitive.tensor(), out_dtype))
145 }
146
147 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
148 /// values sampled from the given distribution.
149 pub fn random_like(&self, distribution: Distribution) -> Self {
150 Self::new(TensorPrimitive::Float(B::float_random(
151 self.shape(),
152 distribution,
153 &self.device(),
154 self.dtype().into(),
155 )))
156 }
157
158 /// Calculate the variance along the given dimension.
159 pub fn var(self, dim: usize) -> Self {
160 stats::var(self, dim)
161 }
162
163 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
164 pub fn var_bias(self, dim: usize) -> Self {
165 stats::var_bias(self, dim)
166 }
167
168 /// Calculate the variance along the given dimension and also returns the mean.
169 pub fn var_mean(self, dim: usize) -> (Self, Self) {
170 let mean = self.clone().mean_dim(dim);
171 let var = stats::var_with_mean(self, mean.clone(), dim);
172 (var, mean)
173 }
174
175 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
176 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
177 let mean = self.clone().mean_dim(dim);
178 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
179 (var, mean)
180 }
181
182 /// Returns the median value along the specified dimension.
183 ///
184 /// The median is not unique for input tensors with an even number of elements
185 /// in the reduced dimension. In this case, the lower of the two medians is returned,
186 /// following PyTorch's behavior.
187 ///
188 /// # Note
189 ///
190 /// The current implementation performs a full sort along the specified dimension,
191 /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
192 /// to CPU for the sort operation, which may result in slower performance compared
193 /// to native GPU operations.
194 ///
195 /// # Arguments
196 ///
197 /// - `dim` - The dimension along which to compute the median.
198 ///
199 /// # Returns
200 ///
201 /// - A tensor containing the median values along the specified dimension.
202 ///
203 /// # Example 1
204 ///
205 /// ```ignore
206 /// // Assuming backend B
207 /// let device = B::Device::default();
208 /// let tensor = Tensor::<B, 2>::from_data(
209 /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
210 /// &device,
211 /// );
212 ///
213 /// // Median along dimension 0:
214 /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]
215 /// let median = tensor.median(0);
216 /// // Result: [[1.0, 4.0, 3.0, 2.0]]
217 ///
218 /// // Median along dimension 1:
219 /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
220 /// let median = tensor.median(1);
221 /// // Result: [[2.0], [6.0]]
222 /// ```
223 ///
224 /// # Example 2
225 ///
226 /// The median across all elements can be calculated as follows:
227 ///
228 /// ```ignore
229 /// // D is the number of dimensions of the tensor
230 /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);
231 ///
232 /// // Calculate median for dim 0 since the tensor has become 1 dimensional
233 /// let median = flattened_tensor.median(0);
234 /// // Result: [4.0]
235 /// ```
236 pub fn median(self, dim: usize) -> Self {
237 // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
238 // instead of leveraging a full sort to get the median.
239 stats::median(self, dim)
240 }
241
242 /// Returns the median value along the specified dimension and its index.
243 ///
244 /// The median is not unique for input tensors with an even number of elements
245 /// in the reduced dimension. In this case, the lower of the two medians is returned,
246 /// following PyTorch's behavior.
247 ///
248 /// # Note
249 ///
250 /// The current implementation performs a full sort along the specified dimension,
251 /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back
252 /// to CPU for the sort operation, which may result in slower performance compared
253 /// to native GPU operations.
254 ///
255 /// # Arguments
256 ///
257 /// - `dim` - The dimension along which to compute the median.
258 ///
259 /// # Returns
260 ///
261 /// A tuple containing:
262 /// - A tensor with the median values.
263 /// - A tensor with the indices of the median values in the original tensor.
264 ///
265 /// # Example
266 ///
267 /// ```ignore
268 /// // Assuming backend B
269 /// let device = B::Device::default();
270 /// let tensor = Tensor::<B, 2>::from_data(
271 /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],
272 /// &device,
273 /// );
274 ///
275 /// // Median along dimension 1:
276 /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]
277 /// let (values, indices) = tensor.median_with_indices(1);
278 /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)
279 /// ```
280 pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
281 // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl
282 // instead of leveraging a full sort to get the median.
283 stats::median_with_indices(self, dim)
284 }
285
286 /// Converts a tensor to the specified data type.
287 ///
288 /// Supports both within-kind casting (e.g., `FloatDType::F64`) and cross-kind casting
289 /// (e.g., `IntDType::I64` to produce an int tensor).
290 ///
291 /// This is a no-op when casting to the current dtype within the same kind.
292 ///
293 /// # Example
294 ///
295 /// ```rust
296 /// use burn_tensor::backend::Backend;
297 /// use burn_tensor::{Tensor, FloatDType, IntDType};
298 ///
299 /// fn example<B: Backend>() {
300 /// let device = Default::default();
301 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.5], &device);
302 ///
303 /// // Within-kind cast (float to float)
304 /// let f64_tensor = float_tensor.clone().cast(FloatDType::F64);
305 ///
306 /// // Cross-kind cast (float to int)
307 /// let int_tensor = float_tensor.cast(IntDType::I64);
308 /// }
309 /// ```
310 #[must_use]
311 pub fn cast<T: Cast<B, Float>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
312 Tensor::new(T::cast(self.primitive, dtype))
313 }
314
315 /// Detach the current tensor from the autodiff graph.
316 ///
317 /// This function does nothing when autodiff is not enabled.
318 /// This can be used in batchers or elsewhere to ensure that previous operations are not
319 /// considered in the autodiff graph.
320 pub fn detach(self) -> Self {
321 Self::new(TensorPrimitive::Float(B::float_detach(
322 self.primitive.tensor(),
323 )))
324 }
325
326 /// Mark the tensor to keep gradients during the backward pass.
327 ///
328 /// This function does nothing when autodiff is not enabled.
329 pub fn require_grad(self) -> Self {
330 self.set_require_grad(true)
331 }
332
333 /// Returns true if the tensor requires gradients during the backward pass.
334 pub fn is_require_grad(&self) -> bool {
335 match &self.primitive {
336 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
337 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
338 }
339 }
340
341 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
342 /// When tracked, the gradients will be available after the backward pass.
343 ///
344 /// This function does nothing when autodiff is not enabled.
345 pub fn set_require_grad(self, require_grad: bool) -> Self {
346 let primitive = match self.primitive {
347 TensorPrimitive::Float(tensor) => {
348 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
349 }
350 TensorPrimitive::QFloat(tensor) => {
351 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
352 }
353 };
354 Self::new(primitive)
355 }
356
357 /// Applies the relu function to the tensor.
358 pub(crate) fn relu(self) -> Self {
359 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
360 }
361
362 /// Calculate covaraince matrix between different entries alongside a given dimension.
363 ///
364 /// # Arguments
365 ///
366 /// * `size` - The size of the square matrix.
367 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
368 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
369 let n = self.dims()[dim];
370 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
371 centered
372 .clone()
373 .transpose()
374 .matmul(centered)
375 .div_scalar(n as f32 - correction_factor as f32)
376 }
377
378 /// Convert the tensor to a lower precision data type based on the quantization scheme.
379 ///
380 /// # Arguments
381 ///
382 /// * `scheme` - The quantization scheme.
383 /// * `qparams` - The pre-computed quantization parameters.
384 ///
385 /// # Returns
386 ///
387 /// The quantized tensor.
388 pub fn quantize(
389 self,
390 scheme: &QuantScheme,
391 qparams: QuantizationParameters<B>,
392 ) -> Tensor<B, D> {
393 Tensor::new(TensorPrimitive::QFloat(B::quantize(
394 self.primitive.tensor(),
395 scheme,
396 QuantizationParametersPrimitive {
397 scales: qparams.scales.primitive.tensor(),
398 },
399 )))
400 }
401
402 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
403 ///
404 /// # Arguments
405 ///
406 /// * `scheme` - The quantization scheme.
407 ///
408 /// # Returns
409 ///
410 /// The quantized tensor.
411 ///
412 /// # Notes
413 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
414 pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
415 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
416 self.primitive.tensor(),
417 scheme,
418 )))
419 }
420
421 /// Convert the tensor back to a higher precision data type.
422 ///
423 /// If the tensor is not quantized, its value is simply returned.
424 ///
425 /// # Returns
426 ///
427 /// The dequantized tensor.
428 pub fn dequantize(self) -> Tensor<B, D> {
429 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
430 }
431
432 /// Checks element wise if the tensor is close to another tensor.
433 ///
434 /// The tolerance is defined by the following equation:
435 ///
436 /// ```text
437 /// abs(a - b) <= (atol + rtol * abs(b))
438 ///
439 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
440 /// and `atol` is the absolute tolerance.
441 /// ```
442 ///
443 /// # Arguments
444 ///
445 /// * `other` - The tensor to compare with.
446 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
447 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
448 ///
449 /// # Returns
450 ///
451 /// A boolean tensor with the same shape as the input tensors.
452 ///
453 /// # Example
454 ///
455 /// ```rust
456 /// use burn_tensor::backend::Backend;
457 /// use burn_tensor::{Tensor, Shape};
458 ///
459 /// fn example<B: Backend>() {
460 /// let device = B::Device::default();
461 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
462 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
463 /// let tensor = tensor1.is_close(tensor2, None, None);
464 /// println!("{tensor}");
465 /// // [[true, true, true], [true, true, true]]
466 /// }
467 /// ```
468 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
469 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
470 let atol = atol.unwrap_or(DEFAULT_ATOL);
471
472 // check finite difference is close
473 let is_close_finite_val = self
474 .clone()
475 .sub(other.clone())
476 .abs()
477 .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
478 .bool_and(self.clone().is_finite())
479 .bool_and(other.clone().is_finite());
480
481 // check if both are infinite and have same sign
482 let inf_same_sign = self
483 .clone()
484 .is_finite()
485 .bool_not()
486 .bool_and(other.clone().is_finite().bool_not())
487 .bool_and(self.equal(other));
488
489 is_close_finite_val.bool_or(inf_same_sign)
490 }
491
492 /// Checks if all elements are close to another tensor.
493 ///
494 /// The tolerance is defined by the following equation:
495 ///
496 /// ```text
497 ///
498 /// abs(a - b) <= (atol + rtol * abs(b))
499 ///
500 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
501 /// and `atol` is the absolute tolerance.
502 ///
503 /// ```
504 ///
505 /// # Arguments
506 ///
507 /// * `other` - The tensor to compare with.
508 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
509 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
510 ///
511 /// # Returns
512 ///
513 /// A boolean scalar.
514 ///
515 /// # Remarks
516 ///
517 /// # Example
518 ///
519 /// ```rust
520 /// use burn_tensor::backend::Backend;
521 /// use burn_tensor::{Tensor, Shape};
522 ///
523 /// fn example<B: Backend>() {
524 /// let device = B::Device::default();
525 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
526 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
527 /// let result = tensor1.all_close(tensor2, None, None);
528 /// println!("{}", result);
529 /// // true
530 /// }
531 /// ```
532 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
533 self.is_close(other, rtol, atol)
534 .all()
535 .into_scalar()
536 .to_bool()
537 }
538
539 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
540 ///
541 /// # Returns
542 ///
543 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
544 ///
545 /// # Example
546 ///
547 /// ```rust
548 /// use burn_tensor::backend::Backend;
549 /// use burn_tensor::{Tensor, Bool, Shape};
550 ///
551 /// fn example<B: Backend>() {
552 /// let device = B::Device::default();
553 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
554 /// let tensor = tensor.is_nan();
555 /// println!("{tensor}");
556 /// // [[false, true, false], [false, false, false]]
557 /// }
558 /// ```
559 pub fn is_nan(self) -> Tensor<B, D, Bool> {
560 let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
561 Tensor::new(B::float_is_nan(self.primitive.tensor(), out_dtype))
562 }
563
564 /// Checks if the tensor contains any NaN values.
565 ///
566 /// # Returns
567 ///
568 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
569 ///
570 /// # Example
571 ///
572 /// ```rust
573 /// use burn_tensor::backend::Backend;
574 /// use burn_tensor::{Tensor, Bool, Shape};
575 ///
576 /// fn example<B: Backend>() {
577 /// let device = B::Device::default();
578 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
579 /// let tensor = tensor.contains_nan();
580 /// println!("{tensor}");
581 /// // [true]
582 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
583 /// let tensor = tensor.contains_nan();
584 /// println!("{tensor}");
585 /// // [false]
586 /// }
587 /// ```
588 pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
589 // Summing the tensor will result in NaN if the tensor contains any NaN values
590 // This is faster than checking each element individually
591 // because it rolls up the NaN values into a single value
592 let sum = self.sum();
593
594 sum.is_nan()
595 }
596
597 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
598 ///
599 /// # Returns
600 ///
601 /// A boolean tensor where `true` indicates that the value is infinite
602 ///
603 /// # Example
604 ///
605 /// ```rust
606 /// use burn_tensor::backend::Backend;
607 /// use burn_tensor::{Tensor, Bool, Shape};
608 ///
609 /// fn example<B: Backend>() {
610 /// let device = B::Device::default();
611 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
612 /// let tensor = tensor.is_finite();
613 /// println!("{tensor}");
614 /// // [[false, true, false], [false, false, false]]
615 /// }
616 /// ```
617 pub fn is_inf(self) -> Tensor<B, D, Bool> {
618 let out_dtype = get_device_settings::<B>(&self.device()).bool_dtype;
619 Tensor::new(B::float_is_inf(self.primitive.tensor(), out_dtype))
620 }
621
622 /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
623 ///
624 /// # Returns
625 ///
626 /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
627 /// either INF, -INF or NAN
628 ///
629 /// # Example
630 ///
631 /// ```rust
632 /// use burn_tensor::backend::Backend;
633 /// use burn_tensor::{Tensor, Bool, Shape};
634 ///
635 /// fn example<B: Backend>() {
636 /// let device = B::Device::default();
637 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
638 /// let tensor = tensor.is_finite();
639 /// println!("{tensor}");
640 /// // [[true, false, true], [false, true, true]]
641 /// }
642 /// ```
643 pub fn is_finite(self) -> Tensor<B, D, Bool> {
644 self.clone()
645 .is_nan()
646 .bool_not()
647 .bool_and(self.is_inf().bool_not())
648 }
649
650 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
651 /// using the given locations in [-1, 1].
652 ///
653 /// # Arguments
654 ///
655 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
656 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
657 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
658 ///
659 /// # Returns
660 ///
661 /// A tensor with shape (N, C, H_out, W_out)
662 ///
663 /// # Example
664 ///
665 /// ```ignore
666 /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
667 ///
668 /// // Default options (bilinear, zeros padding, align_corners=false)
669 /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());
670 ///
671 /// // Custom options
672 /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)
673 /// .with_padding_mode(GridSamplePaddingMode::Border)
674 /// .with_align_corners(true);
675 /// let output = tensor.grid_sample_2d(grid, options);
676 /// ```
677 pub fn grid_sample_2d(
678 self,
679 grid: Tensor<B, D>,
680 options: impl Into<GridSampleOptions>,
681 ) -> Tensor<B, D> {
682 Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
683 self.primitive.tensor(),
684 grid.primitive.tensor(),
685 options.into(),
686 )))
687 }
688
689 /// Computes the cross product of `self` and another tensor along a given dimension.
690 ///
691 /// Both `self` and `other` **must have size 3** along the specified `dim`,
692 /// because the cross product is only defined in three-dimensional space.
693 ///
694 /// # Arguments
695 ///
696 /// * `other` - The other tensor to take the cross product with.
697 /// * `dim` - The dimension along which to compute the cross product.
698 ///
699 /// # Returns
700 ///
701 /// A tensor containing the cross product of `self` and `other` along `dim`.
702 pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
703 let dim = dim.expect_dim_index(D);
704 check!(TensorCheck::cross(&self, &other, dim));
705 Tensor::new(TensorPrimitive::Float(B::float_cross(
706 self.primitive.tensor(),
707 other.primitive.tensor(),
708 dim,
709 )))
710 }
711
712 /// Applies element wise power operation with a float Tensor
713 ///
714 /// # Arguments
715 ///
716 /// * `other` - The tensor to apply the power operation with.
717 ///
718 /// # Example
719 ///
720 /// ```rust
721 /// use burn_tensor::backend::Backend;
722 /// use burn_tensor::{Tensor, Shape};
723 ///
724 /// fn example<B: Backend>() {
725 /// let device = B::Device::default();
726 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
727 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
728 /// let tensor = tensor1.powf(tensor2);
729 /// println!("{tensor}");
730 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
731 /// }
732 /// ```
733 pub fn powf(self, other: Self) -> Self {
734 let primitive = match (self.primitive, other.primitive) {
735 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
736 TensorPrimitive::Float(B::float_powf(lhs, rhs))
737 }
738 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs),
739 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
740 let dtype = rhs.dtype();
741 TensorPrimitive::Float(B::float_powf(B::dequantize(lhs, dtype.into()), rhs))
742 }
743 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
744 let dtype = lhs.dtype();
745 TensorPrimitive::Float(B::float_powf(lhs, B::dequantize(rhs, dtype.into())))
746 }
747 };
748
749 Tensor::new(primitive)
750 }
751
752 /// Applies element wise power operation with a float scalar
753 ///
754 /// # Arguments
755 ///
756 /// * `other` - The scalar to apply the power operation with.
757 ///
758 /// # Example
759 ///
760 /// ```rust
761 /// use burn_tensor::backend::Backend;
762 /// use burn_tensor::{Tensor, Shape};
763 ///
764 /// fn example<B: Backend>() {
765 /// let device = B::Device::default();
766 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
767 /// let tensor = tensor.powf_scalar(2.0);
768 /// println!("{tensor}");
769 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
770 /// }
771 /// ```
772 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
773 let rhs = Scalar::new(other, &self.dtype());
774
775 let primitive = match self.primitive {
776 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
777 TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
778 };
779
780 Tensor::new(primitive)
781 }
782}
783
784impl<const D: usize, B: Backend> Tensor<B, D> {
785 /// Draws samples from a categorical distribution defined by the last dimension
786 /// of the input tensor.
787 ///
788 /// The last dimension is treated as a (possibly unnormalized) set of weights
789 /// defining a categorical distribution over categories. All leading dimensions
790 /// are treated as batch dimensions. The method returns integer indices of the
791 /// sampled categories.
792 ///
793 /// # Arguments
794 ///
795 /// * `num_samples` - Number of samples to draw per distribution. Must be >= 1.
796 ///
797 /// # Panics
798 ///
799 /// Panics if `num_samples` is 0.
800 ///
801 /// # Note
802 ///
803 /// Distributions with all-zero weights produce undefined (NaN-based) sampling
804 /// results. Callers should ensure each distribution has at least one positive
805 /// weight.
806 ///
807 /// # Returns
808 ///
809 /// An integer tensor with the same shape as the input, except the last dimension
810 /// is replaced by `num_samples`, containing sampled category indices in
811 /// `[0, num_categories)`.
812 ///
813 /// # Example
814 ///
815 /// ```rust
816 /// use burn_tensor::backend::Backend;
817 /// use burn_tensor::Tensor;
818 ///
819 /// fn example<B: Backend>() {
820 /// let device = B::Device::default();
821 /// let probs = Tensor::<B, 2>::from_floats(
822 /// [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
823 /// &device,
824 /// );
825 /// let samples = probs.categorical(4);
826 /// // First row always samples index 1, second row always samples index 2
827 /// println!("{samples}");
828 /// }
829 /// ```
830 pub fn categorical(self, num_samples: usize) -> Tensor<B, D, Int> {
831 assert!(num_samples > 0, "categorical: num_samples must be >= 1");
832
833 let shape = self.shape();
834 let num_categories = shape[D - 1];
835 let batch_size = (shape.num_elements() / num_categories).max(1);
836 let device = self.device();
837
838 // Flatten leading dimensions into a single batch dimension: [batch, categories]
839 let flat: Tensor<B, 2> = self.reshape([batch_size, num_categories]);
840
841 // Normalize weights to probabilities
842 let sum = flat.clone().sum_dim(1); // [batch, 1]
843 let probs = flat / sum;
844
845 // Cumulative sum along categories dimension
846 let cumsum = probs.cumsum(1); // [batch, categories]
847
848 // Uniform random values for each sample
849 let uniform = Tensor::<B, 2>::random(
850 [batch_size, num_samples],
851 Distribution::Uniform(0.0, 1.0),
852 &device,
853 ); // [batch, num_samples]
854
855 // Expand dimensions for broadcasting:
856 // cumsum: [batch, categories, 1]
857 // uniform: [batch, 1, num_samples]
858 let cumsum_3d: Tensor<B, 3> = cumsum.unsqueeze_dim(2);
859 let uniform_3d: Tensor<B, 3> = uniform.unsqueeze_dim(1);
860
861 // Count categories where cumsum < uniform (inverse CDF)
862 let mask: Tensor<B, 3, Bool> = cumsum_3d.lower(uniform_3d);
863 let indices: Tensor<B, 2, Int> = mask.int().sum_dim(1).squeeze_dim::<2>(1);
864
865 // Clamp to valid range to guard against floating-point imprecision in cumsum
866 let indices = indices.clamp(0, num_categories as i64 - 1);
867
868 // Reshape back to [...leading_dims, num_samples]
869 let mut out_shape = shape;
870 out_shape[D - 1] = num_samples;
871 indices.reshape(out_shape)
872 }
873}
874
875#[cfg(feature = "distributed")]
876impl<const D: usize, B> Tensor<B, D>
877where
878 B: AutodiffBackend,
879{
880 /// Returns true if the tensor is marked as distributed.
881 pub fn is_distributed(&self) -> bool {
882 match &self.primitive {
883 TensorPrimitive::Float(tensor) => B::is_distributed(tensor),
884 TensorPrimitive::QFloat(_) => unimplemented!(),
885 }
886 }
887
888 /// Mark the tensor as distributed.
889 ///
890 /// This function does nothing when autodiff or distributed is not enabled.
891 pub fn set_distributed(self, param_id: DistributedParamId) -> Self {
892 let primitive = match self.primitive {
893 TensorPrimitive::Float(tensor) => {
894 TensorPrimitive::Float(B::set_distributed_params(tensor, param_id))
895 }
896 TensorPrimitive::QFloat(_) => unimplemented!(),
897 };
898 Self::new(primitive)
899 }
900}
901
902impl<B, const D: usize, K> Tensor<B, D, K>
903where
904 B: Backend,
905 K: FloatMathOps<B>,
906{
907 /// Applies element wise square operation.
908 ///
909 #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
910 #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
911 pub fn square(self) -> Self {
912 Self::new(K::square(self.primitive))
913 }
914
915 /// Applies element wise exponential operation.
916 ///
917 #[cfg_attr(doc, doc = r#"$y_i = e^{x_i}$"#)]
918 #[cfg_attr(not(doc), doc = "`y = e^x`")]
919 pub fn exp(self) -> Self {
920 Self::new(K::exp(self.primitive))
921 }
922
923 /// Applies element wise natural logarithm of one plus the input tensor.
924 ///
925 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
926 #[cfg_attr(not(doc), doc = "`y_i = log1p(x_i)`")]
927 pub fn log1p(self) -> Self {
928 Self::new(K::log1p(self.primitive))
929 }
930
931 /// Applies element wise natural log operation *ln*.
932 ///
933 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
934 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
935 pub fn log(self) -> Self {
936 Self::new(K::log(self.primitive))
937 }
938
939 /// Applies element wise square root operation.
940 ///
941 pub fn sqrt(self) -> Self {
942 Tensor::new(K::sqrt(self.primitive))
943 }
944 /// Applies element wise cosine operation.
945 ///
946 #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
947 #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
948 pub fn cos(self) -> Self {
949 Tensor::new(K::cos(self.primitive))
950 }
951
952 /// Applies element wise sine operation.
953 ///
954 #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
955 #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
956 pub fn sin(self) -> Self {
957 Tensor::new(K::sin(self.primitive))
958 }
959
960 /// Applies element wise tangent operation.
961 ///
962 #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
963 #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
964 pub fn tan(self) -> Self {
965 Tensor::new(K::tan(self.primitive))
966 }
967
968 /// Applies element wise hyperbolic cosine operation.
969 ///
970 #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
971 #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
972 ///
973 /// # Example
974 ///
975 /// ```rust
976 /// use burn_tensor::backend::Backend;
977 /// use burn_tensor::Tensor;
978 ///
979 /// fn example<B: Backend>() {
980 /// let device = Default::default();
981 ///
982 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
983 /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
984 /// }
985 /// ```
986 pub fn cosh(self) -> Self {
987 Tensor::new(K::cosh(self.primitive))
988 }
989
990 /// Applies element wise hyperbolic sine operation.
991 ///
992 #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
993 #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
994 ///
995 /// # Example
996 ///
997 /// ```rust
998 /// use burn_tensor::backend::Backend;
999 /// use burn_tensor::Tensor;
1000 ///
1001 /// fn example<B: Backend>() {
1002 /// let device = Default::default();
1003 ///
1004 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
1005 /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
1006 /// }
1007 /// ```
1008 pub fn sinh(self) -> Self {
1009 Tensor::new(K::sinh(self.primitive))
1010 }
1011
1012 /// Applies element wise hyperbolic tangent operation.
1013 ///
1014 #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
1015 #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
1016 ///
1017 /// # Example
1018 ///
1019 /// ```rust
1020 /// use burn_tensor::backend::Backend;
1021 /// use burn_tensor::Tensor;
1022 ///
1023 /// fn example<B: Backend>() {
1024 /// let device = Default::default();
1025 ///
1026 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
1027 /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
1028 /// }
1029 /// ```
1030 pub fn tanh(self) -> Self {
1031 Tensor::new(K::tanh(self.primitive))
1032 }
1033
1034 /// Applies element wise inverse cosine operation.
1035 ///
1036 #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)]
1037 #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")]
1038 ///
1039 /// # Example
1040 ///
1041 /// ```rust
1042 /// use burn_tensor::backend::Backend;
1043 /// use burn_tensor::Tensor;
1044 ///
1045 /// fn example<B: Backend>() {
1046 /// let device = Default::default();
1047 ///
1048 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
1049 /// println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0]
1050 /// }
1051 /// ```
1052 pub fn acos(self) -> Self {
1053 Tensor::new(K::acos(self.primitive))
1054 }
1055
1056 /// Applies element wise inverse hyperbolic cosine operation.
1057 ///
1058 #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)]
1059 #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")]
1060 ///
1061 /// # Example
1062 ///
1063 /// ```rust
1064 /// use burn_tensor::backend::Backend;
1065 /// use burn_tensor::Tensor;
1066 ///
1067 /// fn example<B: Backend>() {
1068 /// let device = Default::default();
1069 ///
1070 /// let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);
1071 /// println!("{}", tensor.acosh()); // [0.0000, 1.3170, 1.7627]
1072 /// }
1073 /// ```
1074 pub fn acosh(self) -> Self {
1075 Tensor::new(K::acosh(self.primitive))
1076 }
1077
1078 /// Applies element wise inverse sine operation.
1079 ///
1080 #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)]
1081 #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")]
1082 ///
1083 /// # Example
1084 ///
1085 /// ```rust
1086 /// use burn_tensor::backend::Backend;
1087 /// use burn_tensor::Tensor;
1088 ///
1089 /// fn example<B: Backend>() {
1090 /// let device = Default::default();
1091 ///
1092 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
1093 /// println!("{}", tensor.asin()); // [ 0.0000, -1.5708, 1.5708]
1094 /// }
1095 /// ```
1096 pub fn asin(self) -> Self {
1097 Tensor::new(K::asin(self.primitive))
1098 }
1099
1100 /// Applies element wise inverse hyperbolic sine operation.
1101 ///
1102 #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)]
1103 #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")]
1104 ///
1105 /// # Example
1106 ///
1107 /// ```rust
1108 /// use burn_tensor::backend::Backend;
1109 /// use burn_tensor::Tensor;
1110 ///
1111 /// fn example<B: Backend>() {
1112 /// let device = Default::default();
1113 ///
1114 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);
1115 /// println!("{}", tensor.asinh()); // [ 0.0000, -0.8814, 0.8814]
1116 /// }
1117 /// ```
1118 pub fn asinh(self) -> Self {
1119 Tensor::new(K::asinh(self.primitive))
1120 }
1121
1122 /// Applies element wise inverse tangent operation.
1123 ///
1124 #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)]
1125 #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")]
1126 ///
1127 /// # Example
1128 ///
1129 /// ```rust
1130 /// use burn_tensor::backend::Backend;
1131 /// use burn_tensor::Tensor;
1132 ///
1133 /// fn example<B: Backend>() {
1134 /// let device = Default::default();
1135 ///
1136 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
1137 /// println!("{}", tensor.atan()); // [ 0.0, -0.7854, 1.1071]
1138 /// }
1139 /// ```
1140 pub fn atan(self) -> Self {
1141 Tensor::new(K::atan(self.primitive))
1142 }
1143
1144 /// Applies element wise inverse hyperbolic tangent operation.
1145 ///
1146 #[cfg_attr(doc, doc = r#"$y_i = \atanh\(x_i\)$"#)]
1147 #[cfg_attr(not(doc), doc = "`y_i = atanh(x_i)`")]
1148 ///
1149 /// # Example
1150 ///
1151 /// ```rust
1152 /// use burn_tensor::backend::Backend;
1153 /// use burn_tensor::Tensor;
1154 ///
1155 /// fn example<B: Backend>() {
1156 /// let device = Default::default();
1157 ///
1158 /// let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);
1159 /// println!("{}", tensor.atanh()); // [ 0.0, -0.5493, 0.5493]
1160 /// }
1161 /// ```
1162 pub fn atanh(self) -> Self {
1163 Tensor::new(K::atanh(self.primitive))
1164 }
1165
1166 /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.
1167 ///
1168 #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)]
1169 #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")]
1170 ///
1171 /// # Example
1172 ///
1173 /// ```rust
1174 /// use burn_tensor::backend::Backend;
1175 /// use burn_tensor::Tensor;
1176 ///
1177 /// fn example<B: Backend>() {
1178 /// let device = Default::default();
1179 ///
1180 /// let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);
1181 /// let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);
1182 /// println!("{}", lhs.atan2(rhs)); // [-1.1071, 2.0344, -2.0344]
1183 /// }
1184 /// ```
1185 pub fn atan2(self, other: Self) -> Self {
1186 Tensor::new(K::atan2(self.primitive, other.primitive))
1187 }
1188}