burn_backend/tensor/ops/base.rs
1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 Backend, ExecutionError, TensorData, TensorMetadata,
6 element::{Element, ElementConversion},
7 ops::TransactionPrimitive,
8 tensor::{IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11/// Trait that list all operations that can be applied on all tensors.
12///
13/// # Warnings
14///
15/// This is an internal trait, use the public API provided by the
16#[cfg_attr(doc, doc = crate::doc_tensor!())]
17#[cfg_attr(not(doc), doc = "`Tensor`")]
18/// struct.
19pub trait BasicOps<B: Backend>: TensorKind<B> {
20 /// The type of the tensor elements.
21 type Elem: Element;
22
23 /// Creates an empty tensor with the given shape.
24 ///
25 /// # Arguments
26 ///
27 /// * `shape` - The shape of the tensor.
28 /// * `device` - The device on which the tensor will be allocated.
29 /// * `dtype` - The target data type.
30 ///
31 /// # Returns
32 ///
33 /// The empty tensor.
34 ///
35 /// # Remarks
36 ///
37 /// This is a low-level function used internally by the library to call different backend functions
38 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
39 /// or use this function directly.
40 ///
41 /// For creating empty tensors, users should prefer the
42 #[cfg_attr(doc, doc = crate::doc_tensor!("empty"))]
43 #[cfg_attr(not(doc), doc = "`Tensor::empty`")]
44 /// function, which is more high-level and designed for public use.
45 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
46
47 /// Creates a tensor filled with zeros.
48 ///
49 /// # Arguments
50 ///
51 /// * `shape` - The shape of the tensor.
52 /// * `device` - The device on which the tensor will be allocated.
53 /// * `dtype` - The target data type.
54 ///
55 /// # Returns
56 ///
57 /// The tensor filled with zeros.
58 ///
59 /// # Remarks
60 ///
61 /// This is a low-level function used internally by the library to call different backend functions
62 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
63 /// or use this function directly.
64 ///
65 /// For creating a tensor filled with zeros, users should prefer the
66 #[cfg_attr(doc, doc = crate::doc_tensor!("zeros"))]
67 #[cfg_attr(not(doc), doc = "`Tensor::zeros`")]
68 /// function, which is more high-level and designed for public use.
69 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
70
71 /// Creates a tensor filled with ones.
72 ///
73 /// # Arguments
74 ///
75 /// * `shape` - The shape of the tensor.
76 /// * `device` - The device on which the tensor will be allocated.
77 /// * `dtype` - The target data type.
78 ///
79 /// # Returns
80 ///
81 /// The tensor filled with ones.
82 ///
83 /// # Remarks
84 ///
85 /// This is a low-level function used internally by the library to call different backend functions
86 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
87 /// or use this function directly.
88 ///
89 /// For creating a tensor filled with ones, users should prefer the
90 #[cfg_attr(doc, doc = crate::doc_tensor!("ones"))]
91 #[cfg_attr(not(doc), doc = "`Tensor::ones`")]
92 /// function, which is more high-level and designed for public use.
93 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
94
95 /// Creates a tensor of the given shape where each element is equal to the provided value.
96 ///
97 /// # Arguments
98 ///
99 /// * `shape` - The shape of the tensor.
100 /// * `fill_value` - The value with which to fill the tensor.
101 /// * `device` - The device on which the tensor will be allocated.
102 /// * `dtype` - The target data type.
103 ///
104 /// # Returns
105 ///
106 /// The tensor filled with the specified value.
107 ///
108 /// # Remarks
109 ///
110 /// This is a low-level function used internally by the library to call different backend functions
111 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
112 /// or use this function directly.
113 ///
114 /// For creating full tensors, users should prefer the
115 #[cfg_attr(doc, doc = crate::doc_tensor!("full"))]
116 #[cfg_attr(not(doc), doc = "`Tensor::full`")]
117 /// function, which is more high-level and designed for public use.
118 fn full<E: ElementConversion>(
119 shape: Shape,
120 fill_value: E,
121 device: &B::Device,
122 dtype: DType,
123 ) -> Self::Primitive;
124
125 /// Reshapes the tensor.
126 ///
127 /// # Arguments
128 ///
129 /// * `tensor` - The tensor.
130 /// * `shape` - The new shape of the tensor.
131 ///
132 /// # Returns
133 ///
134 /// The reshaped tensor.
135 ///
136 /// # Remarks
137 ///
138 /// This is a low-level function used internally by the library to call different backend functions
139 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
140 /// or use this function directly.
141 ///
142 /// For reshaping a tensor, users should prefer the
143 #[cfg_attr(doc, doc = crate::doc_tensor!("reshape"))]
144 #[cfg_attr(not(doc), doc = "`Tensor::reshape`")]
145 /// function, which is more high-level and designed for public use.
146 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
147
148 /// Transposes a tensor.
149 ///
150 /// # Arguments
151 ///
152 /// * `tensor` - The tensor to transpose.
153 ///
154 /// # Returns
155 ///
156 /// The transposed tensor.
157 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
158
159 /// Swaps two dimensions of a tensor.
160 ///
161 /// # Arguments
162 ///
163 /// * `tensor` - The tensor to swap the dimensions of.
164 /// * `dim1` - The first dimension to swap.
165 /// * `dim2` - The second dimension to swap.
166 ///
167 /// # Returns
168 ///
169 /// The tensor with the dimensions swapped.
170 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
171
172 /// Permutes the dimensions of a tensor.
173 ///
174 /// # Arguments
175 ///
176 /// * `tensor` - The tensor to permute the dimensions of.
177 /// * `axes` - The new order of the dimensions.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the dimensions permuted.
182 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
183
184 /// Flips the tensor along the given axes.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor to flip.
189 /// * `axes` - The axes to flip the tensor along.
190 ///
191 /// # Returns
192 ///
193 /// The tensor with the axes flipped.
194 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
195
196 /// Select tensor elements corresponding to the given slices.
197 ///
198 /// # Arguments
199 ///
200 /// * `tensor` - The tensor.
201 /// * `slices` - The slices specifying ranges and steps for each dimension.
202 ///
203 /// # Returns
204 ///
205 /// The selected elements.
206 ///
207 /// # Remarks
208 ///
209 /// This is a low-level function used internally by the library to call different backend functions
210 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
211 /// or use this function directly.
212 ///
213 /// For selecting elements of a tensor, users should prefer the
214 #[cfg_attr(doc, doc = crate::doc_tensor!("slice"))]
215 #[cfg_attr(not(doc), doc = "`Tensor::slice`")]
216 /// function, which is more high-level and designed for public use.
217 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
218
219 /// Assigns the given value to the tensor elements corresponding to the given slices.
220 ///
221 /// # Arguments
222 ///
223 /// * `tensor` - The tensor.
224 /// * `slices` - The slices specifying which elements to assign, including support for steps.
225 /// * `value` - The value to assign.
226 ///
227 /// # Returns
228 ///
229 /// The tensor with the assigned values.
230 ///
231 /// # Remarks
232 ///
233 /// This is a low-level function used internally by the library to call different backend functions
234 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
235 /// or use this function directly.
236 ///
237 /// For assigning values to elements of a tensor, users should prefer the
238 #[cfg_attr(doc, doc = crate::doc_tensor!("slice_assign"))]
239 #[cfg_attr(not(doc), doc = "`Tensor::slice_assign`")]
240 /// function, which is more high-level and designed for public use.
241 fn slice_assign(
242 tensor: Self::Primitive,
243 slices: &[Slice],
244 value: Self::Primitive,
245 ) -> Self::Primitive;
246
247 /// Select tensor elements along the given dimension corresponding to the given indices.
248 ///
249 /// # Arguments
250 ///
251 /// * `tensor` - The tensor to select from.
252 /// * `dim` - The dimension along which to select.
253 /// * `indices` - The indices of the elements to select.
254 ///
255 /// # Returns
256 ///
257 /// The selected tensor elements.
258 ///
259 /// # Remarks
260 ///
261 /// This is a low-level function used internally by the library to call different backend functions
262 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
263 /// or use this function directly.
264 ///
265 /// For selecting elements from a tensor along an axis, users should prefer the
266 #[cfg_attr(doc, doc = crate::doc_tensor!("select"))]
267 #[cfg_attr(not(doc), doc = "`Tensor::select`")]
268 /// function, which is more high-level and designed for public use.
269 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive;
270
271 /// Assign the selected elements along the given dimension corresponding to the given indices
272 /// from the value tensor.
273 ///
274 /// # Arguments
275 ///
276 /// * `tensor` - The tensor to assign elements to.
277 /// * `dim` - The axis along which to assign elements.
278 /// * `indices` - The indices of the elements to assign.
279 /// * `values` - The values to assign to the tensor.
280 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
281 ///
282 /// # Returns
283 ///
284 /// A tensor with the same shape as the input tensor, where each element is taken from the
285 /// corresponding element of the input tensor at the corresponding index along the specified axis,
286 /// except for the elements at the specified indices, which are taken from the corresponding
287 /// element of the values tensor.
288 ///
289 /// # Remarks
290 ///
291 /// This is a low-level function used internally by the library to call different backend functions
292 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
293 /// or use this function directly.
294 ///
295 /// For assigning elements to a tensor along an axis, users should prefer the
296 #[cfg_attr(doc, doc = crate::doc_tensor!("select_assign"))]
297 #[cfg_attr(not(doc), doc = "`Tensor::select_assign`")]
298 /// function, which is more high-level and designed for public use.
299 fn select_assign(
300 tensor: Self::Primitive,
301 dim: usize,
302 indices: IntTensor<B>,
303 values: Self::Primitive,
304 update: IndexingUpdateOp,
305 ) -> Self::Primitive;
306
307 /// Selects elements from a tensor based on a boolean mask.
308 ///
309 /// # Arguments
310 ///
311 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
312 /// * `mask` - The boolean mask to use for selecting elements.
313 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
314 ///
315 /// # Returns
316 ///
317 /// A tensor with the same shape as the input tensors, where each element is taken from the
318 /// corresponding element of the left hand side tensor if the corresponding element of the mask
319 /// is true, and from the corresponding element of the right hand side tensor otherwise.
320 ///
321 /// # Remarks
322 ///
323 /// This is a low-level function used internally by the library to call different backend functions
324 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
325 /// or use this function directly.
326 ///
327 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
328 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_where"))]
329 #[cfg_attr(not(doc), doc = "`Tensor::mask_where`")]
330 /// function, which is more high-level and designed for public use.
331 fn mask_where(
332 tensor: Self::Primitive,
333 mask: B::BoolTensorPrimitive,
334 source: Self::Primitive,
335 ) -> Self::Primitive;
336
337 /// Fills elements of a tensor based on a boolean mask.
338 ///
339 /// # Arguments
340 ///
341 /// * `tensor` - The tensor where will be overwritten with the value
342 /// when the corresponding element of the mask is true.
343 /// * `mask` - The boolean mask to use for filling elements.
344 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
345 ///
346 /// # Returns
347 ///
348 /// A tensor with the same shape as the input tensors, where each element is taken from the
349 /// corresponding element unmodified if the corresponding element of the mask is false, and
350 /// filled with the value otherwise.
351 ///
352 /// # Remarks
353 ///
354 /// This is a low-level function used internally by the library to call different backend functions
355 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
356 /// or use this function directly.
357 ///
358 /// For filling elements of a tensor based on a boolean mask, users should prefer the
359 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_fill"))]
360 #[cfg_attr(not(doc), doc = "`Tensor::mask_fill`")]
361 /// function, which is more high-level and designed for public use.
362 fn mask_fill(
363 tensor: Self::Primitive,
364 mask: B::BoolTensorPrimitive,
365 value: Self::Elem,
366 ) -> Self::Primitive;
367
368 /// Gathers elements from a tensor along an axis.
369 ///
370 /// # Arguments
371 ///
372 /// * `dim` - The axis along which to gather elements.
373 /// * `tensor` - The tensor to gather elements from.
374 /// * `indices` - The indices of the elements to gather.
375 ///
376 /// # Returns
377 ///
378 /// A tensor with the same shape as the input tensor, where each element is taken from the
379 /// corresponding element of the input tensor at the corresponding index along the specified axis.
380 ///
381 /// # Remarks
382 ///
383 /// This is a low-level function used internally by the library to call different backend functions
384 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
385 /// or use this function directly.
386 ///
387 /// For gathering elements from a tensor along an axis, users should prefer the
388 #[cfg_attr(doc, doc = crate::doc_tensor!("gather"))]
389 #[cfg_attr(not(doc), doc = "`Tensor::gather`")]
390 /// function, which is more high-level and designed for public use.
391 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;
392
393 /// Scatters elements into a tensor along an axis.
394 ///
395 /// # Arguments
396 ///
397 /// * `dim` - The axis along which to scatter elements.
398 /// * `tensor` - The tensor to scatter elements into.
399 /// * `indices` - The indices of the elements to scatter.
400 /// * `values` - The values to scatter into the tensor.
401 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
402 ///
403 /// # Returns
404 ///
405 /// A tensor with the same shape as the input tensor, where each element is taken from the
406 /// corresponding element of the input tensor at the corresponding index along the specified axis,
407 /// except for the elements at the specified indices, which are taken from the corresponding
408 /// element of the values tensor.
409 ///
410 /// # Remarks
411 ///
412 /// This is a low-level function used internally by the library to call different backend functions
413 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
414 /// or use this function directly.
415 ///
416 /// For scattering elements into a tensor along an axis, users should prefer the
417 #[cfg_attr(doc, doc = crate::doc_tensor!("scatter"))]
418 #[cfg_attr(not(doc), doc = "`Tensor::scatter`")]
419 /// function, which is more high-level and designed for public use.
420 fn scatter(
421 dim: usize,
422 tensor: Self::Primitive,
423 indices: IntTensor<B>,
424 values: Self::Primitive,
425 update: IndexingUpdateOp,
426 ) -> Self::Primitive;
427
428 /// Returns the device on which the tensor is allocated.
429 ///
430 /// # Arguments
431 ///
432 /// * `tensor` - The tensor.
433 ///
434 /// # Returns
435 ///
436 /// The device on which the tensor is allocated.
437 ///
438 /// # Remarks
439 ///
440 /// This is a low-level function used internally by the library to call different backend functions
441 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
442 /// or use this function directly.
443 ///
444 /// For getting the device of a tensor, users should prefer the
445 #[cfg_attr(doc, doc = crate::doc_tensor!("device"))]
446 #[cfg_attr(not(doc), doc = "`Tensor::device`")]
447 /// function, which is more high-level and designed for public use.
448 fn device(tensor: &Self::Primitive) -> B::Device;
449
450 /// Moves the tensor to the given device.
451 ///
452 /// # Arguments
453 ///
454 /// * `tensor` - The tensor.
455 /// * `device` - The device on which the tensor will be moved.
456 ///
457 /// # Returns
458 ///
459 /// The tensor on the given device.
460 ///
461 /// # Remarks
462 ///
463 /// This is a low-level function used internally by the library to call different backend functions
464 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
465 /// or use this function directly.
466 ///
467 /// For moving a tensor to a device, users should prefer the
468 #[cfg_attr(doc, doc = crate::doc_tensor!("to_device"))]
469 #[cfg_attr(not(doc), doc = "`Tensor::to_device`")]
470 /// function, which is more high-level and designed for public use.
471 #[allow(clippy::wrong_self_convention)]
472 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
473
474 /// Extracts the data from the tensor asynchronously.
475 ///
476 /// # Arguments
477 ///
478 /// * `tensor` - The tensor.
479 ///
480 /// # Returns
481 ///
482 /// The data of the tensor.
483 ///
484 /// # Remarks
485 ///
486 /// This is a low-level function used internally by the library to call different backend functions
487 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
488 /// or use this function directly.
489 ///
490 /// For extracting the data of a tensor, users should prefer the
491 #[cfg_attr(doc, doc = crate::doc_tensor!("into_data"))]
492 #[cfg_attr(not(doc), doc = "`Tensor::into_data`")]
493 /// function, which is more high-level and designed for public use.
494 #[allow(clippy::wrong_self_convention)]
495 fn into_data_async(
496 tensor: Self::Primitive,
497 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
498
499 /// Read the data from the tensor using a transaction.
500 ///
501 /// # Remarks
502 ///
503 /// This is a low-level function used internally by the library to call different backend functions
504 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
505 /// or use this function directly.
506 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive);
507
508 /// Creates a tensor from the given data.
509 ///
510 /// # Arguments
511 ///
512 /// * `data` - The data of the tensor.
513 /// * `device` - The device on which the tensor will be allocated.
514 ///
515 /// # Returns
516 ///
517 /// The tensor.
518 ///
519 /// # Remarks
520 ///
521 /// This is a low-level function used internally by the library to call different backend functions
522 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
523 /// or use this function directly.
524 ///
525 /// For creating a tensor from data, users should prefer the
526 #[cfg_attr(doc, doc = crate::doc_tensor!("from_data"))]
527 #[cfg_attr(not(doc), doc = "`Tensor::from_data`")]
528 /// function, which is more high-level and designed for public use.
529 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
530 /// Creates a tensor from the given data enforcing the given data type.
531 ///
532 /// # Remarks
533 ///
534 /// This is a low-level function used internally by the library to call different backend functions
535 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
536 /// or use this function directly.
537 ///
538 /// For creating a tensor from data, users should prefer the
539 #[cfg_attr(doc, doc = crate::doc_tensor!("from_data_dtype"))]
540 #[cfg_attr(not(doc), doc = "`Tensor::from_data_dtype`")]
541 /// function, which is more high-level and designed for public use.
542 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
543
544 /// Repeat the tensor along the given dimension.
545 ///
546 /// # Arguments
547 ///
548 /// * `tensor` - The tensor.
549 /// * `dim` - The dimension along which the tensor will be repeated.
550 /// * `times` - The number of times the tensor will be repeated.
551 ///
552 /// # Returns
553 ///
554 /// The repeated tensor.
555 ///
556 /// # Remarks
557 ///
558 /// This is a low-level function used internally by the library to call different backend functions
559 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
560 /// or use this function directly.
561 ///
562 /// For repeating a tensor, users should prefer the
563 #[cfg_attr(doc, doc = crate::doc_tensor!("repeat_dim"))]
564 #[cfg_attr(not(doc), doc = "`Tensor::repeat_dim`")]
565 /// function, which is more high-level and designed for public use.
566 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
567
568 /// Concatenates the given tensors along the given dimension.
569 ///
570 /// # Arguments
571 ///
572 /// * `vectors` - The tensors to concatenate.
573 /// * `dim` - The dimension along which the tensors will be concatenated.
574 ///
575 /// # Returns
576 ///
577 /// The concatenated tensor.
578 ///
579 /// # Remarks
580 ///
581 /// This is a low-level function used internally by the library to call different backend functions
582 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
583 /// or use this function directly.
584 ///
585 /// For concatenating tensors, users should prefer the
586 #[cfg_attr(doc, doc = crate::doc_tensor!("cat"))]
587 #[cfg_attr(not(doc), doc = "`Tensor::cat`")]
588 /// function, which is more high-level and designed for public use.
589 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
590
591 /// Equates the given tensors.
592 ///
593 /// # Arguments
594 ///
595 /// * `lhs` - The left hand side tensor.
596 /// * `rhs` - The right hand side tensor.
597 ///
598 /// # Returns
599 ///
600 /// The tensor of booleans indicating whether the corresponding elements are equal.
601 ///
602 /// # Remarks
603 ///
604 /// This is a low-level function used internally by the library to call different backend functions
605 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
606 /// or use this function directly.
607 ///
608 /// For equating tensors, users should prefer the
609 #[cfg_attr(doc, doc = crate::doc_tensor!("equal"))]
610 #[cfg_attr(not(doc), doc = "`Tensor::equal`")]
611 /// function, which is more high-level and designed for public use.
612 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
613
614 /// Element-wise equality between two tensors.
615 ///
616 /// # Arguments
617 ///
618 /// * `lhs` - The left hand side tensor.
619 /// * `rhs` - The right hand side tensor.
620 ///
621 /// # Returns
622 ///
623 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
624 /// corresponding elements of the input tensors are equal, and false otherwise.
625 ///
626 /// # Remarks
627 ///
628 /// This is a low-level function used internally by the library to call different backend functions
629 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
630 /// or use this function directly.
631 ///
632 /// For element-wise equality between two tensors, users should prefer the
633 #[cfg_attr(doc, doc = crate::doc_tensor!("equal_elem"))]
634 #[cfg_attr(not(doc), doc = "`Tensor::equal_elem`")]
635 /// function, which is more high-level and designed for public use.
636 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
637
638 /// Applies element-wise non-equality comparison between the given tensors.
639 ///
640 /// # Arguments
641 ///
642 /// * `lhs` - The left hand side tensor.
643 /// * `rhs` - The right hand side tensor.
644 ///
645 /// # Returns
646 ///
647 /// The tensor of booleans indicating whether the corresponding elements are equal.
648 ///
649 /// # Remarks
650 ///
651 /// This is a low-level function used internally by the library to call different backend functions
652 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
653 /// or use this function directly.
654 ///
655 /// For non-equality comparison of tensors, users should prefer the
656 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal"))]
657 #[cfg_attr(not(doc), doc = "`Tensor::not_equal`")]
658 /// function, which is more high-level and designed for public use.
659 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
660
661 /// Element-wise non-equality between two tensors.
662 ///
663 /// # Arguments
664 ///
665 /// * `lhs` - The left hand side tensor.
666 /// * `rhs` - The right hand side tensor.
667 ///
668 /// # Returns
669 ///
670 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
671 /// corresponding elements of the input tensors are equal, and false otherwise.
672 ///
673 /// # Remarks
674 ///
675 /// This is a low-level function used internally by the library to call different backend functions
676 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
677 /// or use this function directly.
678 ///
679 /// For element-wise non-equality between two tensors, users should prefer the
680 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal_elem"))]
681 #[cfg_attr(not(doc), doc = "`Tensor::not_equal_elem`")]
682 /// function, which is more high-level and designed for public use.
683 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
684
685 /// Returns the name of the element type.
686 fn elem_type_name() -> &'static str {
687 core::any::type_name::<Self::Elem>()
688 }
689
690 /// Returns the tensor data type.
691 fn dtype(tensor: &Self::Primitive) -> DType {
692 tensor.dtype()
693 }
694
695 /// Tests if any element in the `tensor` evaluates to True.
696 ///
697 /// # Arguments
698 ///
699 /// * `tensor` - The tensor to test.
700 ///
701 /// # Returns
702 ///
703 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
704 ///
705 /// # Remarks
706 ///
707 /// This is a low-level function used internally by the library to call different backend functions
708 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
709 /// or use this function directly. Users should prefer the
710 #[cfg_attr(doc, doc = crate::doc_tensor!("any"))]
711 #[cfg_attr(not(doc), doc = "`Tensor::any`")]
712 /// function, which is more high-level and designed for public use.
713 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
714
715 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
716 ///
717 /// # Arguments
718 ///
719 /// * tensor - The tensor to test.
720 /// * dim - The axis along which to test.
721 ///
722 /// # Returns
723 ///
724 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
725 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
726 ///
727 /// # Remarks
728 ///
729 /// This is a low-level function used internally by the library to call different backend functions
730 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
731 /// or use this function directly. Users should prefer the
732 #[cfg_attr(doc, doc = crate::doc_tensor!("any_dim"))]
733 #[cfg_attr(not(doc), doc = "`Tensor::any_dim`")]
734 /// function, which is more high-level and designed for public use.
735 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
736
737 /// Tests if all elements in the `tensor` evaluate to True.
738 ///
739 /// # Arguments
740 ///
741 /// * `tensor` - The tensor to test.
742 ///
743 /// # Returns
744 ///
745 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
746 ///
747 /// # Remarks
748 ///
749 /// This is a low-level function used internally by the library to call different backend functions
750 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
751 /// or use this function directly. Users should prefer the
752 #[cfg_attr(doc, doc = crate::doc_tensor!("all"))]
753 #[cfg_attr(not(doc), doc = "`Tensor::all`")]
754 /// function, which is more high-level and designed for public use.
755 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
756
757 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
758 ///
759 /// # Arguments
760 ///
761 /// * `tensor` - The tensor to test.
762 ///
763 /// # Returns
764 ///
765 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
766 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
767 ///
768 /// # Remarks
769 ///
770 /// This is a low-level function used internally by the library to call different backend functions
771 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
772 /// or use this function directly. Users should prefer the
773 #[cfg_attr(doc, doc = crate::doc_tensor!("all_dim"))]
774 #[cfg_attr(not(doc), doc = "`Tensor::all_dim`")]
775 /// function, which is more high-level and designed for public use.
776 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
777
778 /// Broadcasts the given tensor to the specified shape.
779 ///
780 /// # Arguments
781 ///
782 /// * `tensor` - The tensor to broadcast.
783 /// * `shape` - The shape to broadcast to.
784 ///
785 /// # Returns
786 ///
787 /// The broadcasted tensor.
788 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
789
790 /// Unfold windows along a dimension.
791 ///
792 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
793 /// where windows are advanced by `step` at each index.
794 ///
795 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
796 ///
797 /// # Warning
798 ///
799 /// For the `ndarray` and `candle` backends; this is not a view but a full copy.
800 ///
801 /// # Arguments
802 ///
803 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
804 /// * `dim` - the dimension to unfold.
805 /// * `size` - the size of each unfolded window.
806 /// * `step` - the step between each window.
807 ///
808 /// # Returns
809 ///
810 /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
811 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
812}