burn_backend/tensor/ops/numeric.rs
1use burn_std::Shape;
2
3use crate::{Backend, Distribution, Scalar, element::Element, tensor::BasicOps};
4
5/// Trait that list all operations that can be applied on all numerical tensors.
6///
7/// # Warnings
8///
9/// This is an internal trait, use the public API provided by the
10#[cfg_attr(doc, doc = crate::doc_tensor!())]
11#[cfg_attr(not(doc), doc = "`Tensor`")]
12/// struct.
13pub trait Numeric<B: Backend>: BasicOps<B>
14where
15 Self::Elem: Element,
16{
17 /// Adds two tensors together.
18 ///
19 /// # Arguments
20 ///
21 /// * `lhs` - The left hand side tensor.
22 /// * `rhs` - The right hand side tensor.
23 ///
24 /// # Returns
25 ///
26 /// The sum of the two tensors.
27 ///
28 /// # Remarks
29 ///
30 /// This is a low-level function used internally by the library to call different backend functions
31 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
32 /// or use this function directly.
33 ///
34 /// For adding tensors, users should prefer the
35 #[cfg_attr(doc, doc = crate::doc_tensor!("add"))]
36 #[cfg_attr(not(doc), doc = "`Tensor::add`")]
37 /// function, which is more high-level and designed for public use.
38 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
39
40 /// Adds a scalar to a tensor element-wise.
41 ///
42 /// # Arguments
43 ///
44 /// * `lhs` - The left hand side tensor.
45 /// * `rhs` - The right hand side scalar.
46 ///
47 /// # Returns
48 ///
49 /// The sum of the tensor and the scalar.
50 ///
51 /// # Remarks
52 ///
53 /// This is a low-level function used internally by the library to call different backend functions
54 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
55 /// or use this function directly.
56 ///
57 /// For adding a scalar to a tensor, users should prefer the
58 #[cfg_attr(doc, doc = crate::doc_tensor!("add_scalar"))]
59 #[cfg_attr(not(doc), doc = "`Tensor::add_scalar`")]
60 /// function, which is more high-level and designed for public use.
61 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
62
63 /// Subtracts two tensors.
64 ///
65 /// # Arguments
66 ///
67 /// * `lhs` - The left hand side tensor.
68 /// * `rhs` - The right hand side tensor.
69 ///
70 /// # Returns
71 ///
72 /// The difference of the two tensors.
73 ///
74 /// # Remarks
75 ///
76 /// This is a low-level function used internally by the library to call different backend functions
77 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
78 /// or use this function directly.
79 ///
80 /// For subtracting tensors, users should prefer the
81 #[cfg_attr(doc, doc = crate::doc_tensor!("sub"))]
82 #[cfg_attr(not(doc), doc = "`Tensor::sub`")]
83 /// function, which is more high-level and designed for public use.
84 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
85
86 /// Subtracts a scalar from a tensor element-wise.
87 ///
88 /// # Arguments
89 ///
90 /// * `lhs` - The left hand side tensor.
91 /// * `rhs` - The right hand side scalar.
92 ///
93 /// # Returns
94 ///
95 /// The difference of the tensor and the scalar.
96 ///
97 /// # Remarks
98 ///
99 /// This is a low-level function used internally by the library to call different backend functions
100 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
101 /// or use this function directly.
102 ///
103 /// For subtracting a scalar from a tensor, users should prefer the
104 #[cfg_attr(doc, doc = crate::doc_tensor!("sub_scalar"))]
105 #[cfg_attr(not(doc), doc = "`Tensor::sub_scalar`")]
106 /// function, which is more high-level and designed for public use.
107 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
108
109 /// Divides two tensors.
110 ///
111 /// # Arguments
112 ///
113 /// * `lhs` - The left hand side tensor.
114 /// * `rhs` - The right hand side tensor.
115 ///
116 /// # Returns
117 ///
118 /// The quotient of the two tensors.
119 ///
120 /// # Remarks
121 ///
122 /// This is a low-level function used internally by the library to call different backend functions
123 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
124 /// or use this function directly.
125 ///
126 /// For dividing tensors, users should prefer the
127 #[cfg_attr(doc, doc = crate::doc_tensor!("div"))]
128 #[cfg_attr(not(doc), doc = "`Tensor::div`")]
129 /// function, which is more high-level and designed for public use.
130 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
131
132 /// Divides a tensor by a scalar element-wise.
133 ///
134 /// # Arguments
135 ///
136 /// * `lhs` - The left hand side tensor.
137 /// * `rhs` - The right hand side scalar.
138 ///
139 /// # Returns
140 ///
141 /// The quotient of the tensor and the scalar.
142 ///
143 /// # Remarks
144 ///
145 /// This is a low-level function used internally by the library to call different backend functions
146 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
147 /// or use this function directly.
148 ///
149 /// For dividing a tensor by a scalar, users should prefer the
150 #[cfg_attr(doc, doc = crate::doc_tensor!("div_scalar"))]
151 #[cfg_attr(not(doc), doc = "`Tensor::div_scalar`")]
152 /// function, which is more high-level and designed for public use.
153 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
154
155 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
156 /// less than that of the divisor.
157 ///
158 /// # Arguments
159 ///
160 /// * `lhs` - The dividend.
161 /// * `rhs` - The divisor.
162 ///
163 /// # Returns
164 ///
165 /// The modulo of the input tensor with the divisor.
166 ///
167 /// # Remarks
168 ///
169 /// This is a low-level function used internally by the library to call different backend functions
170 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
171 /// or use this function directly.
172 ///
173 /// For performing the modulo operation, users should prefer the
174 #[cfg_attr(doc, doc = crate::doc_tensor!("remainder"))]
175 #[cfg_attr(not(doc), doc = "`Tensor::remainder`")]
176 /// function, which is more high-level and designed for public use.
177 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
178
179 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
180 /// less than that of the divisor.
181 ///
182 /// # Arguments
183 ///
184 /// * `lhs` - The dividend.
185 /// * `rhs` - The divisor.
186 ///
187 /// # Returns
188 ///
189 /// The modulo of the input tensor with the divisor.
190 ///
191 /// # Remarks
192 ///
193 /// This is a low-level function used internally by the library to call different backend functions
194 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
195 /// or use this function directly.
196 ///
197 /// For performing the modulo operation, users should prefer the
198 #[cfg_attr(doc, doc = crate::doc_tensor!("remainder_scalar"))]
199 #[cfg_attr(not(doc), doc = "`Tensor::remainder_scalar`")]
200 /// function, which is more high-level and designed for public use.
201 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
202
203 /// Multiplies two tensors.
204 ///
205 /// # Arguments
206 ///
207 /// * `lhs` - The left hand side tensor.
208 /// * `rhs` - The right hand side tensor.
209 ///
210 /// # Returns
211 ///
212 /// The product of the two tensors.
213 ///
214 /// # Remarks
215 ///
216 /// This is a low-level function used internally by the library to call different backend functions
217 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
218 /// or use this function directly.
219 ///
220 /// For multiplying tensors, users should prefer the
221 #[cfg_attr(doc, doc = crate::doc_tensor!("mul"))]
222 #[cfg_attr(not(doc), doc = "`Tensor::mul`")]
223 /// function, which is more high-level and designed for public use.
224 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
225
226 /// Multiplies a tensor by a scalar element-wise.
227 ///
228 /// # Arguments
229 ///
230 /// * `lhs` - The left hand side tensor.
231 /// * `rhs` - The right hand side scalar.
232 ///
233 /// # Returns
234 ///
235 /// The product of the tensor and the scalar.
236 ///
237 /// # Remarks
238 ///
239 /// This is a low-level function used internally by the library to call different backend functions
240 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
241 /// or use this function directly.
242 ///
243 /// For multiplying a tensor by a scalar, users should prefer the
244 #[cfg_attr(doc, doc = crate::doc_tensor!("mul_scalar"))]
245 #[cfg_attr(not(doc), doc = "`Tensor::mul_scalar`")]
246 /// function, which is more high-level and designed for public use.
247 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
248
249 /// Negates a tensor.
250 ///
251 /// # Arguments
252 ///
253 /// * `tensor` - The tensor to negate.
254 ///
255 /// # Returns
256 ///
257 /// The negated tensor.
258 ///
259 /// # Remarks
260 ///
261 /// This is a low-level function used internally by the library to call different backend functions
262 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
263 /// or use this function directly.
264 ///
265 /// For negating a tensor, users should prefer the
266 #[cfg_attr(doc, doc = crate::doc_tensor!("neg"))]
267 #[cfg_attr(not(doc), doc = "`Tensor::neg`")]
268 /// function, which is more high-level and designed for public use.
269 fn neg(tensor: Self::Primitive) -> Self::Primitive;
270
271 /// Returns the signs of the elements of a tensor.
272 ///
273 /// # Arguments
274 ///
275 /// * `tensor` - The tensor.
276 ///
277 /// # Returns
278 ///
279 /// The signs of the elements of the tensor.
280 ///
281 /// # Remarks
282 ///
283 /// This is a low-level function used internally by the library to call different backend functions
284 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
285 /// or use this function directly.
286 ///
287 /// For getting the signs of the elements of a tensor, users should prefer the
288 #[cfg_attr(doc, doc = crate::doc_tensor!("sign"))]
289 #[cfg_attr(not(doc), doc = "`Tensor::sign`")]
290 /// function, which is more high-level and designed for public use.
291 fn sign(tensor: Self::Primitive) -> Self::Primitive;
292
293 /// Sums all the elements of the tensor.
294 ///
295 /// # Arguments
296 ///
297 /// * `tensor` - The tensor to sum.
298 ///
299 /// # Returns
300 ///
301 /// The sum of all the elements of the tensor.
302 ///
303 /// # Remarks
304 ///
305 /// This is a low-level function used internally by the library to call different backend functions
306 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
307 /// or use this function directly.
308 ///
309 /// For summing all the elements of a tensor, users should prefer the
310 #[cfg_attr(doc, doc = crate::doc_tensor!("sum"))]
311 #[cfg_attr(not(doc), doc = "`Tensor::sum`")]
312 /// function, which is more high-level and designed for public use.
313 fn sum(tensor: Self::Primitive) -> Self::Primitive;
314
315 /// Sums all the elements of the tensor along a dimension.
316 ///
317 /// # Arguments
318 ///
319 /// * `tensor` - The tensor to sum.
320 /// * `dim` - The dimension along which to sum.
321 ///
322 /// # Returns
323 ///
324 /// The sum of all the elements of the tensor along the specified dimension.
325 ///
326 /// # Remarks
327 ///
328 /// This is a low-level function used internally by the library to call different backend functions
329 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
330 /// or use this function directly.
331 ///
332 /// For summing all the elements of a tensor along a dimension, users should prefer the
333 #[cfg_attr(doc, doc = crate::doc_tensor!("sum_dim"))]
334 #[cfg_attr(not(doc), doc = "`Tensor::sum_dim`")]
335 /// function, which is more high-level and designed for public use.
336 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
337
338 /// Computes the product of all the elements of the tensor.
339 ///
340 /// # Arguments
341 ///
342 /// * `tensor` - The tensor to compute the product of.
343 ///
344 /// # Returns
345 ///
346 /// The product of all the elements of the tensor.
347 ///
348 /// # Remarks
349 ///
350 /// This is a low-level function used internally by the library to call different backend functions
351 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
352 /// or use this function directly.
353 ///
354 /// For computing the product of all the elements of a tensor, users should prefer the
355 #[cfg_attr(doc, doc = crate::doc_tensor!("prod"))]
356 #[cfg_attr(not(doc), doc = "`Tensor::prod`")]
357 /// function, which is more high-level and designed for public use.
358 fn prod(tensor: Self::Primitive) -> Self::Primitive;
359
360 /// Computes the product of all the elements of the tensor along a dimension.
361 ///
362 /// # Arguments
363 ///
364 /// * `tensor` - The tensor to compute the product of.
365 /// * `dim` - The dimension along which to compute the product.
366 ///
367 /// # Returns
368 ///
369 /// The product of all the elements of the tensor along the specified dimension.
370 ///
371 /// # Remarks
372 ///
373 /// This is a low-level function used internally by the library to call different backend functions
374 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
375 /// or use this function directly.
376 ///
377 /// For computing the product of all the elements of a tensor along a dimension, users should prefer the
378 #[cfg_attr(doc, doc = crate::doc_tensor!("prod_dim"))]
379 #[cfg_attr(not(doc), doc = "`Tensor::prod_dim`")]
380 /// function, which is more high-level and designed for public use.
381 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
382
383 /// Computes the mean of all the elements of the tensor.
384 ///
385 /// # Arguments
386 ///
387 /// * `tensor` - The tensor to compute the mean of.
388 ///
389 /// # Returns
390 ///
391 /// The mean of all the elements of the tensor.
392 ///
393 /// # Remarks
394 ///
395 /// This is a low-level function used internally by the library to call different backend functions
396 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
397 /// or use this function directly.
398 ///
399 /// For computing the mean of all the elements of a tensor, users should prefer the
400 #[cfg_attr(doc, doc = crate::doc_tensor!("mean"))]
401 #[cfg_attr(not(doc), doc = "`Tensor::mean`")]
402 /// function, which is more high-level and designed for public use.
403 fn mean(tensor: Self::Primitive) -> Self::Primitive;
404
405 /// Computes the mean of all the elements of the tensor along a dimension.
406 ///
407 /// # Arguments
408 ///
409 /// * `tensor` - The tensor to compute the mean of.
410 /// * `dim` - The dimension along which to compute the mean.
411 ///
412 /// # Returns
413 ///
414 /// The mean of all the elements of the tensor along the specified dimension.
415 ///
416 /// # Remarks
417 ///
418 /// This is a low-level function used internally by the library to call different backend functions
419 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
420 /// or use this function directly.
421 ///
422 /// For computing the mean of all the elements of a tensor along a dimension, users should prefer the
423 #[cfg_attr(doc, doc = crate::doc_tensor!("mean_dim"))]
424 #[cfg_attr(not(doc), doc = "`Tensor::mean_dim`")]
425 /// function, which is more high-level and designed for public use.
426 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
427
428 /// Computes the cumulative sum of elements along a dimension.
429 ///
430 /// # Arguments
431 ///
432 /// * `tensor` - The tensor to compute the cumulative sum of.
433 /// * `dim` - The dimension along which to compute the cumulative sum.
434 ///
435 /// # Returns
436 ///
437 /// A tensor with the same shape as the input tensor, where each element is the cumulative sum
438 /// of all elements up to and including that position along the specified dimension.
439 ///
440 /// # Remarks
441 ///
442 /// This is a low-level function used internally by the library to call different backend functions
443 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
444 /// or use this function directly.
445 ///
446 /// For computing the cumulative sum of elements along a dimension, users should prefer the
447 #[cfg_attr(doc, doc = crate::doc_tensor!("cumsum"))]
448 #[cfg_attr(not(doc), doc = "`Tensor::cumsum`")]
449 /// function, which is more high-level and designed for public use.
450 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
451
452 /// Computes the cumulative product of elements along a dimension.
453 ///
454 /// # Arguments
455 ///
456 /// * `tensor` - The tensor to compute the cumulative product of.
457 /// * `dim` - The dimension along which to compute the cumulative product.
458 ///
459 /// # Returns
460 ///
461 /// A tensor with the same shape as the input tensor, where each element is the cumulative product
462 /// of all elements up to and including that position along the specified dimension.
463 ///
464 /// # Remarks
465 ///
466 /// This is a low-level function used internally by the library to call different backend functions
467 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
468 /// or use this function directly.
469 ///
470 /// For computing the cumulative product of elements along a dimension, users should prefer the
471 #[cfg_attr(doc, doc = crate::doc_tensor!("cumprod"))]
472 #[cfg_attr(not(doc), doc = "`Tensor::cumprod`")]
473 /// function, which is more high-level and designed for public use.
474 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
475
476 /// Calculate absolute value on all elements of a tensor
477 ///
478 /// # Arguments
479 ///
480 /// * `tensor` - The tensor to apply abs to.
481 ///
482 /// # Returns
483 ///
484 /// A tensor with absolute values.
485 ///
486 /// # Remarks
487 ///
488 /// This is a low-level function used internally by the library to call different backend functions
489 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
490 /// or use this function directly.
491 ///
492 /// For calculating abs of the elements of a tensor, users should prefer the
493 #[cfg_attr(doc, doc = crate::doc_tensor!("abs"))]
494 #[cfg_attr(not(doc), doc = "`Tensor::abs`")]
495 /// function, which is more high-level and designed for public use.
496 fn abs(tensor: Self::Primitive) -> Self::Primitive;
497
498 /// Element-wise power of a tensor to a float tensor
499 ///
500 /// # Arguments
501 /// * `tensor` - The tensor to apply power to.
502 /// * `power` - The power to apply to the tensor.
503 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
504
505 /// Element-wise power of a tensor
506 ///
507 /// # Arguments
508 /// * `tensor` - The tensor to apply power to.
509 /// * `power` - The power to apply to the tensor.
510 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
511
512 /// Element-wise power of a tensor to a scalar float
513 ///
514 /// # Arguments
515 /// * `tensor` - The tensor to apply power to.
516 /// * `power` - The power to apply to the tensor.
517 fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
518
519 /// Element-wise power of a tensor to a scalar int
520 ///
521 /// # Arguments
522 /// * `tensor` - The tensor to apply power to.
523 /// * `power` - The power to apply to the tensor.
524 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;
525
526 /// Create a random tensor.
527 ///
528 /// # Arguments
529 ///
530 /// * `shape` - The shape of the output tensor.
531 /// * `distribution` - The distribution used to sample.
532 /// * `device` - The device to use.
533 ///
534 /// # Returns
535 ///
536 /// A new tensor.
537 ///
538 /// # Remarks
539 ///
540 /// This is a low-level function used internally by the library to call different backend functions
541 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
542 /// or use this function directly.
543 ///
544 /// Users should prefer the
545 #[cfg_attr(doc, doc = crate::doc_tensor!("random"))]
546 #[cfg_attr(not(doc), doc = "`Tensor::random`")]
547 /// function, which is more high-level and designed for public use.
548 fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
549
550 /// Applies the matrix multiplication operation.
551 ///
552 /// ```math
553 /// C = AB
554 /// ```
555 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
556}