Skip to main content

burn_backend/tensor/ops/
numeric.rs

1use burn_std::Shape;
2
3use crate::{Backend, Distribution, Scalar, element::Element, tensor::BasicOps};
4
5/// Trait that list all operations that can be applied on all numerical tensors.
6///
7/// # Warnings
8///
9/// This is an internal trait, use the public API provided by the
10#[cfg_attr(doc, doc = crate::doc_tensor!())]
11#[cfg_attr(not(doc), doc = "`Tensor`")]
12/// struct.
13pub trait Numeric<B: Backend>: BasicOps<B>
14where
15    Self::Elem: Element,
16{
17    /// Adds two tensors together.
18    ///
19    /// # Arguments
20    ///
21    /// * `lhs` - The left hand side tensor.
22    /// * `rhs` - The right hand side tensor.
23    ///
24    /// # Returns
25    ///
26    /// The sum of the two tensors.
27    ///
28    /// # Remarks
29    ///
30    /// This is a low-level function used internally by the library to call different backend functions
31    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
32    /// or use this function directly.
33    ///
34    /// For adding tensors, users should prefer the
35    #[cfg_attr(doc, doc = crate::doc_tensor!("add"))]
36    #[cfg_attr(not(doc), doc = "`Tensor::add`")]
37    /// function, which is more high-level and designed for public use.
38    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
39
40    /// Adds a scalar to a tensor element-wise.
41    ///
42    /// # Arguments
43    ///
44    /// * `lhs` - The left hand side tensor.
45    /// * `rhs` - The right hand side scalar.
46    ///
47    /// # Returns
48    ///
49    /// The sum of the tensor and the scalar.
50    ///
51    /// # Remarks
52    ///
53    /// This is a low-level function used internally by the library to call different backend functions
54    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
55    /// or use this function directly.
56    ///
57    /// For adding a scalar to a tensor, users should prefer the
58    #[cfg_attr(doc, doc = crate::doc_tensor!("add_scalar"))]
59    #[cfg_attr(not(doc), doc = "`Tensor::add_scalar`")]
60    /// function, which is more high-level and designed for public use.
61    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
62
63    /// Subtracts two tensors.
64    ///
65    /// # Arguments
66    ///
67    /// * `lhs` - The left hand side tensor.
68    /// * `rhs` - The right hand side tensor.
69    ///
70    /// # Returns
71    ///
72    /// The difference of the two tensors.
73    ///
74    /// # Remarks
75    ///
76    /// This is a low-level function used internally by the library to call different backend functions
77    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
78    /// or use this function directly.
79    ///
80    /// For subtracting tensors, users should prefer the
81    #[cfg_attr(doc, doc = crate::doc_tensor!("sub"))]
82    #[cfg_attr(not(doc), doc = "`Tensor::sub`")]
83    /// function, which is more high-level and designed for public use.
84    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
85
86    /// Subtracts a scalar from a tensor element-wise.
87    ///
88    /// # Arguments
89    ///
90    /// * `lhs` - The left hand side tensor.
91    /// * `rhs` - The right hand side scalar.
92    ///
93    /// # Returns
94    ///
95    /// The difference of the tensor and the scalar.
96    ///
97    /// # Remarks
98    ///
99    /// This is a low-level function used internally by the library to call different backend functions
100    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
101    /// or use this function directly.
102    ///
103    /// For subtracting a scalar from a tensor, users should prefer the
104    #[cfg_attr(doc, doc = crate::doc_tensor!("sub_scalar"))]
105    #[cfg_attr(not(doc), doc = "`Tensor::sub_scalar`")]
106    /// function, which is more high-level and designed for public use.
107    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
108
109    /// Divides two tensors.
110    ///
111    /// # Arguments
112    ///
113    /// * `lhs` - The left hand side tensor.
114    /// * `rhs` - The right hand side tensor.
115    ///
116    /// # Returns
117    ///
118    /// The quotient of the two tensors.
119    ///
120    /// # Remarks
121    ///
122    /// This is a low-level function used internally by the library to call different backend functions
123    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
124    /// or use this function directly.
125    ///
126    /// For dividing tensors, users should prefer the
127    #[cfg_attr(doc, doc = crate::doc_tensor!("div"))]
128    #[cfg_attr(not(doc), doc = "`Tensor::div`")]
129    /// function, which is more high-level and designed for public use.
130    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
131
132    /// Divides a tensor by a scalar element-wise.
133    ///
134    /// # Arguments
135    ///
136    /// * `lhs` - The left hand side tensor.
137    /// * `rhs` - The right hand side scalar.
138    ///
139    /// # Returns
140    ///
141    /// The quotient of the tensor and the scalar.
142    ///
143    /// # Remarks
144    ///
145    /// This is a low-level function used internally by the library to call different backend functions
146    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
147    /// or use this function directly.
148    ///
149    /// For dividing a tensor by a scalar, users should prefer the
150    #[cfg_attr(doc, doc = crate::doc_tensor!("div_scalar"))]
151    #[cfg_attr(not(doc), doc = "`Tensor::div_scalar`")]
152    /// function, which is more high-level and designed for public use.
153    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
154
155    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
156    /// less than that of the divisor.
157    ///
158    /// # Arguments
159    ///
160    /// * `lhs` - The dividend.
161    /// * `rhs` - The divisor.
162    ///
163    /// # Returns
164    ///
165    /// The modulo of the input tensor with the divisor.
166    ///
167    /// # Remarks
168    ///
169    /// This is a low-level function used internally by the library to call different backend functions
170    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
171    /// or use this function directly.
172    ///
173    /// For performing the modulo operation, users should prefer the
174    #[cfg_attr(doc, doc = crate::doc_tensor!("remainder"))]
175    #[cfg_attr(not(doc), doc = "`Tensor::remainder`")]
176    /// function, which is more high-level and designed for public use.
177    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
178
179    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
180    /// less than that of the divisor.
181    ///
182    /// # Arguments
183    ///
184    /// * `lhs` - The dividend.
185    /// * `rhs` - The divisor.
186    ///
187    /// # Returns
188    ///
189    /// The modulo of the input tensor with the divisor.
190    ///
191    /// # Remarks
192    ///
193    /// This is a low-level function used internally by the library to call different backend functions
194    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
195    /// or use this function directly.
196    ///
197    /// For performing the modulo operation, users should prefer the
198    #[cfg_attr(doc, doc = crate::doc_tensor!("remainder_scalar"))]
199    #[cfg_attr(not(doc), doc = "`Tensor::remainder_scalar`")]
200    /// function, which is more high-level and designed for public use.
201    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
202
203    /// Multiplies two tensors.
204    ///
205    /// # Arguments
206    ///
207    /// * `lhs` - The left hand side tensor.
208    /// * `rhs` - The right hand side tensor.
209    ///
210    /// # Returns
211    ///
212    /// The product of the two tensors.
213    ///
214    /// # Remarks
215    ///
216    /// This is a low-level function used internally by the library to call different backend functions
217    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
218    /// or use this function directly.
219    ///
220    /// For multiplying tensors, users should prefer the
221    #[cfg_attr(doc, doc = crate::doc_tensor!("mul"))]
222    #[cfg_attr(not(doc), doc = "`Tensor::mul`")]
223    /// function, which is more high-level and designed for public use.
224    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
225
226    /// Multiplies a tensor by a scalar element-wise.
227    ///
228    /// # Arguments
229    ///
230    /// * `lhs` - The left hand side tensor.
231    /// * `rhs` - The right hand side scalar.
232    ///
233    /// # Returns
234    ///
235    /// The product of the tensor and the scalar.
236    ///
237    /// # Remarks
238    ///
239    /// This is a low-level function used internally by the library to call different backend functions
240    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
241    /// or use this function directly.
242    ///
243    /// For multiplying a tensor by a scalar, users should prefer the
244    #[cfg_attr(doc, doc = crate::doc_tensor!("mul_scalar"))]
245    #[cfg_attr(not(doc), doc = "`Tensor::mul_scalar`")]
246    /// function, which is more high-level and designed for public use.
247    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
248
249    /// Negates a tensor.
250    ///
251    /// # Arguments
252    ///
253    /// * `tensor` - The tensor to negate.
254    ///
255    /// # Returns
256    ///
257    /// The negated tensor.
258    ///
259    /// # Remarks
260    ///
261    /// This is a low-level function used internally by the library to call different backend functions
262    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
263    /// or use this function directly.
264    ///
265    /// For negating a tensor, users should prefer the
266    #[cfg_attr(doc, doc = crate::doc_tensor!("neg"))]
267    #[cfg_attr(not(doc), doc = "`Tensor::neg`")]
268    /// function, which is more high-level and designed for public use.
269    fn neg(tensor: Self::Primitive) -> Self::Primitive;
270
271    /// Returns the signs of the elements of a tensor.
272    ///
273    /// # Arguments
274    ///
275    /// * `tensor` - The tensor.
276    ///
277    /// # Returns
278    ///
279    /// The signs of the elements of the tensor.
280    ///
281    /// # Remarks
282    ///
283    /// This is a low-level function used internally by the library to call different backend functions
284    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
285    /// or use this function directly.
286    ///
287    /// For getting the signs of the elements of a tensor, users should prefer the
288    #[cfg_attr(doc, doc = crate::doc_tensor!("sign"))]
289    #[cfg_attr(not(doc), doc = "`Tensor::sign`")]
290    /// function, which is more high-level and designed for public use.
291    fn sign(tensor: Self::Primitive) -> Self::Primitive;
292
293    /// Sums all the elements of the tensor.
294    ///
295    /// # Arguments
296    ///
297    /// * `tensor` - The tensor to sum.
298    ///
299    /// # Returns
300    ///
301    /// The sum of all the elements of the tensor.
302    ///
303    /// # Remarks
304    ///
305    /// This is a low-level function used internally by the library to call different backend functions
306    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
307    /// or use this function directly.
308    ///
309    /// For summing all the elements of a tensor, users should prefer the
310    #[cfg_attr(doc, doc = crate::doc_tensor!("sum"))]
311    #[cfg_attr(not(doc), doc = "`Tensor::sum`")]
312    /// function, which is more high-level and designed for public use.
313    fn sum(tensor: Self::Primitive) -> Self::Primitive;
314
315    /// Sums all the elements of the tensor along a dimension.
316    ///
317    /// # Arguments
318    ///
319    /// * `tensor` - The tensor to sum.
320    /// * `dim` - The dimension along which to sum.
321    ///
322    /// # Returns
323    ///
324    /// The sum of all the elements of the tensor along the specified dimension.
325    ///
326    /// # Remarks
327    ///
328    /// This is a low-level function used internally by the library to call different backend functions
329    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
330    /// or use this function directly.
331    ///
332    /// For summing all the elements of a tensor along a dimension, users should prefer the
333    #[cfg_attr(doc, doc = crate::doc_tensor!("sum_dim"))]
334    #[cfg_attr(not(doc), doc = "`Tensor::sum_dim`")]
335    /// function, which is more high-level and designed for public use.
336    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
337
338    /// Computes the product of all the elements of the tensor.
339    ///
340    /// # Arguments
341    ///
342    /// * `tensor` - The tensor to compute the product of.
343    ///
344    /// # Returns
345    ///
346    /// The product of all the elements of the tensor.
347    ///
348    /// # Remarks
349    ///
350    /// This is a low-level function used internally by the library to call different backend functions
351    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
352    /// or use this function directly.
353    ///
354    /// For computing the product of all the elements of a tensor, users should prefer the
355    #[cfg_attr(doc, doc = crate::doc_tensor!("prod"))]
356    #[cfg_attr(not(doc), doc = "`Tensor::prod`")]
357    /// function, which is more high-level and designed for public use.
358    fn prod(tensor: Self::Primitive) -> Self::Primitive;
359
360    /// Computes the product of all the elements of the tensor along a dimension.
361    ///
362    /// # Arguments
363    ///
364    /// * `tensor` - The tensor to compute the product of.
365    /// * `dim` - The dimension along which to compute the product.
366    ///
367    /// # Returns
368    ///
369    /// The product of all the elements of the tensor along the specified dimension.
370    ///
371    /// # Remarks
372    ///
373    /// This is a low-level function used internally by the library to call different backend functions
374    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
375    /// or use this function directly.
376    ///
377    /// For computing the product of all the elements of a tensor along a dimension, users should prefer the
378    #[cfg_attr(doc, doc = crate::doc_tensor!("prod_dim"))]
379    #[cfg_attr(not(doc), doc = "`Tensor::prod_dim`")]
380    /// function, which is more high-level and designed for public use.
381    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
382
383    /// Computes the mean of all the elements of the tensor.
384    ///
385    /// # Arguments
386    ///
387    /// * `tensor` - The tensor to compute the mean of.
388    ///
389    /// # Returns
390    ///
391    /// The mean of all the elements of the tensor.
392    ///
393    /// # Remarks
394    ///
395    /// This is a low-level function used internally by the library to call different backend functions
396    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
397    /// or use this function directly.
398    ///
399    /// For computing the mean of all the elements of a tensor, users should prefer the
400    #[cfg_attr(doc, doc = crate::doc_tensor!("mean"))]
401    #[cfg_attr(not(doc), doc = "`Tensor::mean`")]
402    /// function, which is more high-level and designed for public use.
403    fn mean(tensor: Self::Primitive) -> Self::Primitive;
404
405    /// Computes the mean of all the elements of the tensor along a dimension.
406    ///
407    /// # Arguments
408    ///
409    /// * `tensor` - The tensor to compute the mean of.
410    /// * `dim` - The dimension along which to compute the mean.
411    ///
412    /// # Returns
413    ///
414    /// The mean of all the elements of the tensor along the specified dimension.
415    ///
416    /// # Remarks
417    ///
418    /// This is a low-level function used internally by the library to call different backend functions
419    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
420    /// or use this function directly.
421    ///
422    /// For computing the mean of all the elements of a tensor along a dimension, users should prefer the
423    #[cfg_attr(doc, doc = crate::doc_tensor!("mean_dim"))]
424    #[cfg_attr(not(doc), doc = "`Tensor::mean_dim`")]
425    /// function, which is more high-level and designed for public use.
426    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
427
428    /// Computes the cumulative sum of elements along a dimension.
429    ///
430    /// # Arguments
431    ///
432    /// * `tensor` - The tensor to compute the cumulative sum of.
433    /// * `dim` - The dimension along which to compute the cumulative sum.
434    ///
435    /// # Returns
436    ///
437    /// A tensor with the same shape as the input tensor, where each element is the cumulative sum
438    /// of all elements up to and including that position along the specified dimension.
439    ///
440    /// # Remarks
441    ///
442    /// This is a low-level function used internally by the library to call different backend functions
443    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
444    /// or use this function directly.
445    ///
446    /// For computing the cumulative sum of elements along a dimension, users should prefer the
447    #[cfg_attr(doc, doc = crate::doc_tensor!("cumsum"))]
448    #[cfg_attr(not(doc), doc = "`Tensor::cumsum`")]
449    /// function, which is more high-level and designed for public use.
450    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
451
452    /// Computes the cumulative product of elements along a dimension.
453    ///
454    /// # Arguments
455    ///
456    /// * `tensor` - The tensor to compute the cumulative product of.
457    /// * `dim` - The dimension along which to compute the cumulative product.
458    ///
459    /// # Returns
460    ///
461    /// A tensor with the same shape as the input tensor, where each element is the cumulative product
462    /// of all elements up to and including that position along the specified dimension.
463    ///
464    /// # Remarks
465    ///
466    /// This is a low-level function used internally by the library to call different backend functions
467    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
468    /// or use this function directly.
469    ///
470    /// For computing the cumulative product of elements along a dimension, users should prefer the
471    #[cfg_attr(doc, doc = crate::doc_tensor!("cumprod"))]
472    #[cfg_attr(not(doc), doc = "`Tensor::cumprod`")]
473    /// function, which is more high-level and designed for public use.
474    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
475
476    /// Calculate absolute value on all elements of a tensor
477    ///
478    /// # Arguments
479    ///
480    /// * `tensor` - The tensor to apply abs to.
481    ///
482    /// # Returns
483    ///
484    /// A tensor with absolute values.
485    ///
486    /// # Remarks
487    ///
488    /// This is a low-level function used internally by the library to call different backend functions
489    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
490    /// or use this function directly.
491    ///
492    /// For calculating abs of the elements of a tensor, users should prefer the
493    #[cfg_attr(doc, doc = crate::doc_tensor!("abs"))]
494    #[cfg_attr(not(doc), doc = "`Tensor::abs`")]
495    /// function, which is more high-level and designed for public use.
496    fn abs(tensor: Self::Primitive) -> Self::Primitive;
497
498    /// Element-wise power of a tensor to a float tensor
499    ///
500    /// # Arguments
501    /// * `tensor` - The tensor to apply power to.
502    /// * `power` - The power to apply to the tensor.
503    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
504
505    /// Element-wise power of a tensor
506    ///
507    /// # Arguments
508    /// * `tensor` - The tensor to apply power to.
509    /// * `power` - The power to apply to the tensor.
510    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
511
512    /// Element-wise power of a tensor to a scalar float
513    ///
514    /// # Arguments
515    /// * `tensor` - The tensor to apply power to.
516    /// * `power` - The power to apply to the tensor.
517    fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
518
519    /// Element-wise power of a tensor to a scalar int
520    ///
521    /// # Arguments
522    /// * `tensor` - The tensor to apply power to.
523    /// * `power` - The power to apply to the tensor.
524    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
525
526    /// Create a random tensor.
527    ///
528    /// # Arguments
529    ///
530    /// * `shape` - The shape of the output tensor.
531    /// * `distribution` - The distribution used to sample.
532    /// * `device` - The device to use.
533    ///
534    /// # Returns
535    ///
536    /// A new tensor.
537    ///
538    /// # Remarks
539    ///
540    /// This is a low-level function used internally by the library to call different backend functions
541    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
542    /// or use this function directly.
543    ///
544    /// Users should prefer the
545    #[cfg_attr(doc, doc = crate::doc_tensor!("random"))]
546    #[cfg_attr(not(doc), doc = "`Tensor::random`")]
547    /// function, which is more high-level and designed for public use.
548    fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
549
550    /// Applies the matrix multiplication operation.
551    ///
552    /// ```math
553    /// C = AB
554    /// ```
555    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
556}