burn_tensor/tensor/api/
float.rs

1use crate::Tensor;
2use crate::check::TensorCheck;
3use crate::quantization::{QuantizationParameters, QuantizationScheme};
4use crate::tensor::backend::Backend;
5use crate::tensor::stats;
6use crate::tensor::{Distribution, TensorData};
7use crate::{FloatDType, check};
8use crate::{Int, TensorPrimitive};
9
10impl<const D: usize, B> Tensor<B, D>
11where
12    B: Backend,
13{
14    /// Executes an operation on the tensor and modifies its value.
15    ///
16    /// # Notes
17    ///
18    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
19    /// no other reference pointing to the same tensor.
20    ///
21    /// Wrapping operations with inplace is not an optimization, it's mainly there if you
22    /// want to mutate a tensor by using owned operations. A plausible usage would be to
23    /// update the weights of a mutable model reference.
24    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
25        let mut tensor_owned = Tensor::empty([0; D], &self.device());
26        core::mem::swap(&mut tensor_owned, self);
27
28        let mut tensor_new = func(tensor_owned);
29        core::mem::swap(&mut tensor_new, self);
30    }
31
32    /// Applies element wise exponential operation.
33    ///
34    /// `y = e^x`
35    pub fn exp(self) -> Self {
36        Self::new(TensorPrimitive::Float(B::float_exp(
37            self.primitive.tensor(),
38        )))
39    }
40
41    /// Applies element wise natural log operation *ln*.
42    ///
43    /// `y = log(x)`
44    pub fn log(self) -> Self {
45        Self::new(TensorPrimitive::Float(B::float_log(
46            self.primitive.tensor(),
47        )))
48    }
49
50    /// Applies the natural logarithm of one plus the input tensor, element-wise.
51    ///
52    /// `y = log(x+1)`
53    pub fn log1p(self) -> Self {
54        Self::new(TensorPrimitive::Float(B::float_log1p(
55            self.primitive.tensor(),
56        )))
57    }
58
59    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
60    ///
61    /// `y = erf(x)`
62    pub fn erf(self) -> Self {
63        Self::new(TensorPrimitive::Float(B::float_erf(
64            self.primitive.tensor(),
65        )))
66    }
67
68    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
69    /// (or multiplicative inverse) element wise.
70    ///
71    /// `y = 1/x`
72    pub fn recip(self) -> Self {
73        Self::new(TensorPrimitive::Float(B::float_recip(
74            self.primitive.tensor(),
75        )))
76    }
77
78    /// Applies element wise root square operation.
79    pub fn sqrt(self) -> Self {
80        Self::new(TensorPrimitive::Float(B::float_sqrt(
81            self.primitive.tensor(),
82        )))
83    }
84
85    /// Applies element wise cosine operation.
86    pub fn cos(self) -> Self {
87        Self::new(TensorPrimitive::Float(B::float_cos(
88            self.primitive.tensor(),
89        )))
90    }
91
92    /// Applies element wise sine operation.
93    pub fn sin(self) -> Self {
94        Self::new(TensorPrimitive::Float(B::float_sin(
95            self.primitive.tensor(),
96        )))
97    }
98
99    /// Applies element wise tangent operation.
100    pub fn tan(self) -> Self {
101        Self::new(TensorPrimitive::Float(B::float_tan(
102            self.primitive.tensor(),
103        )))
104    }
105
106    /// Applies element wise hyperbolic cosine operation.
107    pub fn cosh(self) -> Self {
108        Self::new(TensorPrimitive::Float(B::float_cosh(
109            self.primitive.tensor(),
110        )))
111    }
112
113    /// Applies element wise hyperbolic sine operation.
114    pub fn sinh(self) -> Self {
115        Self::new(TensorPrimitive::Float(B::float_sinh(
116            self.primitive.tensor(),
117        )))
118    }
119
120    /// Applies element wise hyperbolic tangent operation.
121    pub fn tanh(self) -> Self {
122        Self::new(TensorPrimitive::Float(B::float_tanh(
123            self.primitive.tensor(),
124        )))
125    }
126
127    /// Applies element wise round operation.
128    ///
129    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
130    /// strategy, with halfway cases rounded to the nearest even integer value.
131    pub fn round(self) -> Self {
132        Self::new(TensorPrimitive::Float(B::float_round(
133            self.primitive.tensor(),
134        )))
135    }
136
137    /// Applies element wise floor operation.
138    pub fn floor(self) -> Self {
139        Self::new(TensorPrimitive::Float(B::float_floor(
140            self.primitive.tensor(),
141        )))
142    }
143
144    /// Applies element wise ceil operation.
145    pub fn ceil(self) -> Self {
146        Self::new(TensorPrimitive::Float(B::float_ceil(
147            self.primitive.tensor(),
148        )))
149    }
150
151    /// Create a tensor from floats (f32) on a given device.
152    ///
153    /// # Example
154    ///
155    /// ```rust
156    /// use burn_tensor::backend::Backend;
157    /// use burn_tensor::Tensor;
158    ///
159    /// fn example<B: Backend>() {
160    ///     let device = B::Device::default();
161    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
162    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
163    /// }
164    /// ```
165    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
166        Self::from_data(floats.into().convert::<f32>(), device)
167    }
168
169    /// Returns a new tensor with the same shape and device as the current tensor and the data
170    /// cast to Integer.
171    ///
172    /// # Example
173    ///
174    /// ```rust
175    /// use burn_tensor::backend::Backend;
176    /// use burn_tensor::Tensor;
177    ///
178    /// fn example<B: Backend>() {
179    ///     let device = Default::default();
180    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
181    ///     let int_tensor = float_tensor.int();
182    /// }
183    /// ```
184    pub fn int(self) -> Tensor<B, D, Int> {
185        Tensor::new(B::float_into_int(self.primitive.tensor()))
186    }
187
188    /// Returns a new tensor with the same shape and device as the current tensor filled random
189    /// values sampled from the given distribution.
190    pub fn random_like(&self, distribution: Distribution) -> Self {
191        Tensor::new(TensorPrimitive::Float(B::float_random(
192            self.shape(),
193            distribution,
194            &self.device(),
195        )))
196    }
197
198    /// Applies the matrix multiplication operation.
199    ///
200    /// `C = AB`
201    ///
202    /// # Panics
203    ///
204    /// If the two tensors don't have a compatible shape.
205    pub fn matmul(self, other: Self) -> Self {
206        check!(TensorCheck::matmul(&self, &other));
207        match (self.primitive, other.primitive) {
208            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
209                Self::new(TensorPrimitive::QFloat(B::q_matmul(lhs, rhs)))
210            }
211            (lhs, rhs) => Self::new(TensorPrimitive::Float(B::float_matmul(
212                lhs.tensor(),
213                rhs.tensor(),
214            ))),
215        }
216    }
217
218    /// Calculate the variance along the given dimension.
219    pub fn var(self, dim: usize) -> Self {
220        stats::var(self, dim)
221    }
222
223    /// Calculate the variance along the given dimension without applying the Bessel’s correction.
224    pub fn var_bias(self, dim: usize) -> Self {
225        stats::var_bias(self, dim)
226    }
227
228    /// Calculate the variance along the given dimension and also returns the mean.
229    pub fn var_mean(self, dim: usize) -> (Self, Self) {
230        let mean = self.clone().mean_dim(dim);
231        let var = stats::var_with_mean(self, mean.clone(), dim);
232        (var, mean)
233    }
234
235    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
236    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
237        let mean = self.clone().mean_dim(dim);
238        let var = stats::var_with_mean_bias(self, mean.clone(), dim);
239        (var, mean)
240    }
241
242    /// Converts a tensor to the specified floating point data type.
243    ///
244    /// # Warning
245    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
246    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
247    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
248        Tensor::new(TensorPrimitive::Float(B::float_cast(
249            self.primitive.tensor(),
250            dtype.into(),
251        )))
252    }
253
254    /// Detach the current tensor from the autodiff graph.
255    ///
256    /// This function does nothing when autodiff is not enabled.
257    /// This can be used in batchers or elsewhere to ensure that previous operations are not
258    /// considered in the autodiff graph.
259    pub fn detach(self) -> Self {
260        Self::new(TensorPrimitive::Float(B::float_detach(
261            self.primitive.tensor(),
262        )))
263    }
264
265    /// Mark the tensor to keep gradients during the backward pass.
266    ///
267    /// This function does nothing when autodiff is not enabled.
268    pub fn require_grad(self) -> Self {
269        self.set_require_grad(true)
270    }
271
272    /// Returns true if the tensor requires gradients during the backward pass.
273    pub fn is_require_grad(&self) -> bool {
274        match &self.primitive {
275            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
276            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
277        }
278    }
279
280    /// Mark the tensor as tracked or untracked depending on the require_grad argument.
281    /// When tracked, the gradients will be available after the backward pass.
282    ///
283    /// This function does nothing when autodiff is not enabled.
284    pub fn set_require_grad(self, require_grad: bool) -> Self {
285        let primitive = match self.primitive {
286            TensorPrimitive::Float(tensor) => {
287                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
288            }
289            TensorPrimitive::QFloat(tensor) => {
290                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
291            }
292        };
293        Self::new(primitive)
294    }
295
296    /// Applies the relu function to the tensor.
297    pub(crate) fn relu(self) -> Self {
298        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
299    }
300
301    /// Calculate covaraince matrix between different entries alongside a given dimension.
302    ///
303    /// # Arguments
304    ///
305    /// * `size` - The size of the square matrix.
306    /// * `correction_factor` - Is usually 1 for samples and 0 for population.
307    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
308        let n = self.dims()[dim];
309        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
310        centered
311            .clone()
312            .transpose()
313            .matmul(centered)
314            .div_scalar(n as f32 - correction_factor as f32)
315    }
316
317    /// Convert the tensor to a lower precision data type based on the quantization scheme.
318    ///
319    /// # Arguments
320    ///
321    /// * `scheme` - The quantization scheme.
322    /// * `qparams` - The pre-computed quantization parameters.
323    ///
324    /// # Returns
325    ///
326    /// The quantized tensor.
327    pub fn quantize(
328        self,
329        scheme: &QuantizationScheme,
330        qparams: QuantizationParameters<B>,
331    ) -> Tensor<B, D> {
332        Tensor::new(TensorPrimitive::QFloat(B::quantize(
333            self.primitive.tensor(),
334            scheme,
335            qparams.into(),
336        )))
337    }
338
339    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
340    ///
341    /// # Arguments
342    ///
343    /// * `scheme` - The quantization scheme.
344    ///
345    /// # Returns
346    ///
347    /// The quantized tensor.
348    ///
349    /// # Notes
350    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
351    pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor<B, D> {
352        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
353            self.primitive.tensor(),
354            scheme,
355        )))
356    }
357
358    /// Convert the tensor back to a higher precision data type.
359    ///
360    /// If the tensor is not quantized, its value is simply returned.
361    ///
362    /// # Returns
363    ///
364    /// The dequantized tensor.
365    pub fn dequantize(self) -> Tensor<B, D> {
366        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
367    }
368}