burn_tensor/tensor/api/
float.rs

1use alloc::vec::Vec;
2use core::convert::TryInto;
3
4use crate::check::TensorCheck;
5use crate::quantization::{QuantizationParameters, QuantizationScheme};
6use crate::tensor::backend::Backend;
7use crate::tensor::stats;
8use crate::tensor::{Distribution, Shape, TensorData};
9use crate::Tensor;
10use crate::{check, FloatDType};
11use crate::{Int, TensorPrimitive};
12
13impl<const D: usize, B> Tensor<B, D>
14where
15    B: Backend,
16{
17    /// Executes an operation on the tensor and modifies its value.
18    ///
19    /// # Notes
20    ///
21    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
22    /// no other reference pointing to the same tensor.
23    ///
24    /// Wrapping operations with inplace is not an optimization, it's mainly there if you
25    /// want to mutate a tensor by using owned operations. A plausible usage would be to
26    /// update the weights of a mutable model reference.
27    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
28        let mut tensor_owned = Tensor::empty([0; D], &self.device());
29        core::mem::swap(&mut tensor_owned, self);
30
31        let mut tensor_new = func(tensor_owned);
32        core::mem::swap(&mut tensor_new, self);
33    }
34
35    /// Applies element wise exponential operation.
36    ///
37    /// `y = e^x`
38    pub fn exp(self) -> Self {
39        Self::new(TensorPrimitive::Float(B::float_exp(
40            self.primitive.tensor(),
41        )))
42    }
43
44    /// Applies element wise natural log operation *ln*.
45    ///
46    /// `y = log(x)`
47    pub fn log(self) -> Self {
48        Self::new(TensorPrimitive::Float(B::float_log(
49            self.primitive.tensor(),
50        )))
51    }
52
53    /// Applies the natural logarithm of one plus the input tensor, element-wise.
54    ///
55    /// `y = log(x+1)`
56    pub fn log1p(self) -> Self {
57        Self::new(TensorPrimitive::Float(B::float_log1p(
58            self.primitive.tensor(),
59        )))
60    }
61
62    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
63    ///
64    /// `y = erf(x)`
65    pub fn erf(self) -> Self {
66        Self::new(TensorPrimitive::Float(B::float_erf(
67            self.primitive.tensor(),
68        )))
69    }
70
71    /// Applies element wise reciprocal operation.
72    pub fn recip(self) -> Self {
73        Self::new(TensorPrimitive::Float(B::float_recip(
74            self.primitive.tensor(),
75        )))
76    }
77
78    /// Applies element wise root square operation.
79    pub fn sqrt(self) -> Self {
80        Self::new(TensorPrimitive::Float(B::float_sqrt(
81            self.primitive.tensor(),
82        )))
83    }
84
85    /// Applies element wise cosine operation.
86    pub fn cos(self) -> Self {
87        Self::new(TensorPrimitive::Float(B::float_cos(
88            self.primitive.tensor(),
89        )))
90    }
91
92    /// Applies element wise sine operation.
93    pub fn sin(self) -> Self {
94        Self::new(TensorPrimitive::Float(B::float_sin(
95            self.primitive.tensor(),
96        )))
97    }
98
99    /// Applies element wise hyperbolic tangent operation.
100    pub fn tanh(self) -> Self {
101        Self::new(TensorPrimitive::Float(B::float_tanh(
102            self.primitive.tensor(),
103        )))
104    }
105
106    /// Applies element wise round operation.
107    ///
108    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
109    /// strategy, with halfway cases rounded to the nearest even integer value.
110    pub fn round(self) -> Self {
111        Self::new(TensorPrimitive::Float(B::float_round(
112            self.primitive.tensor(),
113        )))
114    }
115
116    /// Applies element wise floor operation.
117    pub fn floor(self) -> Self {
118        Self::new(TensorPrimitive::Float(B::float_floor(
119            self.primitive.tensor(),
120        )))
121    }
122
123    /// Applies element wise ceil operation.
124    pub fn ceil(self) -> Self {
125        Self::new(TensorPrimitive::Float(B::float_ceil(
126            self.primitive.tensor(),
127        )))
128    }
129
130    /// Create a tensor from floats (f32) on a given device.
131    ///
132    /// # Example
133    ///
134    /// ```rust
135    /// use burn_tensor::backend::Backend;
136    /// use burn_tensor::Tensor;
137    ///
138    /// fn example<B: Backend>() {
139    ///     let device = B::Device::default();
140    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
141    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
142    /// }
143    /// ```
144    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
145        Self::from_data(floats.into().convert::<f32>(), device)
146    }
147
148    /// Returns a new tensor with the same shape and device as the current tensor and the data
149    /// cast to Integer.
150    ///
151    /// # Example
152    ///
153    /// ```rust
154    /// use burn_tensor::backend::Backend;
155    /// use burn_tensor::Tensor;
156    ///
157    /// fn example<B: Backend>() {
158    ///     let device = Default::default();
159    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
160    ///     let int_tensor = float_tensor.int();
161    /// }
162    /// ```
163    pub fn int(self) -> Tensor<B, D, Int> {
164        Tensor::new(B::float_into_int(self.primitive.tensor()))
165    }
166
167    /// Returns a new tensor with the same shape and device as the current tensor filled random
168    /// values sampled from the given distribution.
169    pub fn random_like(&self, distribution: Distribution) -> Self {
170        Tensor::new(TensorPrimitive::Float(B::float_random(
171            self.shape(),
172            distribution,
173            &self.device(),
174        )))
175    }
176
177    /// Create a one hot tensor.
178    ///
179    /// # Example
180    ///
181    /// ```rust
182    /// use burn_tensor::backend::Backend;
183    /// use burn_tensor::Tensor;
184    ///
185    /// fn example<B: Backend>() {
186    ///     let device = Default::default();
187    ///     let one_hot = Tensor::<B, 1>::one_hot(2, 10, &device);
188    ///     println!("{}", one_hot.to_data());
189    ///     // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
190    /// }
191    /// ```
192    pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self {
193        check!(TensorCheck::one_hot_index(index, num_classes));
194
195        let mut dims = [1; D];
196        dims[D - 1] = num_classes;
197        let shape = Shape::new(dims);
198        let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect();
199        let tensor = Tensor::zeros(shape, device);
200        let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
201        ranges[D - 1] = index..index + 1;
202
203        tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device))
204    }
205
206    /// Applies the matrix multiplication operation.
207    ///
208    /// `C = AB`
209    ///
210    /// # Panics
211    ///
212    /// If the two tensors don't have a compatible shape.
213    pub fn matmul(self, other: Self) -> Self {
214        check!(TensorCheck::matmul(&self, &other));
215        Self::new(TensorPrimitive::Float(B::float_matmul(
216            self.primitive.tensor(),
217            other.primitive.tensor(),
218        )))
219    }
220
221    /// Calculate the variance along the given dimension.
222    pub fn var(self, dim: usize) -> Self {
223        stats::var(self, dim)
224    }
225
226    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
227    pub fn var_bias(self, dim: usize) -> Self {
228        stats::var_bias(self, dim)
229    }
230
231    /// Calculate the variance along the given dimension and also returns the mean.
232    pub fn var_mean(self, dim: usize) -> (Self, Self) {
233        let mean = self.clone().mean_dim(dim);
234        let var = stats::var_with_mean(self, mean.clone(), dim);
235        (var, mean)
236    }
237
238    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
239    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
240        let mean = self.clone().mean_dim(dim);
241        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
242        (var, mean)
243    }
244
245    /// Converts a tensor to the specified floating point data type.
246    ///
247    /// # Warning
248    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
249    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
250    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
251        Tensor::new(TensorPrimitive::Float(B::float_cast(
252            self.primitive.tensor(),
253            dtype.into(),
254        )))
255    }
256
257    /// Detach the current tensor from the autodiff graph.
258    ///
259    /// This function does nothing when autodiff is not enabled.
260    /// This can be used in batchers or elsewhere to ensure that previous operations are not
261    /// considered in the autodiff graph.
262    pub fn detach(self) -> Self {
263        Self::new(TensorPrimitive::Float(B::float_detach(
264            self.primitive.tensor(),
265        )))
266    }
267
268    /// Mark the tensor to keep gradients during the backward pass.
269    ///
270    /// This function does nothing when autodiff is not enabled.
271    pub fn require_grad(self) -> Self {
272        self.set_require_grad(true)
273    }
274
275    /// Returns true if the tensor requires gradients during the backward pass.
276    pub fn is_require_grad(&self) -> bool {
277        match &self.primitive {
278            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
279            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
280        }
281    }
282
283    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
284    /// When tracked, the gradients will be available after the backward pass.
285    ///
286    /// This function does nothing when autodiff is not enabled.
287    pub fn set_require_grad(self, require_grad: bool) -> Self {
288        let primitive = match self.primitive {
289            TensorPrimitive::Float(tensor) => {
290                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
291            }
292            TensorPrimitive::QFloat(tensor) => {
293                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
294            }
295        };
296        Self::new(primitive)
297    }
298
299    /// Applies the relu function to the tensor.
300    pub(crate) fn relu(self) -> Self {
301        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
302    }
303
304    /// Calculate covaraince matrix between different entries alongside a given dimension.
305    ///
306    /// # Arguments
307    ///
308    /// * `size` - The size of the square matrix.
309    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
310    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
311        let n = self.dims()[dim];
312        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
313        centered
314            .clone()
315            .transpose()
316            .matmul(centered)
317            .div_scalar(n as f32 - correction_factor as f32)
318    }
319
320    /// Convert the tensor to a lower precision data type based on the quantization scheme.
321    ///
322    /// # Arguments
323    ///
324    /// * `scheme` - The quantization scheme.
325    /// * `qparams` - The pre-computed quantization parameters.
326    ///
327    /// # Returns
328    ///
329    /// The quantized tensor.
330    pub fn quantize(
331        self,
332        scheme: &QuantizationScheme,
333        qparams: QuantizationParameters<B>,
334    ) -> Tensor<B, D> {
335        Tensor::new(TensorPrimitive::QFloat(B::quantize(
336            self.primitive.tensor(),
337            scheme,
338            qparams.into(),
339        )))
340    }
341
342    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
343    ///
344    /// # Arguments
345    ///
346    /// * `scheme` - The quantization scheme.
347    ///
348    /// # Returns
349    ///
350    /// The quantized tensor.
351    pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor<B, D> {
352        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
353            self.primitive.tensor(),
354            scheme,
355        )))
356    }
357
358    /// Convert the tensor back to a higher precision data type.
359    ///
360    /// If the tensor is not quantized, its value is simply returned.
361    ///
362    /// # Returns
363    ///
364    /// The dequantized tensor.
365    pub fn dequantize(self) -> Tensor<B, D> {
366        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
367    }
368}