burn_backend/tensor/ops/base.rs
1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 Backend, ExecutionError, Scalar, TensorData, TensorMetadata,
6 element::Element,
7 ops::TransactionPrimitive,
8 tensor::{IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11/// Trait that list all operations that can be applied on all tensors.
12///
13/// # Warnings
14///
15/// This is an internal trait, use the public API provided by the
16#[cfg_attr(doc, doc = crate::doc_tensor!())]
17#[cfg_attr(not(doc), doc = "`Tensor`")]
18/// struct.
19pub trait BasicOps<B: Backend>: TensorKind<B> {
20 /// The type of the tensor elements.
21 type Elem: Element;
22
23 /// Creates an empty tensor with the given shape.
24 ///
25 /// # Arguments
26 ///
27 /// * `shape` - The shape of the tensor.
28 /// * `device` - The device on which the tensor will be allocated.
29 /// * `dtype` - The target data type.
30 ///
31 /// # Returns
32 ///
33 /// The empty tensor.
34 ///
35 /// # Remarks
36 ///
37 /// This is a low-level function used internally by the library to call different backend functions
38 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
39 /// or use this function directly.
40 ///
41 /// For creating empty tensors, users should prefer the
42 #[cfg_attr(doc, doc = crate::doc_tensor!("empty"))]
43 #[cfg_attr(not(doc), doc = "`Tensor::empty`")]
44 /// function, which is more high-level and designed for public use.
45 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
46
47 /// Creates a tensor filled with zeros.
48 ///
49 /// # Arguments
50 ///
51 /// * `shape` - The shape of the tensor.
52 /// * `device` - The device on which the tensor will be allocated.
53 /// * `dtype` - The target data type.
54 ///
55 /// # Returns
56 ///
57 /// The tensor filled with zeros.
58 ///
59 /// # Remarks
60 ///
61 /// This is a low-level function used internally by the library to call different backend functions
62 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
63 /// or use this function directly.
64 ///
65 /// For creating a tensor filled with zeros, users should prefer the
66 #[cfg_attr(doc, doc = crate::doc_tensor!("zeros"))]
67 #[cfg_attr(not(doc), doc = "`Tensor::zeros`")]
68 /// function, which is more high-level and designed for public use.
69 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
70
71 /// Creates a tensor filled with ones.
72 ///
73 /// # Arguments
74 ///
75 /// * `shape` - The shape of the tensor.
76 /// * `device` - The device on which the tensor will be allocated.
77 /// * `dtype` - The target data type.
78 ///
79 /// # Returns
80 ///
81 /// The tensor filled with ones.
82 ///
83 /// # Remarks
84 ///
85 /// This is a low-level function used internally by the library to call different backend functions
86 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
87 /// or use this function directly.
88 ///
89 /// For creating a tensor filled with ones, users should prefer the
90 #[cfg_attr(doc, doc = crate::doc_tensor!("ones"))]
91 #[cfg_attr(not(doc), doc = "`Tensor::ones`")]
92 /// function, which is more high-level and designed for public use.
93 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
94
95 /// Creates a tensor of the given shape where each element is equal to the provided value.
96 ///
97 /// # Arguments
98 ///
99 /// * `shape` - The shape of the tensor.
100 /// * `fill_value` - The value with which to fill the tensor.
101 /// * `device` - The device on which the tensor will be allocated.
102 /// * `dtype` - The target data type.
103 ///
104 /// # Returns
105 ///
106 /// The tensor filled with the specified value.
107 ///
108 /// # Remarks
109 ///
110 /// This is a low-level function used internally by the library to call different backend functions
111 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
112 /// or use this function directly.
113 ///
114 /// For creating full tensors, users should prefer the
115 #[cfg_attr(doc, doc = crate::doc_tensor!("full"))]
116 #[cfg_attr(not(doc), doc = "`Tensor::full`")]
117 /// function, which is more high-level and designed for public use.
118 fn full(shape: Shape, fill_value: Scalar, device: &B::Device, dtype: DType) -> Self::Primitive;
119
120 /// Reshapes the tensor.
121 ///
122 /// # Arguments
123 ///
124 /// * `tensor` - The tensor.
125 /// * `shape` - The new shape of the tensor.
126 ///
127 /// # Returns
128 ///
129 /// The reshaped tensor.
130 ///
131 /// # Remarks
132 ///
133 /// This is a low-level function used internally by the library to call different backend functions
134 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
135 /// or use this function directly.
136 ///
137 /// For reshaping a tensor, users should prefer the
138 #[cfg_attr(doc, doc = crate::doc_tensor!("reshape"))]
139 #[cfg_attr(not(doc), doc = "`Tensor::reshape`")]
140 /// function, which is more high-level and designed for public use.
141 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
142
143 /// Transposes a tensor.
144 ///
145 /// # Arguments
146 ///
147 /// * `tensor` - The tensor to transpose.
148 ///
149 /// # Returns
150 ///
151 /// The transposed tensor.
152 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
153
154 /// Swaps two dimensions of a tensor.
155 ///
156 /// # Arguments
157 ///
158 /// * `tensor` - The tensor to swap the dimensions of.
159 /// * `dim1` - The first dimension to swap.
160 /// * `dim2` - The second dimension to swap.
161 ///
162 /// # Returns
163 ///
164 /// The tensor with the dimensions swapped.
165 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
166
167 /// Permutes the dimensions of a tensor.
168 ///
169 /// # Arguments
170 ///
171 /// * `tensor` - The tensor to permute the dimensions of.
172 /// * `axes` - The new order of the dimensions.
173 ///
174 /// # Returns
175 ///
176 /// The tensor with the dimensions permuted.
177 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
178
179 /// Flips the tensor along the given axes.
180 ///
181 /// # Arguments
182 ///
183 /// * `tensor` - The tensor to flip.
184 /// * `axes` - The axes to flip the tensor along.
185 ///
186 /// # Returns
187 ///
188 /// The tensor with the axes flipped.
189 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
190
191 /// Select tensor elements corresponding to the given slices.
192 ///
193 /// # Arguments
194 ///
195 /// * `tensor` - The tensor.
196 /// * `slices` - The slices specifying ranges and steps for each dimension.
197 ///
198 /// # Returns
199 ///
200 /// The selected elements.
201 ///
202 /// # Remarks
203 ///
204 /// This is a low-level function used internally by the library to call different backend functions
205 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
206 /// or use this function directly.
207 ///
208 /// For selecting elements of a tensor, users should prefer the
209 #[cfg_attr(doc, doc = crate::doc_tensor!("slice"))]
210 #[cfg_attr(not(doc), doc = "`Tensor::slice`")]
211 /// function, which is more high-level and designed for public use.
212 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
213
214 /// Assigns the given value to the tensor elements corresponding to the given slices.
215 ///
216 /// # Arguments
217 ///
218 /// * `tensor` - The tensor.
219 /// * `slices` - The slices specifying which elements to assign, including support for steps.
220 /// * `value` - The value to assign.
221 ///
222 /// # Returns
223 ///
224 /// The tensor with the assigned values.
225 ///
226 /// # Remarks
227 ///
228 /// This is a low-level function used internally by the library to call different backend functions
229 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
230 /// or use this function directly.
231 ///
232 /// For assigning values to elements of a tensor, users should prefer the
233 #[cfg_attr(doc, doc = crate::doc_tensor!("slice_assign"))]
234 #[cfg_attr(not(doc), doc = "`Tensor::slice_assign`")]
235 /// function, which is more high-level and designed for public use.
236 fn slice_assign(
237 tensor: Self::Primitive,
238 slices: &[Slice],
239 value: Self::Primitive,
240 ) -> Self::Primitive;
241
242 /// Select tensor elements along the given dimension corresponding to the given indices.
243 ///
244 /// # Arguments
245 ///
246 /// * `tensor` - The tensor to select from.
247 /// * `dim` - The dimension along which to select.
248 /// * `indices` - The indices of the elements to select.
249 ///
250 /// # Returns
251 ///
252 /// The selected tensor elements.
253 ///
254 /// # Remarks
255 ///
256 /// This is a low-level function used internally by the library to call different backend functions
257 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
258 /// or use this function directly.
259 ///
260 /// For selecting elements from a tensor along an axis, users should prefer the
261 #[cfg_attr(doc, doc = crate::doc_tensor!("select"))]
262 #[cfg_attr(not(doc), doc = "`Tensor::select`")]
263 /// function, which is more high-level and designed for public use.
264 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive;
265
266 /// Assign the selected elements along the given dimension corresponding to the given indices
267 /// from the value tensor.
268 ///
269 /// # Arguments
270 ///
271 /// * `tensor` - The tensor to assign elements to.
272 /// * `dim` - The axis along which to assign elements.
273 /// * `indices` - The indices of the elements to assign.
274 /// * `values` - The values to assign to the tensor.
275 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
276 ///
277 /// # Returns
278 ///
279 /// A tensor with the same shape as the input tensor, where each element is taken from the
280 /// corresponding element of the input tensor at the corresponding index along the specified axis,
281 /// except for the elements at the specified indices, which are taken from the corresponding
282 /// element of the values tensor.
283 ///
284 /// # Remarks
285 ///
286 /// This is a low-level function used internally by the library to call different backend functions
287 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
288 /// or use this function directly.
289 ///
290 /// For assigning elements to a tensor along an axis, users should prefer the
291 #[cfg_attr(doc, doc = crate::doc_tensor!("select_assign"))]
292 #[cfg_attr(not(doc), doc = "`Tensor::select_assign`")]
293 /// function, which is more high-level and designed for public use.
294 fn select_assign(
295 tensor: Self::Primitive,
296 dim: usize,
297 indices: IntTensor<B>,
298 values: Self::Primitive,
299 update: IndexingUpdateOp,
300 ) -> Self::Primitive;
301
302 /// Selects elements from a tensor based on a boolean mask.
303 ///
304 /// # Arguments
305 ///
306 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
307 /// * `mask` - The boolean mask to use for selecting elements.
308 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
309 ///
310 /// # Returns
311 ///
312 /// A tensor with the same shape as the input tensors, where each element is taken from the
313 /// corresponding element of the left hand side tensor if the corresponding element of the mask
314 /// is true, and from the corresponding element of the right hand side tensor otherwise.
315 ///
316 /// # Remarks
317 ///
318 /// This is a low-level function used internally by the library to call different backend functions
319 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
320 /// or use this function directly.
321 ///
322 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
323 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_where"))]
324 #[cfg_attr(not(doc), doc = "`Tensor::mask_where`")]
325 /// function, which is more high-level and designed for public use.
326 fn mask_where(
327 tensor: Self::Primitive,
328 mask: B::BoolTensorPrimitive,
329 source: Self::Primitive,
330 ) -> Self::Primitive;
331
332 /// Fills elements of a tensor based on a boolean mask.
333 ///
334 /// # Arguments
335 ///
336 /// * `tensor` - The tensor where will be overwritten with the value
337 /// when the corresponding element of the mask is true.
338 /// * `mask` - The boolean mask to use for filling elements.
339 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
340 ///
341 /// # Returns
342 ///
343 /// A tensor with the same shape as the input tensors, where each element is taken from the
344 /// corresponding element unmodified if the corresponding element of the mask is false, and
345 /// filled with the value otherwise.
346 ///
347 /// # Remarks
348 ///
349 /// This is a low-level function used internally by the library to call different backend functions
350 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
351 /// or use this function directly.
352 ///
353 /// For filling elements of a tensor based on a boolean mask, users should prefer the
354 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_fill"))]
355 #[cfg_attr(not(doc), doc = "`Tensor::mask_fill`")]
356 /// function, which is more high-level and designed for public use.
357 fn mask_fill(
358 tensor: Self::Primitive,
359 mask: B::BoolTensorPrimitive,
360 value: Scalar,
361 ) -> Self::Primitive;
362
363 /// Gathers elements from a tensor along an axis.
364 ///
365 /// # Arguments
366 ///
367 /// * `dim` - The axis along which to gather elements.
368 /// * `tensor` - The tensor to gather elements from.
369 /// * `indices` - The indices of the elements to gather.
370 ///
371 /// # Returns
372 ///
373 /// A tensor with the same shape as the input tensor, where each element is taken from the
374 /// corresponding element of the input tensor at the corresponding index along the specified axis.
375 ///
376 /// # Remarks
377 ///
378 /// This is a low-level function used internally by the library to call different backend functions
379 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
380 /// or use this function directly.
381 ///
382 /// For gathering elements from a tensor along an axis, users should prefer the
383 #[cfg_attr(doc, doc = crate::doc_tensor!("gather"))]
384 #[cfg_attr(not(doc), doc = "`Tensor::gather`")]
385 /// function, which is more high-level and designed for public use.
386 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;
387
388 /// Scatters elements into a tensor along an axis.
389 ///
390 /// # Arguments
391 ///
392 /// * `dim` - The axis along which to scatter elements.
393 /// * `tensor` - The tensor to scatter elements into.
394 /// * `indices` - The indices of the elements to scatter.
395 /// * `values` - The values to scatter into the tensor.
396 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
397 ///
398 /// # Returns
399 ///
400 /// A tensor with the same shape as the input tensor, where each element is taken from the
401 /// corresponding element of the input tensor at the corresponding index along the specified axis,
402 /// except for the elements at the specified indices, which are taken from the corresponding
403 /// element of the values tensor.
404 ///
405 /// # Remarks
406 ///
407 /// This is a low-level function used internally by the library to call different backend functions
408 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
409 /// or use this function directly.
410 ///
411 /// For scattering elements into a tensor along an axis, users should prefer the
412 #[cfg_attr(doc, doc = crate::doc_tensor!("scatter"))]
413 #[cfg_attr(not(doc), doc = "`Tensor::scatter`")]
414 /// function, which is more high-level and designed for public use.
415 fn scatter(
416 dim: usize,
417 tensor: Self::Primitive,
418 indices: IntTensor<B>,
419 values: Self::Primitive,
420 update: IndexingUpdateOp,
421 ) -> Self::Primitive;
422
423 /// Returns the device on which the tensor is allocated.
424 ///
425 /// # Arguments
426 ///
427 /// * `tensor` - The tensor.
428 ///
429 /// # Returns
430 ///
431 /// The device on which the tensor is allocated.
432 ///
433 /// # Remarks
434 ///
435 /// This is a low-level function used internally by the library to call different backend functions
436 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
437 /// or use this function directly.
438 ///
439 /// For getting the device of a tensor, users should prefer the
440 #[cfg_attr(doc, doc = crate::doc_tensor!("device"))]
441 #[cfg_attr(not(doc), doc = "`Tensor::device`")]
442 /// function, which is more high-level and designed for public use.
443 fn device(tensor: &Self::Primitive) -> B::Device;
444
445 /// Moves the tensor to the given device.
446 ///
447 /// # Arguments
448 ///
449 /// * `tensor` - The tensor.
450 /// * `device` - The device on which the tensor will be moved.
451 ///
452 /// # Returns
453 ///
454 /// The tensor on the given device.
455 ///
456 /// # Remarks
457 ///
458 /// This is a low-level function used internally by the library to call different backend functions
459 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
460 /// or use this function directly.
461 ///
462 /// For moving a tensor to a device, users should prefer the
463 #[cfg_attr(doc, doc = crate::doc_tensor!("to_device"))]
464 #[cfg_attr(not(doc), doc = "`Tensor::to_device`")]
465 /// function, which is more high-level and designed for public use.
466 #[allow(clippy::wrong_self_convention)]
467 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
468
469 /// Extracts the data from the tensor asynchronously.
470 ///
471 /// # Arguments
472 ///
473 /// * `tensor` - The tensor.
474 ///
475 /// # Returns
476 ///
477 /// The data of the tensor.
478 ///
479 /// # Remarks
480 ///
481 /// This is a low-level function used internally by the library to call different backend functions
482 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
483 /// or use this function directly.
484 ///
485 /// For extracting the data of a tensor, users should prefer the
486 #[cfg_attr(doc, doc = crate::doc_tensor!("into_data"))]
487 #[cfg_attr(not(doc), doc = "`Tensor::into_data`")]
488 /// function, which is more high-level and designed for public use.
489 #[allow(clippy::wrong_self_convention)]
490 fn into_data_async(
491 tensor: Self::Primitive,
492 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
493
494 /// Read the data from the tensor using a transaction.
495 ///
496 /// # Remarks
497 ///
498 /// This is a low-level function used internally by the library to call different backend functions
499 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
500 /// or use this function directly.
501 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive);
502
503 /// Creates a tensor from the given data.
504 ///
505 /// # Arguments
506 ///
507 /// * `data` - The data of the tensor.
508 /// * `device` - The device on which the tensor will be allocated.
509 ///
510 /// # Returns
511 ///
512 /// The tensor.
513 ///
514 /// # Remarks
515 ///
516 /// This is a low-level function used internally by the library to call different backend functions
517 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
518 /// or use this function directly.
519 ///
520 /// For creating a tensor from data, users should prefer the
521 #[cfg_attr(doc, doc = crate::doc_tensor!("from_data"))]
522 #[cfg_attr(not(doc), doc = "`Tensor::from_data`")]
523 /// function, which is more high-level and designed for public use.
524 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
525 /// Creates a tensor from the given data enforcing the given data type.
526 ///
527 /// # Remarks
528 ///
529 /// This is a low-level function used internally by the library to call different backend functions
530 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
531 /// or use this function directly.
532 ///
533 /// For creating a tensor from data, users should prefer the
534 #[cfg_attr(doc, doc = crate::doc_tensor!("from_data_dtype"))]
535 #[cfg_attr(not(doc), doc = "`Tensor::from_data_dtype`")]
536 /// function, which is more high-level and designed for public use.
537 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
538
539 /// Repeat the tensor along the given dimension.
540 ///
541 /// # Arguments
542 ///
543 /// * `tensor` - The tensor.
544 /// * `dim` - The dimension along which the tensor will be repeated.
545 /// * `times` - The number of times the tensor will be repeated.
546 ///
547 /// # Returns
548 ///
549 /// The repeated tensor.
550 ///
551 /// # Remarks
552 ///
553 /// This is a low-level function used internally by the library to call different backend functions
554 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
555 /// or use this function directly.
556 ///
557 /// For repeating a tensor, users should prefer the
558 #[cfg_attr(doc, doc = crate::doc_tensor!("repeat_dim"))]
559 #[cfg_attr(not(doc), doc = "`Tensor::repeat_dim`")]
560 /// function, which is more high-level and designed for public use.
561 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
562
563 /// Concatenates the given tensors along the given dimension.
564 ///
565 /// # Arguments
566 ///
567 /// * `vectors` - The tensors to concatenate.
568 /// * `dim` - The dimension along which the tensors will be concatenated.
569 ///
570 /// # Returns
571 ///
572 /// The concatenated tensor.
573 ///
574 /// # Remarks
575 ///
576 /// This is a low-level function used internally by the library to call different backend functions
577 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
578 /// or use this function directly.
579 ///
580 /// For concatenating tensors, users should prefer the
581 #[cfg_attr(doc, doc = crate::doc_tensor!("cat"))]
582 #[cfg_attr(not(doc), doc = "`Tensor::cat`")]
583 /// function, which is more high-level and designed for public use.
584 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
585
586 /// Equates the given tensors.
587 ///
588 /// # Arguments
589 ///
590 /// * `lhs` - The left hand side tensor.
591 /// * `rhs` - The right hand side tensor.
592 ///
593 /// # Returns
594 ///
595 /// The tensor of booleans indicating whether the corresponding elements are equal.
596 ///
597 /// # Remarks
598 ///
599 /// This is a low-level function used internally by the library to call different backend functions
600 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
601 /// or use this function directly.
602 ///
603 /// For equating tensors, users should prefer the
604 #[cfg_attr(doc, doc = crate::doc_tensor!("equal"))]
605 #[cfg_attr(not(doc), doc = "`Tensor::equal`")]
606 /// function, which is more high-level and designed for public use.
607 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
608
609 /// Element-wise equality between two tensors.
610 ///
611 /// # Arguments
612 ///
613 /// * `lhs` - The left hand side tensor.
614 /// * `rhs` - The right hand side scalar.
615 ///
616 /// # Returns
617 ///
618 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
619 /// corresponding elements of the input tensors are equal, and false otherwise.
620 ///
621 /// # Remarks
622 ///
623 /// This is a low-level function used internally by the library to call different backend functions
624 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
625 /// or use this function directly.
626 ///
627 /// For element-wise equality between two tensors, users should prefer the
628 #[cfg_attr(doc, doc = crate::doc_tensor!("equal_elem"))]
629 #[cfg_attr(not(doc), doc = "`Tensor::equal_elem`")]
630 /// function, which is more high-level and designed for public use.
631 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
632
633 /// Applies element-wise non-equality comparison between the given tensors.
634 ///
635 /// # Arguments
636 ///
637 /// * `lhs` - The left hand side tensor.
638 /// * `rhs` - The right hand side tensor.
639 ///
640 /// # Returns
641 ///
642 /// The tensor of booleans indicating whether the corresponding elements are equal.
643 ///
644 /// # Remarks
645 ///
646 /// This is a low-level function used internally by the library to call different backend functions
647 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
648 /// or use this function directly.
649 ///
650 /// For non-equality comparison of tensors, users should prefer the
651 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal"))]
652 #[cfg_attr(not(doc), doc = "`Tensor::not_equal`")]
653 /// function, which is more high-level and designed for public use.
654 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
655
656 /// Element-wise non-equality between two tensors.
657 ///
658 /// # Arguments
659 ///
660 /// * `lhs` - The left hand side tensor.
661 /// * `rhs` - The right hand side scalar.
662 ///
663 /// # Returns
664 ///
665 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
666 /// corresponding elements of the input tensors are equal, and false otherwise.
667 ///
668 /// # Remarks
669 ///
670 /// This is a low-level function used internally by the library to call different backend functions
671 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
672 /// or use this function directly.
673 ///
674 /// For element-wise non-equality between two tensors, users should prefer the
675 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal_elem"))]
676 #[cfg_attr(not(doc), doc = "`Tensor::not_equal_elem`")]
677 /// function, which is more high-level and designed for public use.
678 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
679
680 /// Returns the name of the element type.
681 fn elem_type_name() -> &'static str {
682 core::any::type_name::<Self::Elem>()
683 }
684
685 /// Returns the tensor data type.
686 fn dtype(tensor: &Self::Primitive) -> DType {
687 tensor.dtype()
688 }
689
690 /// Tests if any element in the `tensor` evaluates to True.
691 ///
692 /// # Arguments
693 ///
694 /// * `tensor` - The tensor to test.
695 ///
696 /// # Returns
697 ///
698 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
699 ///
700 /// # Remarks
701 ///
702 /// This is a low-level function used internally by the library to call different backend functions
703 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
704 /// or use this function directly. Users should prefer the
705 #[cfg_attr(doc, doc = crate::doc_tensor!("any"))]
706 #[cfg_attr(not(doc), doc = "`Tensor::any`")]
707 /// function, which is more high-level and designed for public use.
708 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
709
710 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
711 ///
712 /// # Arguments
713 ///
714 /// * tensor - The tensor to test.
715 /// * dim - The axis along which to test.
716 ///
717 /// # Returns
718 ///
719 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
720 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
721 ///
722 /// # Remarks
723 ///
724 /// This is a low-level function used internally by the library to call different backend functions
725 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
726 /// or use this function directly. Users should prefer the
727 #[cfg_attr(doc, doc = crate::doc_tensor!("any_dim"))]
728 #[cfg_attr(not(doc), doc = "`Tensor::any_dim`")]
729 /// function, which is more high-level and designed for public use.
730 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
731
732 /// Tests if all elements in the `tensor` evaluate to True.
733 ///
734 /// # Arguments
735 ///
736 /// * `tensor` - The tensor to test.
737 ///
738 /// # Returns
739 ///
740 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
741 ///
742 /// # Remarks
743 ///
744 /// This is a low-level function used internally by the library to call different backend functions
745 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
746 /// or use this function directly. Users should prefer the
747 #[cfg_attr(doc, doc = crate::doc_tensor!("all"))]
748 #[cfg_attr(not(doc), doc = "`Tensor::all`")]
749 /// function, which is more high-level and designed for public use.
750 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
751
752 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
753 ///
754 /// # Arguments
755 ///
756 /// * `tensor` - The tensor to test.
757 ///
758 /// # Returns
759 ///
760 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
761 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
762 ///
763 /// # Remarks
764 ///
765 /// This is a low-level function used internally by the library to call different backend functions
766 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
767 /// or use this function directly. Users should prefer the
768 #[cfg_attr(doc, doc = crate::doc_tensor!("all_dim"))]
769 #[cfg_attr(not(doc), doc = "`Tensor::all_dim`")]
770 /// function, which is more high-level and designed for public use.
771 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
772
773 /// Broadcasts the given tensor to the specified shape.
774 ///
775 /// # Arguments
776 ///
777 /// * `tensor` - The tensor to broadcast.
778 /// * `shape` - The shape to broadcast to.
779 ///
780 /// # Returns
781 ///
782 /// The broadcasted tensor.
783 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
784
785 /// Unfold windows along a dimension.
786 ///
787 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
788 /// where windows are advanced by `step` at each index.
789 ///
790 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
791 ///
792 /// # Warning
793 ///
794 /// For the `ndarray` and `candle` backends; this is not a view but a full copy.
795 ///
796 /// # Arguments
797 ///
798 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
799 /// * `dim` - the dimension to unfold.
800 /// * `size` - the size of each unfolded window.
801 /// * `step` - the step between each window.
802 ///
803 /// # Returns
804 ///
805 /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
806 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
807}