burn_tensor/tensor/api/float.rs
1use crate::AsIndex;
2use crate::FloatDType;
3use crate::Tensor;
4use crate::cast::ToElement;
5use crate::check;
6use crate::check::TensorCheck;
7use crate::indexing::canonicalize_dim;
8use crate::ops::InterpolateMode;
9use crate::quantization::{QuantScheme, QuantizationParameters};
10use crate::tensor::backend::Backend;
11use crate::tensor::stats;
12use crate::tensor::{Distribution, TensorData};
13use crate::{Int, TensorPrimitive};
14
15use super::Bool;
16
17/// Default RTOL value for `is_close` and `all_close`.
18pub const DEFAULT_RTOL: f64 = 1e-5;
19
20/// Default ATOL value for `is_close` and `all_close`.
21pub const DEFAULT_ATOL: f64 = 1e-8;
22
23impl<const D: usize, B> Tensor<B, D>
24where
25 B: Backend,
26{
27 /// Applies element wise exponential operation.
28 ///
29 #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
30 #[cfg_attr(not(doc), doc = "`y = e^x`")]
31 pub fn exp(self) -> Self {
32 Self::new(TensorPrimitive::Float(B::float_exp(
33 self.primitive.tensor(),
34 )))
35 }
36
37 /// Applies element wise natural log operation *ln*.
38 ///
39 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
40 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
41 pub fn log(self) -> Self {
42 Self::new(TensorPrimitive::Float(B::float_log(
43 self.primitive.tensor(),
44 )))
45 }
46
47 /// Applies the natural logarithm of one plus the input tensor, element-wise.
48 ///
49 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
50 #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
51 pub fn log1p(self) -> Self {
52 Self::new(TensorPrimitive::Float(B::float_log1p(
53 self.primitive.tensor(),
54 )))
55 }
56
57 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
58 ///
59 #[cfg_attr(
60 doc,
61 doc = r#"
62$y_i = \text{erf}\(x_i\)$
63
64The error function is defined as:
65
66$$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$
67"#
68 )]
69 #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")]
70 pub fn erf(self) -> Self {
71 Self::new(TensorPrimitive::Float(B::float_erf(
72 self.primitive.tensor(),
73 )))
74 }
75
76 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
77 /// (or multiplicative inverse) element wise.
78 ///
79 #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)]
80 #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")]
81 pub fn recip(self) -> Self {
82 Self::new(TensorPrimitive::Float(B::float_recip(
83 self.primitive.tensor(),
84 )))
85 }
86
87 /// Applies element wise square operation.
88 ///
89 #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)]
90 #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")]
91 pub fn square(self) -> Self {
92 self.powi_scalar(2)
93 }
94
95 /// Applies element wise root square operation.
96 ///
97 #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
98 #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
99 pub fn sqrt(self) -> Self {
100 Self::new(TensorPrimitive::Float(B::float_sqrt(
101 self.primitive.tensor(),
102 )))
103 }
104
105 /// Applies element wise cosine operation.
106 ///
107 #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)]
108 #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")]
109 pub fn cos(self) -> Self {
110 Self::new(TensorPrimitive::Float(B::float_cos(
111 self.primitive.tensor(),
112 )))
113 }
114
115 /// Applies element wise sine operation.
116 ///
117 #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)]
118 #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")]
119 pub fn sin(self) -> Self {
120 Self::new(TensorPrimitive::Float(B::float_sin(
121 self.primitive.tensor(),
122 )))
123 }
124
125 /// Applies element wise tangent operation.
126 ///
127 #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)]
128 #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")]
129 pub fn tan(self) -> Self {
130 Self::new(TensorPrimitive::Float(B::float_tan(
131 self.primitive.tensor(),
132 )))
133 }
134
135 /// Applies element wise hyperbolic cosine operation.
136 ///
137 #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)]
138 #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")]
139 ///
140 /// # Example
141 ///
142 /// ```rust
143 /// use burn_tensor::backend::Backend;
144 /// use burn_tensor::Tensor;
145 ///
146 /// fn example<B: Backend>() {
147 /// let device = Default::default();
148 ///
149 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
150 /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621]
151 /// }
152 /// ```
153 pub fn cosh(self) -> Self {
154 Self::new(TensorPrimitive::Float(B::float_cosh(
155 self.primitive.tensor(),
156 )))
157 }
158
159 /// Applies element wise hyperbolic sine operation.
160 ///
161 #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)]
162 #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")]
163 ///
164 /// # Example
165 ///
166 /// ```rust
167 /// use burn_tensor::backend::Backend;
168 /// use burn_tensor::Tensor;
169 ///
170 /// fn example<B: Backend>() {
171 /// let device = Default::default();
172 ///
173 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
174 /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269]
175 /// }
176 /// ```
177 pub fn sinh(self) -> Self {
178 Self::new(TensorPrimitive::Float(B::float_sinh(
179 self.primitive.tensor(),
180 )))
181 }
182
183 /// Applies element wise hyperbolic tangent operation.
184 ///
185 #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)]
186 #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")]
187 ///
188 /// # Example
189 ///
190 /// ```rust
191 /// use burn_tensor::backend::Backend;
192 /// use burn_tensor::Tensor;
193 ///
194 /// fn example<B: Backend>() {
195 /// let device = Default::default();
196 ///
197 /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);
198 /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640]
199 /// }
200 /// ```
201 pub fn tanh(self) -> Self {
202 Self::new(TensorPrimitive::Float(B::float_tanh(
203 self.primitive.tensor(),
204 )))
205 }
206
207 /// Applies element wise round operation.
208 ///
209 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
210 /// strategy, with halfway cases rounded to the nearest even integer value.
211 pub fn round(self) -> Self {
212 Self::new(TensorPrimitive::Float(B::float_round(
213 self.primitive.tensor(),
214 )))
215 }
216
217 /// Applies element wise floor operation.
218 pub fn floor(self) -> Self {
219 Self::new(TensorPrimitive::Float(B::float_floor(
220 self.primitive.tensor(),
221 )))
222 }
223
224 /// Applies element wise ceil operation.
225 pub fn ceil(self) -> Self {
226 Self::new(TensorPrimitive::Float(B::float_ceil(
227 self.primitive.tensor(),
228 )))
229 }
230
231 /// Create a tensor from floats (f32) on a given device.
232 ///
233 /// # Example
234 ///
235 /// ```rust
236 /// use burn_tensor::backend::Backend;
237 /// use burn_tensor::Tensor;
238 ///
239 /// fn example<B: Backend>() {
240 /// let device = B::Device::default();
241 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
242 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
243 /// }
244 /// ```
245 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
246 Self::from_data(floats.into().convert::<f32>(), device)
247 }
248
249 /// Returns a new tensor with the same shape and device as the current tensor and the data
250 /// cast to Integer.
251 ///
252 /// # Example
253 ///
254 /// ```rust
255 /// use burn_tensor::backend::Backend;
256 /// use burn_tensor::Tensor;
257 ///
258 /// fn example<B: Backend>() {
259 /// let device = Default::default();
260 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
261 /// let int_tensor = float_tensor.int();
262 /// }
263 /// ```
264 pub fn int(self) -> Tensor<B, D, Int> {
265 Tensor::new(B::float_into_int(self.primitive.tensor()))
266 }
267
268 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random
269 /// values sampled from the given distribution.
270 pub fn random_like(&self, distribution: Distribution) -> Self {
271 Self::new(TensorPrimitive::Float(B::float_random(
272 self.shape(),
273 distribution,
274 &self.device(),
275 )))
276 .cast(self.dtype())
277 }
278
279 /// Calculate the variance along the given dimension.
280 pub fn var(self, dim: usize) -> Self {
281 stats::var(self, dim)
282 }
283
284 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
285 pub fn var_bias(self, dim: usize) -> Self {
286 stats::var_bias(self, dim)
287 }
288
289 /// Calculate the variance along the given dimension and also returns the mean.
290 pub fn var_mean(self, dim: usize) -> (Self, Self) {
291 let mean = self.clone().mean_dim(dim);
292 let var = stats::var_with_mean(self, mean.clone(), dim);
293 (var, mean)
294 }
295
296 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
297 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
298 let mean = self.clone().mean_dim(dim);
299 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
300 (var, mean)
301 }
302
303 /// Converts a tensor to the specified floating point data type.
304 ///
305 /// This is always a no-op when casting to the current dtype.
306 ///
307 /// # Warning
308 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
309 /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
310 pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
311 let dtype = dtype.into();
312 let self_type: FloatDType = self.dtype().into();
313 if dtype == self_type {
314 // no-op.
315 return self;
316 }
317
318 Tensor::new(TensorPrimitive::Float(B::float_cast(
319 self.primitive.tensor(),
320 dtype,
321 )))
322 }
323
324 /// Detach the current tensor from the autodiff graph.
325 ///
326 /// This function does nothing when autodiff is not enabled.
327 /// This can be used in batchers or elsewhere to ensure that previous operations are not
328 /// considered in the autodiff graph.
329 pub fn detach(self) -> Self {
330 Self::new(TensorPrimitive::Float(B::float_detach(
331 self.primitive.tensor(),
332 )))
333 }
334
335 /// Mark the tensor to keep gradients during the backward pass.
336 ///
337 /// This function does nothing when autodiff is not enabled.
338 pub fn require_grad(self) -> Self {
339 self.set_require_grad(true)
340 }
341
342 /// Returns true if the tensor requires gradients during the backward pass.
343 pub fn is_require_grad(&self) -> bool {
344 match &self.primitive {
345 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
346 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
347 }
348 }
349
350 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
351 /// When tracked, the gradients will be available after the backward pass.
352 ///
353 /// This function does nothing when autodiff is not enabled.
354 pub fn set_require_grad(self, require_grad: bool) -> Self {
355 let primitive = match self.primitive {
356 TensorPrimitive::Float(tensor) => {
357 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
358 }
359 TensorPrimitive::QFloat(tensor) => {
360 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
361 }
362 };
363 Self::new(primitive)
364 }
365
366 /// Applies the relu function to the tensor.
367 pub(crate) fn relu(self) -> Self {
368 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
369 }
370
371 /// Calculate covaraince matrix between different entries alongside a given dimension.
372 ///
373 /// # Arguments
374 ///
375 /// * `size` - The size of the square matrix.
376 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
377 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
378 let n = self.dims()[dim];
379 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
380 centered
381 .clone()
382 .transpose()
383 .matmul(centered)
384 .div_scalar(n as f32 - correction_factor as f32)
385 }
386
387 /// Convert the tensor to a lower precision data type based on the quantization scheme.
388 ///
389 /// # Arguments
390 ///
391 /// * `scheme` - The quantization scheme.
392 /// * `qparams` - The pre-computed quantization parameters.
393 ///
394 /// # Returns
395 ///
396 /// The quantized tensor.
397 pub fn quantize(
398 self,
399 scheme: &QuantScheme,
400 qparams: QuantizationParameters<B>,
401 ) -> Tensor<B, D> {
402 Tensor::new(TensorPrimitive::QFloat(B::quantize(
403 self.primitive.tensor(),
404 scheme,
405 qparams.into(),
406 )))
407 }
408
409 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
410 ///
411 /// # Arguments
412 ///
413 /// * `scheme` - The quantization scheme.
414 ///
415 /// # Returns
416 ///
417 /// The quantized tensor.
418 ///
419 /// # Notes
420 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
421 pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {
422 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
423 self.primitive.tensor(),
424 scheme,
425 )))
426 }
427
428 /// Convert the tensor back to a higher precision data type.
429 ///
430 /// If the tensor is not quantized, its value is simply returned.
431 ///
432 /// # Returns
433 ///
434 /// The dequantized tensor.
435 pub fn dequantize(self) -> Tensor<B, D> {
436 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
437 }
438
439 /// Checks element wise if the tensor is close to another tensor.
440 ///
441 /// The tolerance is defined by the following equation:
442 ///
443 /// ```text
444 /// abs(a - b) <= (atol + rtol * abs(b))
445 ///
446 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
447 /// and `atol` is the absolute tolerance.
448 /// ```
449 ///
450 /// # Arguments
451 ///
452 /// * `other` - The tensor to compare with.
453 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
454 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
455 ///
456 /// # Returns
457 ///
458 /// A boolean tensor with the same shape as the input tensors.
459 ///
460 /// # Example
461 ///
462 /// ```rust
463 /// use burn_tensor::backend::Backend;
464 /// use burn_tensor::{Tensor, Shape};
465 ///
466 /// fn example<B: Backend>() {
467 /// let device = B::Device::default();
468 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
469 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
470 /// let tensor = tensor1.is_close(tensor2, None, None);
471 /// println!("{tensor}");
472 /// // [[true, true, true], [true, true, true]]
473 /// }
474 /// ```
475 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
476 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
477 let atol = atol.unwrap_or(DEFAULT_ATOL);
478
479 // check finite difference is close
480 let is_close_finite_val = self
481 .clone()
482 .sub(other.clone())
483 .abs()
484 .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))
485 .bool_and(self.clone().is_finite())
486 .bool_and(other.clone().is_finite());
487
488 // check if both are infinite and have same sign
489 let inf_same_sign = self
490 .clone()
491 .is_finite()
492 .bool_not()
493 .bool_and(other.clone().is_finite().bool_not())
494 .bool_and(self.equal(other));
495
496 is_close_finite_val.bool_or(inf_same_sign)
497 }
498
499 /// Checks if all elements are close to another tensor.
500 ///
501 /// The tolerance is defined by the following equation:
502 ///
503 /// ```text
504 ///
505 /// abs(a - b) <= (atol + rtol * abs(b))
506 ///
507 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
508 /// and `atol` is the absolute tolerance.
509 ///
510 /// ```
511 ///
512 /// # Arguments
513 ///
514 /// * `other` - The tensor to compare with.
515 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
516 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
517 ///
518 /// # Returns
519 ///
520 /// A boolean scalar.
521 ///
522 /// # Remarks
523 ///
524 /// # Example
525 ///
526 /// ```rust
527 /// use burn_tensor::backend::Backend;
528 /// use burn_tensor::{Tensor, Shape};
529 ///
530 /// fn example<B: Backend>() {
531 /// let device = B::Device::default();
532 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
533 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
534 /// let result = tensor1.all_close(tensor2, None, None);
535 /// println!("{}", result);
536 /// // true
537 /// }
538 /// ```
539 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
540 self.is_close(other, rtol, atol)
541 .all()
542 .into_scalar()
543 .to_bool()
544 }
545
546 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
547 ///
548 /// # Returns
549 ///
550 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
551 ///
552 /// # Example
553 ///
554 /// ```rust
555 /// use burn_tensor::backend::Backend;
556 /// use burn_tensor::{Tensor, Bool, Shape};
557 ///
558 /// fn example<B: Backend>() {
559 /// let device = B::Device::default();
560 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
561 /// let tensor = tensor.is_nan();
562 /// println!("{tensor}");
563 /// // [[false, true, false], [false, false, false]]
564 /// }
565 /// ```
566 pub fn is_nan(self) -> Tensor<B, D, Bool> {
567 Tensor::new(B::float_is_nan(self.primitive.tensor()))
568 }
569
570 /// Checks if the tensor contains any NaN values.
571 ///
572 /// # Returns
573 ///
574 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
575 ///
576 /// # Example
577 ///
578 /// ```rust
579 /// use burn_tensor::backend::Backend;
580 /// use burn_tensor::{Tensor, Bool, Shape};
581 ///
582 /// fn example<B: Backend>() {
583 /// let device = B::Device::default();
584 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
585 /// let tensor = tensor.contains_nan();
586 /// println!("{tensor}");
587 /// // [true]
588 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
589 /// let tensor = tensor.contains_nan();
590 /// println!("{tensor}");
591 /// // [false]
592 /// }
593 /// ```
594 pub fn contains_nan(self) -> Tensor<B, 1, Bool> {
595 // Summing the tensor will result in NaN if the tensor contains any NaN values
596 // This is faster than checking each element individually
597 // because it rolls up the NaN values into a single value
598 let sum = self.sum();
599
600 sum.is_nan()
601 }
602
603 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
604 ///
605 /// # Returns
606 ///
607 /// A boolean tensor where `true` indicates that the value is infinite
608 ///
609 /// # Example
610 ///
611 /// ```rust
612 /// use burn_tensor::backend::Backend;
613 /// use burn_tensor::{Tensor, Bool, Shape};
614 ///
615 /// fn example<B: Backend>() {
616 /// let device = B::Device::default();
617 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
618 /// let tensor = tensor.is_finite();
619 /// println!("{tensor}");
620 /// // [[false, true, false], [false, false, false]]
621 /// }
622 /// ```
623 pub fn is_inf(self) -> Tensor<B, D, Bool> {
624 Tensor::new(B::float_is_inf(self.primitive.tensor()))
625 }
626
627 /// Returns a new tensor with boolean elements indicating whether each element of the input is finite
628 ///
629 /// # Returns
630 ///
631 /// A boolean tensor where `true` indicates that the value is finite and `false` indicates
632 /// either INF, -INF or NAN
633 ///
634 /// # Example
635 ///
636 /// ```rust
637 /// use burn_tensor::backend::Backend;
638 /// use burn_tensor::{Tensor, Bool, Shape};
639 ///
640 /// fn example<B: Backend>() {
641 /// let device = B::Device::default();
642 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);
643 /// let tensor = tensor.is_finite();
644 /// println!("{tensor}");
645 /// // [[true, false, true], [false, true, true]]
646 /// }
647 /// ```
648 pub fn is_finite(self) -> Tensor<B, D, Bool> {
649 self.clone()
650 .is_nan()
651 .bool_not()
652 .bool_and(self.is_inf().bool_not())
653 }
654
655 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
656 /// using the given locations in [-1, 1].
657 ///
658 /// Interpolation is bilinear.
659 /// Padding is border: out of bounds locations will be clamped to the nearest border
660 ///
661 /// # Arguments
662 ///
663 /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
664 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
665 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
666 /// * `method` - How to interpolate between samples
667 ///
668 /// # Returns
669 ///
670 /// A tensor with shape (N, C, H_out, W_out)
671 pub fn grid_sample_2d(self, grid: Tensor<B, D>, method: InterpolateMode) -> Tensor<B, D> {
672 Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(
673 self.primitive.tensor(),
674 grid.primitive.tensor(),
675 method,
676 )))
677 }
678
679 /// Computes the cross product of `self` and another tensor along a given dimension.
680 ///
681 /// Both `self` and `other` **must have size 3** along the specified `dim`,
682 /// because the cross product is only defined in three-dimensional space.
683 ///
684 /// # Arguments
685 ///
686 /// * `other` - The other tensor to take the cross product with.
687 /// * `dim` - The dimension along which to compute the cross product.
688 ///
689 /// # Returns
690 ///
691 /// A tensor containing the cross product of `self` and `other` along `dim`.
692 pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {
693 let dim = canonicalize_dim(dim, D, false);
694 check!(TensorCheck::cross(&self, &other, dim));
695 Tensor::new(TensorPrimitive::Float(B::float_cross(
696 self.primitive.tensor(),
697 other.primitive.tensor(),
698 dim,
699 )))
700 }
701}