burn_backend/tensor/ops/numeric.rs
1use burn_std::Shape;
2
3use crate::{
4 Backend, Distribution,
5 element::{Element, ElementConversion},
6 tensor::{BasicOps, IntTensor},
7};
8
9/// Trait that list all operations that can be applied on all numerical tensors.
10///
11/// # Warnings
12///
13/// This is an internal trait, use the public API provided by the
14#[cfg_attr(doc, doc = crate::doc_tensor!())]
15#[cfg_attr(not(doc), doc = "`Tensor`")]
16/// struct.
17pub trait Numeric<B: Backend>: BasicOps<B>
18where
19 Self::Elem: Element,
20{
21 /// Adds two tensors together.
22 ///
23 /// # Arguments
24 ///
25 /// * `lhs` - The left hand side tensor.
26 /// * `rhs` - The right hand side tensor.
27 ///
28 /// # Returns
29 ///
30 /// The sum of the two tensors.
31 ///
32 /// # Remarks
33 ///
34 /// This is a low-level function used internally by the library to call different backend functions
35 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
36 /// or use this function directly.
37 ///
38 /// For adding tensors, users should prefer the
39 #[cfg_attr(doc, doc = crate::doc_tensor!("add"))]
40 #[cfg_attr(not(doc), doc = "`Tensor::add`")]
41 /// function, which is more high-level and designed for public use.
42 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
43
44 /// Adds a scalar to a tensor element-wise.
45 ///
46 /// # Arguments
47 ///
48 /// * `lhs` - The left hand side tensor.
49 /// * `rhs` - The right hand side scalar.
50 ///
51 /// # Returns
52 ///
53 /// The sum of the tensor and the scalar.
54 ///
55 /// # Remarks
56 ///
57 /// This is a low-level function used internally by the library to call different backend functions
58 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
59 /// or use this function directly.
60 ///
61 /// For adding a scalar to a tensor, users should prefer the
62 #[cfg_attr(doc, doc = crate::doc_tensor!("add_scalar"))]
63 #[cfg_attr(not(doc), doc = "`Tensor::add_scalar`")]
64 /// function, which is more high-level and designed for public use.
65 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
66
67 /// Subtracts two tensors.
68 ///
69 /// # Arguments
70 ///
71 /// * `lhs` - The left hand side tensor.
72 /// * `rhs` - The right hand side tensor.
73 ///
74 /// # Returns
75 ///
76 /// The difference of the two tensors.
77 ///
78 /// # Remarks
79 ///
80 /// This is a low-level function used internally by the library to call different backend functions
81 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
82 /// or use this function directly.
83 ///
84 /// For subtracting tensors, users should prefer the
85 #[cfg_attr(doc, doc = crate::doc_tensor!("sub"))]
86 #[cfg_attr(not(doc), doc = "`Tensor::sub`")]
87 /// function, which is more high-level and designed for public use.
88 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
89
90 /// Subtracts a scalar from a tensor element-wise.
91 ///
92 /// # Arguments
93 ///
94 /// * `lhs` - The left hand side tensor.
95 /// * `rhs` - The right hand side scalar.
96 ///
97 /// # Returns
98 ///
99 /// The difference of the tensor and the scalar.
100 ///
101 /// # Remarks
102 ///
103 /// This is a low-level function used internally by the library to call different backend functions
104 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
105 /// or use this function directly.
106 ///
107 /// For subtracting a scalar from a tensor, users should prefer the
108 #[cfg_attr(doc, doc = crate::doc_tensor!("sub_scalar"))]
109 #[cfg_attr(not(doc), doc = "`Tensor::sub_scalar`")]
110 /// function, which is more high-level and designed for public use.
111 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
112
113 /// Divides two tensors.
114 ///
115 /// # Arguments
116 ///
117 /// * `lhs` - The left hand side tensor.
118 /// * `rhs` - The right hand side tensor.
119 ///
120 /// # Returns
121 ///
122 /// The quotient of the two tensors.
123 ///
124 /// # Remarks
125 ///
126 /// This is a low-level function used internally by the library to call different backend functions
127 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
128 /// or use this function directly.
129 ///
130 /// For dividing tensors, users should prefer the
131 #[cfg_attr(doc, doc = crate::doc_tensor!("div"))]
132 #[cfg_attr(not(doc), doc = "`Tensor::div`")]
133 /// function, which is more high-level and designed for public use.
134 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
135
136 /// Divides a tensor by a scalar element-wise.
137 ///
138 /// # Arguments
139 ///
140 /// * `lhs` - The left hand side tensor.
141 /// * `rhs` - The right hand side scalar.
142 ///
143 /// # Returns
144 ///
145 /// The quotient of the tensor and the scalar.
146 ///
147 /// # Remarks
148 ///
149 /// This is a low-level function used internally by the library to call different backend functions
150 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
151 /// or use this function directly.
152 ///
153 /// For dividing a tensor by a scalar, users should prefer the
154 #[cfg_attr(doc, doc = crate::doc_tensor!("div_scalar"))]
155 #[cfg_attr(not(doc), doc = "`Tensor::div_scalar`")]
156 /// function, which is more high-level and designed for public use.
157 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
158
159 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
160 /// less than that of the divisor.
161 ///
162 /// # Arguments
163 ///
164 /// * `lhs` - The dividend.
165 /// * `rhs` - The divisor.
166 ///
167 /// # Returns
168 ///
169 /// The modulo of the input tensor with the divisor.
170 ///
171 /// # Remarks
172 ///
173 /// This is a low-level function used internally by the library to call different backend functions
174 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
175 /// or use this function directly.
176 ///
177 /// For performing the modulo operation, users should prefer the
178 #[cfg_attr(doc, doc = crate::doc_tensor!("remainder"))]
179 #[cfg_attr(not(doc), doc = "`Tensor::remainder`")]
180 /// function, which is more high-level and designed for public use.
181 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
182
183 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
184 /// less than that of the divisor.
185 ///
186 /// # Arguments
187 ///
188 /// * `lhs` - The dividend.
189 /// * `rhs` - The divisor.
190 ///
191 /// # Returns
192 ///
193 /// The modulo of the input tensor with the divisor.
194 ///
195 /// # Remarks
196 ///
197 /// This is a low-level function used internally by the library to call different backend functions
198 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
199 /// or use this function directly.
200 ///
201 /// For performing the modulo operation, users should prefer the
202 #[cfg_attr(doc, doc = crate::doc_tensor!("remainder_scalar"))]
203 #[cfg_attr(not(doc), doc = "`Tensor::remainder_scalar`")]
204 /// function, which is more high-level and designed for public use.
205 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
206
207 /// Multiplies two tensors.
208 ///
209 /// # Arguments
210 ///
211 /// * `lhs` - The left hand side tensor.
212 /// * `rhs` - The right hand side tensor.
213 ///
214 /// # Returns
215 ///
216 /// The product of the two tensors.
217 ///
218 /// # Remarks
219 ///
220 /// This is a low-level function used internally by the library to call different backend functions
221 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
222 /// or use this function directly.
223 ///
224 /// For multiplying tensors, users should prefer the
225 #[cfg_attr(doc, doc = crate::doc_tensor!("mul"))]
226 #[cfg_attr(not(doc), doc = "`Tensor::mul`")]
227 /// function, which is more high-level and designed for public use.
228 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
229
230 /// Multiplies a tensor by a scalar element-wise.
231 ///
232 /// # Arguments
233 ///
234 /// * `lhs` - The left hand side tensor.
235 /// * `rhs` - The right hand side scalar.
236 ///
237 /// # Returns
238 ///
239 /// The product of the tensor and the scalar.
240 ///
241 /// # Remarks
242 ///
243 /// This is a low-level function used internally by the library to call different backend functions
244 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
245 /// or use this function directly.
246 ///
247 /// For multiplying a tensor by a scalar, users should prefer the
248 #[cfg_attr(doc, doc = crate::doc_tensor!("mul_scalar"))]
249 #[cfg_attr(not(doc), doc = "`Tensor::mul_scalar`")]
250 /// function, which is more high-level and designed for public use.
251 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
252
253 /// Negates a tensor.
254 ///
255 /// # Arguments
256 ///
257 /// * `tensor` - The tensor to negate.
258 ///
259 /// # Returns
260 ///
261 /// The negated tensor.
262 ///
263 /// # Remarks
264 ///
265 /// This is a low-level function used internally by the library to call different backend functions
266 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
267 /// or use this function directly.
268 ///
269 /// For negating a tensor, users should prefer the
270 #[cfg_attr(doc, doc = crate::doc_tensor!("neg"))]
271 #[cfg_attr(not(doc), doc = "`Tensor::neg`")]
272 /// function, which is more high-level and designed for public use.
273 fn neg(tensor: Self::Primitive) -> Self::Primitive;
274
275 /// Returns the signs of the elements of a tensor.
276 ///
277 /// # Arguments
278 ///
279 /// * `tensor` - The tensor.
280 ///
281 /// # Returns
282 ///
283 /// The signs of the elements of the tensor.
284 ///
285 /// # Remarks
286 ///
287 /// This is a low-level function used internally by the library to call different backend functions
288 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
289 /// or use this function directly.
290 ///
291 /// For getting the signs of the elements of a tensor, users should prefer the
292 #[cfg_attr(doc, doc = crate::doc_tensor!("sign"))]
293 #[cfg_attr(not(doc), doc = "`Tensor::sign`")]
294 /// function, which is more high-level and designed for public use.
295 fn sign(tensor: Self::Primitive) -> Self::Primitive;
296
297 /// Sums all the elements of the tensor.
298 ///
299 /// # Arguments
300 ///
301 /// * `tensor` - The tensor to sum.
302 ///
303 /// # Returns
304 ///
305 /// The sum of all the elements of the tensor.
306 ///
307 /// # Remarks
308 ///
309 /// This is a low-level function used internally by the library to call different backend functions
310 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
311 /// or use this function directly.
312 ///
313 /// For summing all the elements of a tensor, users should prefer the
314 #[cfg_attr(doc, doc = crate::doc_tensor!("sum"))]
315 #[cfg_attr(not(doc), doc = "`Tensor::sum`")]
316 /// function, which is more high-level and designed for public use.
317 fn sum(tensor: Self::Primitive) -> Self::Primitive;
318
319 /// Sums all the elements of the tensor along a dimension.
320 ///
321 /// # Arguments
322 ///
323 /// * `tensor` - The tensor to sum.
324 /// * `dim` - The dimension along which to sum.
325 ///
326 /// # Returns
327 ///
328 /// The sum of all the elements of the tensor along the specified dimension.
329 ///
330 /// # Remarks
331 ///
332 /// This is a low-level function used internally by the library to call different backend functions
333 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
334 /// or use this function directly.
335 ///
336 /// For summing all the elements of a tensor along a dimension, users should prefer the
337 #[cfg_attr(doc, doc = crate::doc_tensor!("sum_dim"))]
338 #[cfg_attr(not(doc), doc = "`Tensor::sum_dim`")]
339 /// function, which is more high-level and designed for public use.
340 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
341
342 /// Computes the product of all the elements of the tensor.
343 ///
344 /// # Arguments
345 ///
346 /// * `tensor` - The tensor to compute the product of.
347 ///
348 /// # Returns
349 ///
350 /// The product of all the elements of the tensor.
351 ///
352 /// # Remarks
353 ///
354 /// This is a low-level function used internally by the library to call different backend functions
355 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
356 /// or use this function directly.
357 ///
358 /// For computing the product of all the elements of a tensor, users should prefer the
359 #[cfg_attr(doc, doc = crate::doc_tensor!("prod"))]
360 #[cfg_attr(not(doc), doc = "`Tensor::prod`")]
361 /// function, which is more high-level and designed for public use.
362 fn prod(tensor: Self::Primitive) -> Self::Primitive;
363
364 /// Computes the product of all the elements of the tensor along a dimension.
365 ///
366 /// # Arguments
367 ///
368 /// * `tensor` - The tensor to compute the product of.
369 /// * `dim` - The dimension along which to compute the product.
370 ///
371 /// # Returns
372 ///
373 /// The product of all the elements of the tensor along the specified dimension.
374 ///
375 /// # Remarks
376 ///
377 /// This is a low-level function used internally by the library to call different backend functions
378 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
379 /// or use this function directly.
380 ///
381 /// For computing the product of all the elements of a tensor along a dimension, users should prefer the
382 #[cfg_attr(doc, doc = crate::doc_tensor!("prod_dim"))]
383 #[cfg_attr(not(doc), doc = "`Tensor::prod_dim`")]
384 /// function, which is more high-level and designed for public use.
385 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
386
387 /// Computes the mean of all the elements of the tensor.
388 ///
389 /// # Arguments
390 ///
391 /// * `tensor` - The tensor to compute the mean of.
392 ///
393 /// # Returns
394 ///
395 /// The mean of all the elements of the tensor.
396 ///
397 /// # Remarks
398 ///
399 /// This is a low-level function used internally by the library to call different backend functions
400 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
401 /// or use this function directly.
402 ///
403 /// For computing the mean of all the elements of a tensor, users should prefer the
404 #[cfg_attr(doc, doc = crate::doc_tensor!("mean"))]
405 #[cfg_attr(not(doc), doc = "`Tensor::mean`")]
406 /// function, which is more high-level and designed for public use.
407 fn mean(tensor: Self::Primitive) -> Self::Primitive;
408
409 /// Computes the mean of all the elements of the tensor along a dimension.
410 ///
411 /// # Arguments
412 ///
413 /// * `tensor` - The tensor to compute the mean of.
414 /// * `dim` - The dimension along which to compute the mean.
415 ///
416 /// # Returns
417 ///
418 /// The mean of all the elements of the tensor along the specified dimension.
419 ///
420 /// # Remarks
421 ///
422 /// This is a low-level function used internally by the library to call different backend functions
423 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
424 /// or use this function directly.
425 ///
426 /// For computing the mean of all the elements of a tensor along a dimension, users should prefer the
427 #[cfg_attr(doc, doc = crate::doc_tensor!("mean_dim"))]
428 #[cfg_attr(not(doc), doc = "`Tensor::mean_dim`")]
429 /// function, which is more high-level and designed for public use.
430 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
431
432 /// Computes the cumulative sum of elements along a dimension.
433 ///
434 /// # Arguments
435 ///
436 /// * `tensor` - The tensor to compute the cumulative sum of.
437 /// * `dim` - The dimension along which to compute the cumulative sum.
438 ///
439 /// # Returns
440 ///
441 /// A tensor with the same shape as the input tensor, where each element is the cumulative sum
442 /// of all elements up to and including that position along the specified dimension.
443 ///
444 /// # Remarks
445 ///
446 /// This is a low-level function used internally by the library to call different backend functions
447 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
448 /// or use this function directly.
449 ///
450 /// For computing the cumulative sum of elements along a dimension, users should prefer the
451 #[cfg_attr(doc, doc = crate::doc_tensor!("cumsum"))]
452 #[cfg_attr(not(doc), doc = "`Tensor::cumsum`")]
453 /// function, which is more high-level and designed for public use.
454 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
455
456 /// Computes the cumulative product of elements along a dimension.
457 ///
458 /// # Arguments
459 ///
460 /// * `tensor` - The tensor to compute the cumulative product of.
461 /// * `dim` - The dimension along which to compute the cumulative product.
462 ///
463 /// # Returns
464 ///
465 /// A tensor with the same shape as the input tensor, where each element is the cumulative product
466 /// of all elements up to and including that position along the specified dimension.
467 ///
468 /// # Remarks
469 ///
470 /// This is a low-level function used internally by the library to call different backend functions
471 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
472 /// or use this function directly.
473 ///
474 /// For computing the cumulative product of elements along a dimension, users should prefer the
475 #[cfg_attr(doc, doc = crate::doc_tensor!("cumprod"))]
476 #[cfg_attr(not(doc), doc = "`Tensor::cumprod`")]
477 /// function, which is more high-level and designed for public use.
478 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
479
480 /// Computes the cumulative minimum of elements along a dimension.
481 ///
482 /// # Arguments
483 ///
484 /// * `tensor` - The tensor to compute the cumulative minimum of.
485 /// * `dim` - The dimension along which to compute the cumulative minimum.
486 ///
487 /// # Returns
488 ///
489 /// A tensor with the same shape as the input tensor, where each element is the minimum
490 /// of all elements up to and including that position along the specified dimension.
491 ///
492 /// # Remarks
493 ///
494 /// This is a low-level function used internally by the library to call different backend functions
495 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
496 /// or use this function directly.
497 ///
498 /// For computing the cumulative minimum of elements along a dimension, users should prefer the
499 #[cfg_attr(doc, doc = crate::doc_tensor!("cummin"))]
500 #[cfg_attr(not(doc), doc = "`Tensor::cummin`")]
501 /// function, which is more high-level and designed for public use.
502 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
503
504 /// Computes the cumulative maximum of elements along a dimension.
505 ///
506 /// # Arguments
507 ///
508 /// * `tensor` - The tensor to compute the cumulative maximum of.
509 /// * `dim` - The dimension along which to compute the cumulative maximum.
510 ///
511 /// # Returns
512 ///
513 /// A tensor with the same shape as the input tensor, where each element is the maximum
514 /// of all elements up to and including that position along the specified dimension.
515 ///
516 /// # Remarks
517 ///
518 /// This is a low-level function used internally by the library to call different backend functions
519 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
520 /// or use this function directly.
521 ///
522 /// For computing the cumulative maximum of elements along a dimension, users should prefer the
523 #[cfg_attr(doc, doc = crate::doc_tensor!("cummax"))]
524 #[cfg_attr(not(doc), doc = "`Tensor::cummax`")]
525 /// function, which is more high-level and designed for public use.
526 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
527
528 /// Element-wise greater than comparison between two tensors.
529 ///
530 /// # Arguments
531 ///
532 /// * `lhs` - The left hand side tensor.
533 /// * `rhs` - The right hand side tensor.
534 ///
535 /// # Returns
536 ///
537 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
538 /// corresponding element of the left hand side tensor is greater than the corresponding element
539 /// of the right hand side tensor, and false otherwise.
540 ///
541 /// # Remarks
542 ///
543 /// This is a low-level function used internally by the library to call different backend functions
544 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
545 /// or use this function directly.
546 ///
547 /// For element-wise greater than comparison between two tensors, users should prefer the
548 #[cfg_attr(doc, doc = crate::doc_tensor!("greater"))]
549 #[cfg_attr(not(doc), doc = "`Tensor::greater`")]
550 /// function, which is more high-level and designed for public use.
551 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
552
553 /// Element-wise greater than comparison between a tensor and a scalar.
554 ///
555 /// # Arguments
556 ///
557 /// * `lhs` - The left hand side tensor.
558 /// * `rhs` - The right hand side scalar.
559 ///
560 /// # Returns
561 ///
562 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
563 /// corresponding element of the left hand side tensor is greater than the right hand side
564 /// scalar, and false otherwise.
565 ///
566 /// # Remarks
567 ///
568 /// This is a low-level function used internally by the library to call different backend functions
569 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
570 /// or use this function directly.
571 ///
572 /// For element-wise greater than comparison between a tensor and a scalar, users should prefer the
573 #[cfg_attr(doc, doc = crate::doc_tensor!("greater_elem"))]
574 #[cfg_attr(not(doc), doc = "`Tensor::greater_elem`")]
575 /// function, which is more high-level and designed for public use.
576 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
577
578 /// Element-wise greater than or equal comparison between two tensors.
579 ///
580 /// # Arguments
581 ///
582 /// * `lhs` - The left hand side tensor.
583 /// * `rhs` - The right hand side tensor.
584 ///
585 /// # Returns
586 ///
587 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
588 /// corresponding element of the left hand side tensor is greater than or equal to the
589 /// corresponding element of the right hand side tensor, and false otherwise.
590 ///
591 /// # Remarks
592 ///
593 /// This is a low-level function used internally by the library to call different backend functions
594 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
595 /// or use this function directly.
596 ///
597 /// For element-wise greater than or equal comparison between two tensors, users should prefer the
598 #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal"))]
599 #[cfg_attr(not(doc), doc = "`Tensor::greater_equal`")]
600 /// function, which is more high-level and designed for public use.
601 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
602
603 /// Element-wise greater than or equal comparison between a tensor and a scalar.
604 ///
605 /// # Arguments
606 ///
607 /// * `lhs` - The left hand side tensor.
608 /// * `rhs` - The right hand side scalar.
609 ///
610 /// # Returns
611 ///
612 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
613 /// corresponding element of the left hand side tensor is greater than or equal to the right
614 /// hand side scalar, and false otherwise.
615 ///
616 /// # Remarks
617 ///
618 /// This is a low-level function used internally by the library to call different backend functions
619 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
620 /// or use this function directly.
621 ///
622 /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer the
623 #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal_elem"))]
624 #[cfg_attr(not(doc), doc = "`Tensor::greater_equal_elem`")]
625 /// function, which is more high-level and designed for public use.
626 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
627
628 /// Element-wise less than comparison between two tensors.
629 ///
630 /// # Arguments
631 ///
632 /// * `lhs` - The left hand side tensor.
633 /// * `rhs` - The right hand side tensor.
634 ///
635 /// # Returns
636 ///
637 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
638 /// corresponding element of the left hand side tensor is less than the corresponding element of
639 /// the right hand side tensor, and false otherwise.
640 ///
641 /// # Remarks
642 ///
643 /// This is a low-level function used internally by the library to call different backend functions
644 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
645 /// or use this function directly.
646 ///
647 /// For element-wise less than comparison between two tensors, users should prefer the
648 #[cfg_attr(doc, doc = crate::doc_tensor!("lower"))]
649 #[cfg_attr(not(doc), doc = "`Tensor::lower`")]
650 /// function, which is more high-level and designed for public use.
651 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
652
653 /// Element-wise less than comparison between a tensor and a scalar.
654 ///
655 /// # Arguments
656 ///
657 /// * `lhs` - The left hand side tensor.
658 /// * `rhs` - The right hand side scalar.
659 ///
660 /// # Returns
661 ///
662 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
663 /// corresponding element of the left hand side tensor is less than the right hand side scalar,
664 /// and false otherwise.
665 ///
666 /// # Remarks
667 ///
668 /// This is a low-level function used internally by the library to call different backend functions
669 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
670 /// or use this function directly.
671 ///
672 /// For element-wise less than comparison between a tensor and a scalar, users should prefer the
673 #[cfg_attr(doc, doc = crate::doc_tensor!("lower_elem"))]
674 #[cfg_attr(not(doc), doc = "`Tensor::lower_elem`")]
675 /// function, which is more high-level and designed for public use.
676 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
677
678 /// Element-wise less than or equal comparison between two tensors.
679 ///
680 /// # Arguments
681 ///
682 /// * `lhs` - The left hand side tensor.
683 /// * `rhs` - The right hand side tensor.
684 ///
685 /// # Returns
686 ///
687 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
688 /// corresponding element of the left hand side tensor is less than or equal to the corresponding
689 /// element of the right hand side tensor, and false otherwise.
690 ///
691 /// # Remarks
692 ///
693 /// This is a low-level function used internally by the library to call different backend functions
694 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
695 /// or use this function directly.
696 ///
697 /// For element-wise less than or equal comparison between two tensors, users should prefer the
698 #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal"))]
699 #[cfg_attr(not(doc), doc = "`Tensor::lower_equal`")]
700 /// function, which is more high-level and designed for public use.
701 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
702
703 /// Element-wise less than or equal comparison between a tensor and a scalar.
704 ///
705 /// # Arguments
706 ///
707 /// * `lhs` - The left hand side tensor.
708 /// * `rhs` - The right hand side scalar.
709 ///
710 /// # Returns
711 ///
712 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
713 /// corresponding element of the left hand side tensor is less than or equal to the right hand
714 /// side scalar, and false otherwise.
715 ///
716 /// # Remarks
717 ///
718 /// This is a low-level function used internally by the library to call different backend functions
719 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
720 /// or use this function directly.
721 ///
722 /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer the
723 #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal_elem"))]
724 #[cfg_attr(not(doc), doc = "`Tensor::lower_equal_elem`")]
725 /// function, which is more high-level and designed for public use.
726 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
727
728 /// Gets the indices of the maximum elements of a tensor along an axis.
729 ///
730 /// # Arguments
731 ///
732 /// * `dim` - The axis along which to get the indices of the maximum elements.
733 /// * `tensor` - The tensor to get the indices of the maximum elements from.
734 ///
735 /// # Returns
736 ///
737 /// A tensor with the same shape as the input tensor, where each element is the index of the
738 /// maximum element of the input tensor at the corresponding index along the specified axis.
739 ///
740 /// # Remarks
741 ///
742 /// This is a low-level function used internally by the library to call different backend functions
743 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
744 /// or use this function directly.
745 ///
746 /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
747 #[cfg_attr(doc, doc = crate::doc_tensor!("argmax"))]
748 #[cfg_attr(not(doc), doc = "`Tensor::argmax`")]
749 /// function, which is more high-level and designed for public use.
750 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B>;
751
752 /// Gets the indices of the minimum elements of a tensor along an axis.
753 ///
754 /// # Arguments
755 ///
756 /// * `dim` - The axis along which to get the indices of the minimum elements.
757 /// * `tensor` - The tensor to get the indices of the minimum elements from.
758 ///
759 /// # Returns
760 ///
761 /// A tensor with the same shape as the input tensor, where each element is the index of the
762 /// minimum element of the input tensor at the corresponding index along the specified axis.
763 ///
764 /// # Remarks
765 ///
766 /// This is a low-level function used internally by the library to call different backend functions
767 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
768 /// or use this function directly.
769 ///
770 /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
771 #[cfg_attr(doc, doc = crate::doc_tensor!("argmin"))]
772 #[cfg_attr(not(doc), doc = "`Tensor::argmin`")]
773 /// function, which is more high-level and designed for public use.
774 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B>;
775
776 /// Gets the maximum elements of a tensor along an axis.
777 ///
778 /// # Arguments
779 ///
780 /// * `dim` - The axis along which to get the maximum elements.
781 ///
782 /// # Returns
783 ///
784 /// A single-element tensor containing the maximum element of the input tensor.
785 ///
786 /// # Remarks
787 ///
788 /// This is a low-level function used internally by the library to call different backend functions
789 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
790 /// or use this function directly.
791 ///
792 /// For getting the maximum elements of a tensor along an axis, users should prefer the
793 #[cfg_attr(doc, doc = crate::doc_tensor!("max"))]
794 #[cfg_attr(not(doc), doc = "`Tensor::max`")]
795 /// function, which is more high-level and designed for public use.
796 fn max(tensor: Self::Primitive) -> Self::Primitive;
797
798 /// Gets the maximum elements of a tensor along an axis.
799 ///
800 /// # Arguments
801 ///
802 /// * `tensor` - The tensor to get the maximum elements from.
803 /// * `dim` - The axis along which to get the maximum elements.
804 ///
805 /// # Returns
806 ///
807 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
808 /// Each element is the maximum element of the corresponding input dim.
809 ///
810 /// # Remarks
811 ///
812 /// This is a low-level function used internally by the library to call different backend functions
813 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
814 /// or use this function directly.
815 ///
816 /// For getting the maximum elements of a tensor along an axis, users should prefer the
817 #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim"))]
818 #[cfg_attr(not(doc), doc = "`Tensor::max_dim`")]
819 /// function, which is more high-level and designed for public use.
820 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
821
822 /// Gets the maximum elements of a tensor along an axis.
823 ///
824 /// # Arguments
825 ///
826 /// * `tensor` - The tensor to get the maximum elements from.
827 /// * `dim` - The axis along which to get the maximum elements.
828 ///
829 /// # Returns
830 ///
831 /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
832 /// as the input tensor, where each element is the index of the maximum element of the input tensor
833 /// at the corresponding index along the specified axis.
834 ///
835 /// # Remarks
836 ///
837 /// This is a low-level function used internally by the library to call different backend functions
838 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
839 /// or use this function directly.
840 ///
841 /// For getting the maximum elements of a tensor along an axis, users should prefer the
842 #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim_with_indices"))]
843 #[cfg_attr(not(doc), doc = "`Tensor::max_dim_with_indices`")]
844 /// function, which is more high-level and designed for public use.
845 fn max_dim_with_indices(tensor: Self::Primitive, dim: usize)
846 -> (Self::Primitive, IntTensor<B>);
847
848 /// Gets the maximum elements of a tensor along an axis.
849 ///
850 /// # Arguments
851 ///
852 /// * `dim` - The axis along which to get the maximum elements.
853 ///
854 /// # Returns
855 ///
856 /// A single-element tensor containing the maximum absolute element of the input tensor.
857 ///
858 /// # Remarks
859 ///
860 /// This is a low-level function used internally by the library to call different backend functions
861 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
862 /// or use this function directly.
863 ///
864 /// For getting the maximum absolute elements of a tensor, users should prefer the
865 #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs"))]
866 #[cfg_attr(not(doc), doc = "`Tensor::max_abs`")]
867 /// function, which is more high-level and designed for public use.
868 fn max_abs(tensor: Self::Primitive) -> Self::Primitive;
869
870 /// Gets the maximum elements of a tensor along an axis.
871 ///
872 /// # Arguments
873 ///
874 /// * `tensor` - The tensor to get the maximum elements from.
875 /// * `dim` - The axis along which to get the maximum elements.
876 ///
877 /// # Returns
878 ///
879 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
880 /// Each element is the maximum absolute element of the corresponding input dim.
881 ///
882 /// # Remarks
883 ///
884 /// This is a low-level function used internally by the library to call different backend functions
885 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
886 /// or use this function directly.
887 ///
888 /// For getting the maximum elements of a tensor along an axis, users should prefer the
889 #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs_dim"))]
890 #[cfg_attr(not(doc), doc = "`Tensor::max_abs_dim`")]
891 /// function, which is more high-level and designed for public use.
892 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
893
894 /// Gets the minimum elements of a tensor along an axis.
895 ///
896 /// # Arguments
897 ///
898 /// * `tensor` - The tensor to get the minimum elements from.
899 ///
900 /// # Returns
901 ///
902 /// A single-element tensor containing the minimum element of the input tensor.
903 ///
904 /// # Remarks
905 ///
906 /// This is a low-level function used internally by the library to call different backend functions
907 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
908 /// or use this function directly.
909 ///
910 /// For getting the minimum elements of a tensor along an axis, users should prefer the
911 #[cfg_attr(doc, doc = crate::doc_tensor!("min"))]
912 #[cfg_attr(not(doc), doc = "`Tensor::min`")]
913 /// function, which is more high-level and designed for public use.
914 fn min(tensor: Self::Primitive) -> Self::Primitive;
915
916 /// Gets the minimum elements of a tensor along an axis.
917 ///
918 /// # Arguments
919 ///
920 /// * `tensor` - The tensor to get the minimum elements from.
921 /// * `dim` - The axis along which to get the minimum elements.
922 ///
923 /// # Returns
924 ///
925 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
926 /// Each element is the minimum element of the corresponding input dim.
927 ///
928 /// # Remarks
929 ///
930 /// This is a low-level function used internally by the library to call different backend functions
931 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
932 /// or use this function directly.
933 ///
934 /// For getting the minimum elements of a tensor along an axis, users should prefer the
935 #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim"))]
936 #[cfg_attr(not(doc), doc = "`Tensor::min_dim`")]
937 /// function, which is more high-level and designed for public use.
938 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
939
940 /// Gets the minimum elements and indices of a tensor along an axis.
941 ///
942 /// # Arguments
943 ///
944 /// * `tensor` - The tensor to get the minimum elements from.
945 ///
946 /// # Returns
947 ///
948 /// A tensor with the same shape as the input tensor and corresponding indices, where
949 /// each element is the minimum element of the input tensor at the corresponding index
950 /// along the specified axis.
951 ///
952 /// # Remarks
953 ///
954 /// This is a low-level function used internally by the library to call different backend functions
955 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
956 /// or use this function directly.
957 ///
958 /// For getting the minimum elements of a tensor along an axis, users should prefer the
959 #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim_with_indices"))]
960 #[cfg_attr(not(doc), doc = "`Tensor::min_dim_with_indices`")]
961 /// function, which is more high-level and designed for public use.
962 fn min_dim_with_indices(tensor: Self::Primitive, dim: usize)
963 -> (Self::Primitive, IntTensor<B>);
964
965 /// Clamp the tensor between the given min and max values.
966 ///
967 /// # Arguments
968 ///
969 /// * `min` - The minimum value.
970 /// * `max` - The maximum value.
971 ///
972 /// # Returns
973 ///
974 /// A new tensor with the values clamped between the given min and max values.
975 ///
976 /// # Remarks
977 ///
978 /// This is a low-level function used internally by the library to call different backend functions
979 /// with static dispatch. It is not designed for direct usage by users.
980 ///
981 /// For clamping a tensor between the given min and max values, users should prefer the
982 #[cfg_attr(doc, doc = crate::doc_tensor!("clamp"))]
983 #[cfg_attr(not(doc), doc = "`Tensor::clamp`")]
984 /// function, which is more high-level and designed for public use.
985 fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
986
987 /// Clamps a tensor under a minimum value.
988 ///
989 /// # Arguments
990 ///
991 /// * `tensor` - The tensor to clamp.
992 /// * `min` - The minimum value.
993 ///
994 /// # Returns
995 ///
996 /// A new tensor with the values clamped under the given min value.
997 ///
998 /// # Remarks
999 ///
1000 /// This is a low-level function used internally by the library to call different backend functions
1001 /// with static dispatch. It is not designed for direct usage by users.
1002 ///
1003 /// For clamping a tensor under a minimum value, users should prefer the
1004 #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_min"))]
1005 #[cfg_attr(not(doc), doc = "`Tensor::clamp_min`")]
1006 /// function, which is more high-level and designed for public use.
1007 fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
1008
1009 /// Clamps a tensor over a maximum value.
1010 ///
1011 /// # Arguments
1012 ///
1013 /// * `tensor` - The tensor to clamp.
1014 /// * `max` - The maximum value.
1015 ///
1016 /// # Returns
1017 ///
1018 /// A new tensor with the values clamped over the given max value.
1019 ///
1020 /// # Remarks
1021 ///
1022 /// This is a low-level function used internally by the library to call different backend functions
1023 /// with static dispatch. It is not designed for direct usage by users.
1024 ///
1025 /// For clamping a tensor over a maximum value, users should prefer the
1026 #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_max"))]
1027 #[cfg_attr(not(doc), doc = "`Tensor::clamp_max`")]
1028 /// function, which is more high-level and designed for public use.
1029 fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
1030
1031 /// Calculate absolute value on all elements of a tensor
1032 ///
1033 /// # Arguments
1034 ///
1035 /// * `tensor` - The tensor to apply abs to.
1036 ///
1037 /// # Returns
1038 ///
1039 /// A tensor with absolute values.
1040 ///
1041 /// # Remarks
1042 ///
1043 /// This is a low-level function used internally by the library to call different backend functions
1044 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1045 /// or use this function directly.
1046 ///
1047 /// For calculating abs of the elements of a tensor, users should prefer the
1048 #[cfg_attr(doc, doc = crate::doc_tensor!("abs"))]
1049 #[cfg_attr(not(doc), doc = "`Tensor::abs`")]
1050 /// function, which is more high-level and designed for public use.
1051 fn abs(tensor: Self::Primitive) -> Self::Primitive;
1052
1053 /// Element-wise power of a tensor to a float tensor
1054 ///
1055 /// # Arguments
1056 /// * `tensor` - The tensor to apply power to.
1057 /// * `power` - The power to apply to the tensor.
1058 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
1059
1060 /// Element-wise power of a tensor
1061 ///
1062 /// # Arguments
1063 /// * `tensor` - The tensor to apply power to.
1064 /// * `power` - The power to apply to the tensor.
1065 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
1066
1067 /// Element-wise power of a tensor to a scalar float
1068 ///
1069 /// # Arguments
1070 /// * `tensor` - The tensor to apply power to.
1071 /// * `power` - The power to apply to the tensor.
1072 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
1073
1074 /// Element-wise power of a tensor to a scalar int
1075 ///
1076 /// # Arguments
1077 /// * `tensor` - The tensor to apply power to.
1078 /// * `power` - The power to apply to the tensor.
1079 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
1080
1081 /// Create a random tensor.
1082 ///
1083 /// # Arguments
1084 ///
1085 /// * `shape` - The shape of the output tensor.
1086 /// * `distribution` - The distribution used to sample.
1087 /// * `device` - The device to use.
1088 ///
1089 /// # Returns
1090 ///
1091 /// A new tensor.
1092 ///
1093 /// # Remarks
1094 ///
1095 /// This is a low-level function used internally by the library to call different backend functions
1096 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1097 /// or use this function directly.
1098 ///
1099 /// Users should prefer the
1100 #[cfg_attr(doc, doc = crate::doc_tensor!("random"))]
1101 #[cfg_attr(not(doc), doc = "`Tensor::random`")]
1102 /// function, which is more high-level and designed for public use.
1103 fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
1104
1105 /// Sort the elements of the input `tensor` by value along a given dimension.
1106 ///
1107 /// This sort is unstable (i.e., may reorder equal elements).
1108 ///
1109 /// # Arguments
1110 ///
1111 /// * `tensor` - The input tensor.
1112 /// * `dim` - The axis along which to sort.
1113 /// * `descending` - The sorting order.
1114 ///
1115 /// # Returns
1116 ///
1117 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1118 ///
1119 /// # Remarks
1120 /// This is a low-level function used internally by the library to call different backend functions
1121 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1122 /// or use this function directly.
1123 ///
1124 /// Users should prefer the
1125 #[cfg_attr(doc, doc = crate::doc_tensor!("sort"))]
1126 #[cfg_attr(not(doc), doc = "`Tensor::sort`")]
1127 /// function, which is more high-level and designed for public use.
1128 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
1129
1130 /// Sort the elements of the input `tensor` by value along a given dimension.
1131 ///
1132 /// This sort is unstable (i.e., may reorder equal elements).
1133 ///
1134 /// # Arguments
1135 ///
1136 /// * `tensor` - The input tensor.
1137 /// * `dim` - The axis along which to sort.
1138 /// * `descending` - The sorting order.
1139 ///
1140 /// # Returns
1141 ///
1142 /// A tensor with the same shape as the input tensor and corresponding indices, where
1143 /// the elements are sorted by value and the indices map back to the original input tensor.
1144 ///
1145 /// # Remarks
1146 /// This is a low-level function used internally by the library to call different backend functions
1147 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1148 /// or use this function directly.
1149 ///
1150 /// For sorting the elements of a tensor, users should prefer the
1151 #[cfg_attr(doc, doc = crate::doc_tensor!("sort_with_indices"))]
1152 #[cfg_attr(not(doc), doc = "`Tensor::sort_with_indices`")]
1153 /// function, which is more high-level and designed for public use.
1154 fn sort_with_indices(
1155 tensor: Self::Primitive,
1156 dim: usize,
1157 descending: bool,
1158 ) -> (Self::Primitive, IntTensor<B>);
1159
1160 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1161 ///
1162 /// This sort is unstable (i.e., may reorder equal elements).
1163 ///
1164 /// # Arguments
1165 ///
1166 /// * `tensor` - The input tensor.
1167 /// * `dim` - The axis along which to sort.
1168 /// * `descending` - The sorting order.
1169 ///
1170 /// # Returns
1171 ///
1172 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1173 ///
1174 /// # Remarks
1175 /// This is a low-level function used internally by the library to call different backend functions
1176 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1177 /// or use this function directly.
1178 ///
1179 /// Users should prefer the
1180 #[cfg_attr(doc, doc = crate::doc_tensor!("argsort"))]
1181 #[cfg_attr(not(doc), doc = "`Tensor::argsort`")]
1182 /// function, which is more high-level and designed for public use.
1183 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B>;
1184
1185 /// Applies the matrix multiplication operation.
1186 ///
1187 /// ```math
1188 /// C = AB
1189 /// ```
1190 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
1191}