burn_tensor/tensor/api/float.rs
1use crate::Tensor;
2use crate::check::TensorCheck;
3use crate::quantization::{QuantizationParameters, QuantizationScheme};
4use crate::tensor::backend::Backend;
5use crate::tensor::stats;
6use crate::tensor::{Distribution, TensorData};
7use crate::{FloatDType, check};
8use crate::{Int, TensorPrimitive};
9
10impl<const D: usize, B> Tensor<B, D>
11where
12 B: Backend,
13{
14 /// Executes an operation on the tensor and modifies its value.
15 ///
16 /// # Notes
17 ///
18 /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
19 /// no other reference pointing to the same tensor.
20 ///
21 /// Wrapping operations with inplace is not an optimization, it's mainly there if you
22 /// want to mutate a tensor by using owned operations. A plausible usage would be to
23 /// update the weights of a mutable model reference.
24 pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
25 let mut tensor_owned = Tensor::empty([0; D], &self.device());
26 core::mem::swap(&mut tensor_owned, self);
27
28 let mut tensor_new = func(tensor_owned);
29 core::mem::swap(&mut tensor_new, self);
30 }
31
32 /// Applies element wise exponential operation.
33 ///
34 /// `y = e^x`
35 pub fn exp(self) -> Self {
36 Self::new(TensorPrimitive::Float(B::float_exp(
37 self.primitive.tensor(),
38 )))
39 }
40
41 /// Applies element wise natural log operation *ln*.
42 ///
43 /// `y = log(x)`
44 pub fn log(self) -> Self {
45 Self::new(TensorPrimitive::Float(B::float_log(
46 self.primitive.tensor(),
47 )))
48 }
49
50 /// Applies the natural logarithm of one plus the input tensor, element-wise.
51 ///
52 /// `y = log(x+1)`
53 pub fn log1p(self) -> Self {
54 Self::new(TensorPrimitive::Float(B::float_log1p(
55 self.primitive.tensor(),
56 )))
57 }
58
59 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
60 ///
61 /// `y = erf(x)`
62 pub fn erf(self) -> Self {
63 Self::new(TensorPrimitive::Float(B::float_erf(
64 self.primitive.tensor(),
65 )))
66 }
67
68 /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)
69 /// (or multiplicative inverse) element wise.
70 ///
71 /// `y = 1/x`
72 pub fn recip(self) -> Self {
73 Self::new(TensorPrimitive::Float(B::float_recip(
74 self.primitive.tensor(),
75 )))
76 }
77
78 /// Applies element wise root square operation.
79 pub fn sqrt(self) -> Self {
80 Self::new(TensorPrimitive::Float(B::float_sqrt(
81 self.primitive.tensor(),
82 )))
83 }
84
85 /// Applies element wise cosine operation.
86 pub fn cos(self) -> Self {
87 Self::new(TensorPrimitive::Float(B::float_cos(
88 self.primitive.tensor(),
89 )))
90 }
91
92 /// Applies element wise sine operation.
93 pub fn sin(self) -> Self {
94 Self::new(TensorPrimitive::Float(B::float_sin(
95 self.primitive.tensor(),
96 )))
97 }
98
99 /// Applies element wise tangent operation.
100 pub fn tan(self) -> Self {
101 Self::new(TensorPrimitive::Float(B::float_tan(
102 self.primitive.tensor(),
103 )))
104 }
105
106 /// Applies element wise hyperbolic cosine operation.
107 pub fn cosh(self) -> Self {
108 Self::new(TensorPrimitive::Float(B::float_cosh(
109 self.primitive.tensor(),
110 )))
111 }
112
113 /// Applies element wise hyperbolic sine operation.
114 pub fn sinh(self) -> Self {
115 Self::new(TensorPrimitive::Float(B::float_sinh(
116 self.primitive.tensor(),
117 )))
118 }
119
120 /// Applies element wise hyperbolic tangent operation.
121 pub fn tanh(self) -> Self {
122 Self::new(TensorPrimitive::Float(B::float_tanh(
123 self.primitive.tensor(),
124 )))
125 }
126
127 /// Applies element wise round operation.
128 ///
129 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
130 /// strategy, with halfway cases rounded to the nearest even integer value.
131 pub fn round(self) -> Self {
132 Self::new(TensorPrimitive::Float(B::float_round(
133 self.primitive.tensor(),
134 )))
135 }
136
137 /// Applies element wise floor operation.
138 pub fn floor(self) -> Self {
139 Self::new(TensorPrimitive::Float(B::float_floor(
140 self.primitive.tensor(),
141 )))
142 }
143
144 /// Applies element wise ceil operation.
145 pub fn ceil(self) -> Self {
146 Self::new(TensorPrimitive::Float(B::float_ceil(
147 self.primitive.tensor(),
148 )))
149 }
150
151 /// Create a tensor from floats (f32) on a given device.
152 ///
153 /// # Example
154 ///
155 /// ```rust
156 /// use burn_tensor::backend::Backend;
157 /// use burn_tensor::Tensor;
158 ///
159 /// fn example<B: Backend>() {
160 /// let device = B::Device::default();
161 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
162 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
163 /// }
164 /// ```
165 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
166 Self::from_data(floats.into().convert::<f32>(), device)
167 }
168
169 /// Returns a new tensor with the same shape and device as the current tensor and the data
170 /// cast to Integer.
171 ///
172 /// # Example
173 ///
174 /// ```rust
175 /// use burn_tensor::backend::Backend;
176 /// use burn_tensor::Tensor;
177 ///
178 /// fn example<B: Backend>() {
179 /// let device = Default::default();
180 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
181 /// let int_tensor = float_tensor.int();
182 /// }
183 /// ```
184 pub fn int(self) -> Tensor<B, D, Int> {
185 Tensor::new(B::float_into_int(self.primitive.tensor()))
186 }
187
188 /// Returns a new tensor with the same shape and device as the current tensor filled random
189 /// values sampled from the given distribution.
190 pub fn random_like(&self, distribution: Distribution) -> Self {
191 Tensor::new(TensorPrimitive::Float(B::float_random(
192 self.shape(),
193 distribution,
194 &self.device(),
195 )))
196 }
197
198 /// Applies the matrix multiplication operation.
199 ///
200 /// `C = AB`
201 ///
202 /// # Panics
203 ///
204 /// If the two tensors don't have a compatible shape.
205 pub fn matmul(self, other: Self) -> Self {
206 check!(TensorCheck::matmul(&self, &other));
207 match (self.primitive, other.primitive) {
208 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
209 Self::new(TensorPrimitive::QFloat(B::q_matmul(lhs, rhs)))
210 }
211 (lhs, rhs) => Self::new(TensorPrimitive::Float(B::float_matmul(
212 lhs.tensor(),
213 rhs.tensor(),
214 ))),
215 }
216 }
217
218 /// Calculate the variance along the given dimension.
219 pub fn var(self, dim: usize) -> Self {
220 stats::var(self, dim)
221 }
222
223 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
224 pub fn var_bias(self, dim: usize) -> Self {
225 stats::var_bias(self, dim)
226 }
227
228 /// Calculate the variance along the given dimension and also returns the mean.
229 pub fn var_mean(self, dim: usize) -> (Self, Self) {
230 let mean = self.clone().mean_dim(dim);
231 let var = stats::var_with_mean(self, mean.clone(), dim);
232 (var, mean)
233 }
234
235 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
236 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
237 let mean = self.clone().mean_dim(dim);
238 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
239 (var, mean)
240 }
241
242 /// Converts a tensor to the specified floating point data type.
243 ///
244 /// # Warning
245 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
246 /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
247 pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
248 Tensor::new(TensorPrimitive::Float(B::float_cast(
249 self.primitive.tensor(),
250 dtype.into(),
251 )))
252 }
253
254 /// Detach the current tensor from the autodiff graph.
255 ///
256 /// This function does nothing when autodiff is not enabled.
257 /// This can be used in batchers or elsewhere to ensure that previous operations are not
258 /// considered in the autodiff graph.
259 pub fn detach(self) -> Self {
260 Self::new(TensorPrimitive::Float(B::float_detach(
261 self.primitive.tensor(),
262 )))
263 }
264
265 /// Mark the tensor to keep gradients during the backward pass.
266 ///
267 /// This function does nothing when autodiff is not enabled.
268 pub fn require_grad(self) -> Self {
269 self.set_require_grad(true)
270 }
271
272 /// Returns true if the tensor requires gradients during the backward pass.
273 pub fn is_require_grad(&self) -> bool {
274 match &self.primitive {
275 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
276 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
277 }
278 }
279
280 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
281 /// When tracked, the gradients will be available after the backward pass.
282 ///
283 /// This function does nothing when autodiff is not enabled.
284 pub fn set_require_grad(self, require_grad: bool) -> Self {
285 let primitive = match self.primitive {
286 TensorPrimitive::Float(tensor) => {
287 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
288 }
289 TensorPrimitive::QFloat(tensor) => {
290 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
291 }
292 };
293 Self::new(primitive)
294 }
295
296 /// Applies the relu function to the tensor.
297 pub(crate) fn relu(self) -> Self {
298 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
299 }
300
301 /// Calculate covaraince matrix between different entries alongside a given dimension.
302 ///
303 /// # Arguments
304 ///
305 /// * `size` - The size of the square matrix.
306 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
307 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
308 let n = self.dims()[dim];
309 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
310 centered
311 .clone()
312 .transpose()
313 .matmul(centered)
314 .div_scalar(n as f32 - correction_factor as f32)
315 }
316
317 /// Convert the tensor to a lower precision data type based on the quantization scheme.
318 ///
319 /// # Arguments
320 ///
321 /// * `scheme` - The quantization scheme.
322 /// * `qparams` - The pre-computed quantization parameters.
323 ///
324 /// # Returns
325 ///
326 /// The quantized tensor.
327 pub fn quantize(
328 self,
329 scheme: &QuantizationScheme,
330 qparams: QuantizationParameters<B>,
331 ) -> Tensor<B, D> {
332 Tensor::new(TensorPrimitive::QFloat(B::quantize(
333 self.primitive.tensor(),
334 scheme,
335 qparams.into(),
336 )))
337 }
338
339 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
340 ///
341 /// # Arguments
342 ///
343 /// * `scheme` - The quantization scheme.
344 ///
345 /// # Returns
346 ///
347 /// The quantized tensor.
348 ///
349 /// # Notes
350 /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).
351 pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor<B, D> {
352 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
353 self.primitive.tensor(),
354 scheme,
355 )))
356 }
357
358 /// Convert the tensor back to a higher precision data type.
359 ///
360 /// If the tensor is not quantized, its value is simply returned.
361 ///
362 /// # Returns
363 ///
364 /// The dequantized tensor.
365 pub fn dequantize(self) -> Tensor<B, D> {
366 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
367 }
368}