burn_tensor/tensor/api/float.rs
1use alloc::vec::Vec;
2use core::convert::TryInto;
3
4use crate::check::TensorCheck;
5use crate::quantization::{QuantizationParameters, QuantizationScheme};
6use crate::tensor::backend::Backend;
7use crate::tensor::stats;
8use crate::tensor::{Distribution, Shape, TensorData};
9use crate::Tensor;
10use crate::{check, FloatDType};
11use crate::{Int, TensorPrimitive};
12
13impl<const D: usize, B> Tensor<B, D>
14where
15 B: Backend,
16{
17 /// Executes an operation on the tensor and modifies its value.
18 ///
19 /// # Notes
20 ///
21 /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
22 /// no other reference pointing to the same tensor.
23 ///
24 /// Wrapping operations with inplace is not an optimization, it's mainly there if you
25 /// want to mutate a tensor by using owned operations. A plausible usage would be to
26 /// update the weights of a mutable model reference.
27 pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
28 let mut tensor_owned = Tensor::empty([0; D], &self.device());
29 core::mem::swap(&mut tensor_owned, self);
30
31 let mut tensor_new = func(tensor_owned);
32 core::mem::swap(&mut tensor_new, self);
33 }
34
35 /// Applies element wise exponential operation.
36 ///
37 /// `y = e^x`
38 pub fn exp(self) -> Self {
39 Self::new(TensorPrimitive::Float(B::float_exp(
40 self.primitive.tensor(),
41 )))
42 }
43
44 /// Applies element wise natural log operation *ln*.
45 ///
46 /// `y = log(x)`
47 pub fn log(self) -> Self {
48 Self::new(TensorPrimitive::Float(B::float_log(
49 self.primitive.tensor(),
50 )))
51 }
52
53 /// Applies the natural logarithm of one plus the input tensor, element-wise.
54 ///
55 /// `y = log(x+1)`
56 pub fn log1p(self) -> Self {
57 Self::new(TensorPrimitive::Float(B::float_log1p(
58 self.primitive.tensor(),
59 )))
60 }
61
62 /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.
63 ///
64 /// `y = erf(x)`
65 pub fn erf(self) -> Self {
66 Self::new(TensorPrimitive::Float(B::float_erf(
67 self.primitive.tensor(),
68 )))
69 }
70
71 /// Applies element wise reciprocal operation.
72 pub fn recip(self) -> Self {
73 Self::new(TensorPrimitive::Float(B::float_recip(
74 self.primitive.tensor(),
75 )))
76 }
77
78 /// Applies element wise root square operation.
79 pub fn sqrt(self) -> Self {
80 Self::new(TensorPrimitive::Float(B::float_sqrt(
81 self.primitive.tensor(),
82 )))
83 }
84
85 /// Applies element wise cosine operation.
86 pub fn cos(self) -> Self {
87 Self::new(TensorPrimitive::Float(B::float_cos(
88 self.primitive.tensor(),
89 )))
90 }
91
92 /// Applies element wise sine operation.
93 pub fn sin(self) -> Self {
94 Self::new(TensorPrimitive::Float(B::float_sin(
95 self.primitive.tensor(),
96 )))
97 }
98
99 /// Applies element wise hyperbolic tangent operation.
100 pub fn tanh(self) -> Self {
101 Self::new(TensorPrimitive::Float(B::float_tanh(
102 self.primitive.tensor(),
103 )))
104 }
105
106 /// Applies element wise round operation.
107 ///
108 /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
109 /// strategy, with halfway cases rounded to the nearest even integer value.
110 pub fn round(self) -> Self {
111 Self::new(TensorPrimitive::Float(B::float_round(
112 self.primitive.tensor(),
113 )))
114 }
115
116 /// Applies element wise floor operation.
117 pub fn floor(self) -> Self {
118 Self::new(TensorPrimitive::Float(B::float_floor(
119 self.primitive.tensor(),
120 )))
121 }
122
123 /// Applies element wise ceil operation.
124 pub fn ceil(self) -> Self {
125 Self::new(TensorPrimitive::Float(B::float_ceil(
126 self.primitive.tensor(),
127 )))
128 }
129
130 /// Create a tensor from floats (f32) on a given device.
131 ///
132 /// # Example
133 ///
134 /// ```rust
135 /// use burn_tensor::backend::Backend;
136 /// use burn_tensor::Tensor;
137 ///
138 /// fn example<B: Backend>() {
139 /// let device = B::Device::default();
140 /// let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
141 /// let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
142 /// }
143 /// ```
144 pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {
145 Self::from_data(floats.into().convert::<f32>(), device)
146 }
147
148 /// Returns a new tensor with the same shape and device as the current tensor and the data
149 /// cast to Integer.
150 ///
151 /// # Example
152 ///
153 /// ```rust
154 /// use burn_tensor::backend::Backend;
155 /// use burn_tensor::Tensor;
156 ///
157 /// fn example<B: Backend>() {
158 /// let device = Default::default();
159 /// let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
160 /// let int_tensor = float_tensor.int();
161 /// }
162 /// ```
163 pub fn int(self) -> Tensor<B, D, Int> {
164 Tensor::new(B::float_into_int(self.primitive.tensor()))
165 }
166
167 /// Returns a new tensor with the same shape and device as the current tensor filled random
168 /// values sampled from the given distribution.
169 pub fn random_like(&self, distribution: Distribution) -> Self {
170 Tensor::new(TensorPrimitive::Float(B::float_random(
171 self.shape(),
172 distribution,
173 &self.device(),
174 )))
175 }
176
177 /// Create a one hot tensor.
178 ///
179 /// # Example
180 ///
181 /// ```rust
182 /// use burn_tensor::backend::Backend;
183 /// use burn_tensor::Tensor;
184 ///
185 /// fn example<B: Backend>() {
186 /// let device = Default::default();
187 /// let one_hot = Tensor::<B, 1>::one_hot(2, 10, &device);
188 /// println!("{}", one_hot.to_data());
189 /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
190 /// }
191 /// ```
192 pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self {
193 check!(TensorCheck::one_hot_index(index, num_classes));
194
195 let mut dims = [1; D];
196 dims[D - 1] = num_classes;
197 let shape = Shape::new(dims);
198 let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect();
199 let tensor = Tensor::zeros(shape, device);
200 let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
201 ranges[D - 1] = index..index + 1;
202
203 tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device))
204 }
205
206 /// Applies the matrix multiplication operation.
207 ///
208 /// `C = AB`
209 ///
210 /// # Panics
211 ///
212 /// If the two tensors don't have a compatible shape.
213 pub fn matmul(self, other: Self) -> Self {
214 check!(TensorCheck::matmul(&self, &other));
215 Self::new(TensorPrimitive::Float(B::float_matmul(
216 self.primitive.tensor(),
217 other.primitive.tensor(),
218 )))
219 }
220
221 /// Calculate the variance along the given dimension.
222 pub fn var(self, dim: usize) -> Self {
223 stats::var(self, dim)
224 }
225
226 /// Calculate the variance along the given dimension without applying the Bessel’s correction.
227 pub fn var_bias(self, dim: usize) -> Self {
228 stats::var_bias(self, dim)
229 }
230
231 /// Calculate the variance along the given dimension and also returns the mean.
232 pub fn var_mean(self, dim: usize) -> (Self, Self) {
233 let mean = self.clone().mean_dim(dim);
234 let var = stats::var_with_mean(self, mean.clone(), dim);
235 (var, mean)
236 }
237
238 /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.
239 pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {
240 let mean = self.clone().mean_dim(dim);
241 let var = stats::var_with_mean_bias(self, mean.clone(), dim);
242 (var, mean)
243 }
244
245 /// Converts a tensor to the specified floating point data type.
246 ///
247 /// # Warning
248 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
249 /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).
250 pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {
251 Tensor::new(TensorPrimitive::Float(B::float_cast(
252 self.primitive.tensor(),
253 dtype.into(),
254 )))
255 }
256
257 /// Detach the current tensor from the autodiff graph.
258 ///
259 /// This function does nothing when autodiff is not enabled.
260 /// This can be used in batchers or elsewhere to ensure that previous operations are not
261 /// considered in the autodiff graph.
262 pub fn detach(self) -> Self {
263 Self::new(TensorPrimitive::Float(B::float_detach(
264 self.primitive.tensor(),
265 )))
266 }
267
268 /// Mark the tensor to keep gradients during the backward pass.
269 ///
270 /// This function does nothing when autodiff is not enabled.
271 pub fn require_grad(self) -> Self {
272 self.set_require_grad(true)
273 }
274
275 /// Returns true if the tensor requires gradients during the backward pass.
276 pub fn is_require_grad(&self) -> bool {
277 match &self.primitive {
278 TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),
279 TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),
280 }
281 }
282
283 /// Mark the tensor as tracked or untracked depending on the require_grad argument.
284 /// When tracked, the gradients will be available after the backward pass.
285 ///
286 /// This function does nothing when autodiff is not enabled.
287 pub fn set_require_grad(self, require_grad: bool) -> Self {
288 let primitive = match self.primitive {
289 TensorPrimitive::Float(tensor) => {
290 TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))
291 }
292 TensorPrimitive::QFloat(tensor) => {
293 TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))
294 }
295 };
296 Self::new(primitive)
297 }
298
299 /// Applies the relu function to the tensor.
300 pub(crate) fn relu(self) -> Self {
301 Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))
302 }
303
304 /// Calculate covaraince matrix between different entries alongside a given dimension.
305 ///
306 /// # Arguments
307 ///
308 /// * `size` - The size of the square matrix.
309 /// * `correction_factor` - Is usually 1 for samples and 0 for population.
310 pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {
311 let n = self.dims()[dim];
312 let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);
313 centered
314 .clone()
315 .transpose()
316 .matmul(centered)
317 .div_scalar(n as f32 - correction_factor as f32)
318 }
319
320 /// Convert the tensor to a lower precision data type based on the quantization scheme.
321 ///
322 /// # Arguments
323 ///
324 /// * `scheme` - The quantization scheme.
325 /// * `qparams` - The pre-computed quantization parameters.
326 ///
327 /// # Returns
328 ///
329 /// The quantized tensor.
330 pub fn quantize(
331 self,
332 scheme: &QuantizationScheme,
333 qparams: QuantizationParameters<B>,
334 ) -> Tensor<B, D> {
335 Tensor::new(TensorPrimitive::QFloat(B::quantize(
336 self.primitive.tensor(),
337 scheme,
338 qparams.into(),
339 )))
340 }
341
342 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
343 ///
344 /// # Arguments
345 ///
346 /// * `scheme` - The quantization scheme.
347 ///
348 /// # Returns
349 ///
350 /// The quantized tensor.
351 pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor<B, D> {
352 Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(
353 self.primitive.tensor(),
354 scheme,
355 )))
356 }
357
358 /// Convert the tensor back to a higher precision data type.
359 ///
360 /// If the tensor is not quantized, its value is simply returned.
361 ///
362 /// # Returns
363 ///
364 /// The dequantized tensor.
365 pub fn dequantize(self) -> Tensor<B, D> {
366 Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))
367 }
368}