burn_backend/tensor/ops/base.rs
1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 Backend, BackendTypes, ExecutionError, Scalar, TensorData, TensorMetadata,
6 element::Element,
7 ops::TransactionPrimitive,
8 tensor::{IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11/// Trait for the one basic op that still requires Backend
12///
13/// # Warnings
14///
15/// This is an internal trait, use the public API provided by the
16#[cfg_attr(doc, doc = crate::doc_tensor!())]
17#[cfg_attr(not(doc), doc = "`Tensor`")]
18pub trait TransactionOp<B: Backend>: TensorKind<B> {
19 /// Read the data from the tensor using a transaction.
20 ///
21 /// # Remarks
22 ///
23 /// This is a low-level function used internally by the library to call different backend functions
24 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
25 /// or use this function directly.
26 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive);
27}
28/// Trait that list all operations that can be applied on all tensors.
29///
30/// # Warnings
31///
32/// This is an internal trait, use the public API provided by the
33#[cfg_attr(doc, doc = crate::doc_tensor!())]
34#[cfg_attr(not(doc), doc = "`Tensor`")]
35/// struct.
36pub trait BasicOps<B: BackendTypes>: TensorKind<B> {
37 /// The type of the tensor elements.
38 type Elem: Element;
39
40 /// Creates an empty tensor with the given shape.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device on which the tensor will be allocated.
46 /// * `dtype` - The target data type.
47 ///
48 /// # Returns
49 ///
50 /// The empty tensor.
51 ///
52 /// # Remarks
53 ///
54 /// This is a low-level function used internally by the library to call different backend functions
55 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
56 /// or use this function directly.
57 ///
58 /// For creating empty tensors, users should prefer the
59 #[cfg_attr(doc, doc = crate::doc_tensor!("empty"))]
60 #[cfg_attr(not(doc), doc = "`Tensor::empty`")]
61 /// function, which is more high-level and designed for public use.
62 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
63
64 /// Creates a tensor filled with zeros.
65 ///
66 /// # Arguments
67 ///
68 /// * `shape` - The shape of the tensor.
69 /// * `device` - The device on which the tensor will be allocated.
70 /// * `dtype` - The target data type.
71 ///
72 /// # Returns
73 ///
74 /// The tensor filled with zeros.
75 ///
76 /// # Remarks
77 ///
78 /// This is a low-level function used internally by the library to call different backend functions
79 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
80 /// or use this function directly.
81 ///
82 /// For creating a tensor filled with zeros, users should prefer the
83 #[cfg_attr(doc, doc = crate::doc_tensor!("zeros"))]
84 #[cfg_attr(not(doc), doc = "`Tensor::zeros`")]
85 /// function, which is more high-level and designed for public use.
86 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
87
88 /// Creates a tensor filled with ones.
89 ///
90 /// # Arguments
91 ///
92 /// * `shape` - The shape of the tensor.
93 /// * `device` - The device on which the tensor will be allocated.
94 /// * `dtype` - The target data type.
95 ///
96 /// # Returns
97 ///
98 /// The tensor filled with ones.
99 ///
100 /// # Remarks
101 ///
102 /// This is a low-level function used internally by the library to call different backend functions
103 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
104 /// or use this function directly.
105 ///
106 /// For creating a tensor filled with ones, users should prefer the
107 #[cfg_attr(doc, doc = crate::doc_tensor!("ones"))]
108 #[cfg_attr(not(doc), doc = "`Tensor::ones`")]
109 /// function, which is more high-level and designed for public use.
110 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
111
112 /// Creates a tensor of the given shape where each element is equal to the provided value.
113 ///
114 /// # Arguments
115 ///
116 /// * `shape` - The shape of the tensor.
117 /// * `fill_value` - The value with which to fill the tensor.
118 /// * `device` - The device on which the tensor will be allocated.
119 /// * `dtype` - The target data type.
120 ///
121 /// # Returns
122 ///
123 /// The tensor filled with the specified value.
124 ///
125 /// # Remarks
126 ///
127 /// This is a low-level function used internally by the library to call different backend functions
128 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
129 /// or use this function directly.
130 ///
131 /// For creating full tensors, users should prefer the
132 #[cfg_attr(doc, doc = crate::doc_tensor!("full"))]
133 #[cfg_attr(not(doc), doc = "`Tensor::full`")]
134 /// function, which is more high-level and designed for public use.
135 fn full(shape: Shape, fill_value: Scalar, device: &B::Device, dtype: DType) -> Self::Primitive;
136
137 /// Reshapes the tensor.
138 ///
139 /// # Arguments
140 ///
141 /// * `tensor` - The tensor.
142 /// * `shape` - The new shape of the tensor.
143 ///
144 /// # Returns
145 ///
146 /// The reshaped tensor.
147 ///
148 /// # Remarks
149 ///
150 /// This is a low-level function used internally by the library to call different backend functions
151 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
152 /// or use this function directly.
153 ///
154 /// For reshaping a tensor, users should prefer the
155 #[cfg_attr(doc, doc = crate::doc_tensor!("reshape"))]
156 #[cfg_attr(not(doc), doc = "`Tensor::reshape`")]
157 /// function, which is more high-level and designed for public use.
158 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
159
160 /// Transposes a tensor.
161 ///
162 /// # Arguments
163 ///
164 /// * `tensor` - The tensor to transpose.
165 ///
166 /// # Returns
167 ///
168 /// The transposed tensor.
169 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
170
171 /// Swaps two dimensions of a tensor.
172 ///
173 /// # Arguments
174 ///
175 /// * `tensor` - The tensor to swap the dimensions of.
176 /// * `dim1` - The first dimension to swap.
177 /// * `dim2` - The second dimension to swap.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the dimensions swapped.
182 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
183
184 /// Permutes the dimensions of a tensor.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor to permute the dimensions of.
189 /// * `axes` - The new order of the dimensions.
190 ///
191 /// # Returns
192 ///
193 /// The tensor with the dimensions permuted.
194 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
195
196 /// Flips the tensor along the given axes.
197 ///
198 /// # Arguments
199 ///
200 /// * `tensor` - The tensor to flip.
201 /// * `axes` - The axes to flip the tensor along.
202 ///
203 /// # Returns
204 ///
205 /// The tensor with the axes flipped.
206 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
207
208 /// Select tensor elements corresponding to the given slices.
209 ///
210 /// # Arguments
211 ///
212 /// * `tensor` - The tensor.
213 /// * `slices` - The slices specifying ranges and steps for each dimension.
214 ///
215 /// # Returns
216 ///
217 /// The selected elements.
218 ///
219 /// # Remarks
220 ///
221 /// This is a low-level function used internally by the library to call different backend functions
222 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
223 /// or use this function directly.
224 ///
225 /// For selecting elements of a tensor, users should prefer the
226 #[cfg_attr(doc, doc = crate::doc_tensor!("slice"))]
227 #[cfg_attr(not(doc), doc = "`Tensor::slice`")]
228 /// function, which is more high-level and designed for public use.
229 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
230
231 /// Assigns the given value to the tensor elements corresponding to the given slices.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor.
236 /// * `slices` - The slices specifying which elements to assign, including support for steps.
237 /// * `value` - The value to assign.
238 ///
239 /// # Returns
240 ///
241 /// The tensor with the assigned values.
242 ///
243 /// # Remarks
244 ///
245 /// This is a low-level function used internally by the library to call different backend functions
246 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
247 /// or use this function directly.
248 ///
249 /// For assigning values to elements of a tensor, users should prefer the
250 #[cfg_attr(doc, doc = crate::doc_tensor!("slice_assign"))]
251 #[cfg_attr(not(doc), doc = "`Tensor::slice_assign`")]
252 /// function, which is more high-level and designed for public use.
253 fn slice_assign(
254 tensor: Self::Primitive,
255 slices: &[Slice],
256 value: Self::Primitive,
257 ) -> Self::Primitive;
258
259 /// Select tensor elements along the given dimension corresponding to the given indices.
260 ///
261 /// # Arguments
262 ///
263 /// * `tensor` - The tensor to select from.
264 /// * `dim` - The dimension along which to select.
265 /// * `indices` - The indices of the elements to select.
266 ///
267 /// # Returns
268 ///
269 /// The selected tensor elements.
270 ///
271 /// # Remarks
272 ///
273 /// This is a low-level function used internally by the library to call different backend functions
274 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
275 /// or use this function directly.
276 ///
277 /// For selecting elements from a tensor along an axis, users should prefer the
278 #[cfg_attr(doc, doc = crate::doc_tensor!("select"))]
279 #[cfg_attr(not(doc), doc = "`Tensor::select`")]
280 /// function, which is more high-level and designed for public use.
281 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive;
282
283 /// Assign the selected elements along the given dimension corresponding to the given indices
284 /// from the value tensor.
285 ///
286 /// # Arguments
287 ///
288 /// * `tensor` - The tensor to assign elements to.
289 /// * `dim` - The axis along which to assign elements.
290 /// * `indices` - The indices of the elements to assign.
291 /// * `values` - The values to assign to the tensor.
292 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
293 ///
294 /// # Returns
295 ///
296 /// A tensor with the same shape as the input tensor, where each element is taken from the
297 /// corresponding element of the input tensor at the corresponding index along the specified axis,
298 /// except for the elements at the specified indices, which are taken from the corresponding
299 /// element of the values tensor.
300 ///
301 /// # Remarks
302 ///
303 /// This is a low-level function used internally by the library to call different backend functions
304 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
305 /// or use this function directly.
306 ///
307 /// For assigning elements to a tensor along an axis, users should prefer the
308 #[cfg_attr(doc, doc = crate::doc_tensor!("select_assign"))]
309 #[cfg_attr(not(doc), doc = "`Tensor::select_assign`")]
310 /// function, which is more high-level and designed for public use.
311 fn select_assign(
312 tensor: Self::Primitive,
313 dim: usize,
314 indices: IntTensor<B>,
315 values: Self::Primitive,
316 update: IndexingUpdateOp,
317 ) -> Self::Primitive;
318
319 /// Selects elements from a tensor based on a boolean mask.
320 ///
321 /// # Arguments
322 ///
323 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
324 /// * `mask` - The boolean mask to use for selecting elements.
325 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
326 ///
327 /// # Returns
328 ///
329 /// A tensor with the same shape as the input tensors, where each element is taken from the
330 /// corresponding element of the left hand side tensor if the corresponding element of the mask
331 /// is true, and from the corresponding element of the right hand side tensor otherwise.
332 ///
333 /// # Remarks
334 ///
335 /// This is a low-level function used internally by the library to call different backend functions
336 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
337 /// or use this function directly.
338 ///
339 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
340 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_where"))]
341 #[cfg_attr(not(doc), doc = "`Tensor::mask_where`")]
342 /// function, which is more high-level and designed for public use.
343 fn mask_where(
344 tensor: Self::Primitive,
345 mask: B::BoolTensorPrimitive,
346 source: Self::Primitive,
347 ) -> Self::Primitive;
348
349 /// Fills elements of a tensor based on a boolean mask.
350 ///
351 /// # Arguments
352 ///
353 /// * `tensor` - The tensor where will be overwritten with the value
354 /// when the corresponding element of the mask is true.
355 /// * `mask` - The boolean mask to use for filling elements.
356 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
357 ///
358 /// # Returns
359 ///
360 /// A tensor with the same shape as the input tensors, where each element is taken from the
361 /// corresponding element unmodified if the corresponding element of the mask is false, and
362 /// filled with the value otherwise.
363 ///
364 /// # Remarks
365 ///
366 /// This is a low-level function used internally by the library to call different backend functions
367 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
368 /// or use this function directly.
369 ///
370 /// For filling elements of a tensor based on a boolean mask, users should prefer the
371 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_fill"))]
372 #[cfg_attr(not(doc), doc = "`Tensor::mask_fill`")]
373 /// function, which is more high-level and designed for public use.
374 fn mask_fill(
375 tensor: Self::Primitive,
376 mask: B::BoolTensorPrimitive,
377 value: Scalar,
378 ) -> Self::Primitive;
379
380 /// Gathers elements from a tensor along an axis.
381 ///
382 /// # Arguments
383 ///
384 /// * `dim` - The axis along which to gather elements.
385 /// * `tensor` - The tensor to gather elements from.
386 /// * `indices` - The indices of the elements to gather.
387 ///
388 /// # Returns
389 ///
390 /// A tensor with the same shape as the input tensor, where each element is taken from the
391 /// corresponding element of the input tensor at the corresponding index along the specified axis.
392 ///
393 /// # Remarks
394 ///
395 /// This is a low-level function used internally by the library to call different backend functions
396 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
397 /// or use this function directly.
398 ///
399 /// For gathering elements from a tensor along an axis, users should prefer the
400 #[cfg_attr(doc, doc = crate::doc_tensor!("gather"))]
401 #[cfg_attr(not(doc), doc = "`Tensor::gather`")]
402 /// function, which is more high-level and designed for public use.
403 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;
404
405 /// Scatters elements into a tensor along an axis.
406 ///
407 /// # Arguments
408 ///
409 /// * `dim` - The axis along which to scatter elements.
410 /// * `tensor` - The tensor to scatter elements into.
411 /// * `indices` - The indices of the elements to scatter.
412 /// * `values` - The values to scatter into the tensor.
413 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
414 ///
415 /// # Returns
416 ///
417 /// A tensor with the same shape as the input tensor, where each element is taken from the
418 /// corresponding element of the input tensor at the corresponding index along the specified axis,
419 /// except for the elements at the specified indices, which are taken from the corresponding
420 /// element of the values tensor.
421 ///
422 /// # Remarks
423 ///
424 /// This is a low-level function used internally by the library to call different backend functions
425 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
426 /// or use this function directly.
427 ///
428 /// For scattering elements into a tensor along an axis, users should prefer the
429 #[cfg_attr(doc, doc = crate::doc_tensor!("scatter"))]
430 #[cfg_attr(not(doc), doc = "`Tensor::scatter`")]
431 /// function, which is more high-level and designed for public use.
432 fn scatter(
433 dim: usize,
434 tensor: Self::Primitive,
435 indices: IntTensor<B>,
436 values: Self::Primitive,
437 update: IndexingUpdateOp,
438 ) -> Self::Primitive;
439
440 /// Returns the device on which the tensor is allocated.
441 ///
442 /// # Arguments
443 ///
444 /// * `tensor` - The tensor.
445 ///
446 /// # Returns
447 ///
448 /// The device on which the tensor is allocated.
449 ///
450 /// # Remarks
451 ///
452 /// This is a low-level function used internally by the library to call different backend functions
453 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
454 /// or use this function directly.
455 ///
456 /// For getting the device of a tensor, users should prefer the
457 #[cfg_attr(doc, doc = crate::doc_tensor!("device"))]
458 #[cfg_attr(not(doc), doc = "`Tensor::device`")]
459 /// function, which is more high-level and designed for public use.
460 fn device(tensor: &Self::Primitive) -> B::Device;
461
462 /// Moves the tensor to the given device.
463 ///
464 /// # Arguments
465 ///
466 /// * `tensor` - The tensor.
467 /// * `device` - The device on which the tensor will be moved.
468 ///
469 /// # Returns
470 ///
471 /// The tensor on the given device.
472 ///
473 /// # Remarks
474 ///
475 /// This is a low-level function used internally by the library to call different backend functions
476 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
477 /// or use this function directly.
478 ///
479 /// For moving a tensor to a device, users should prefer the
480 #[cfg_attr(doc, doc = crate::doc_tensor!("to_device"))]
481 #[cfg_attr(not(doc), doc = "`Tensor::to_device`")]
482 /// function, which is more high-level and designed for public use.
483 #[allow(clippy::wrong_self_convention)]
484 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
485
486 /// Extracts the data from the tensor asynchronously.
487 ///
488 /// # Arguments
489 ///
490 /// * `tensor` - The tensor.
491 ///
492 /// # Returns
493 ///
494 /// The data of the tensor.
495 ///
496 /// # Remarks
497 ///
498 /// This is a low-level function used internally by the library to call different backend functions
499 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
500 /// or use this function directly.
501 ///
502 /// For extracting the data of a tensor, users should prefer the
503 #[cfg_attr(doc, doc = crate::doc_tensor!("into_data"))]
504 #[cfg_attr(not(doc), doc = "`Tensor::into_data`")]
505 /// function, which is more high-level and designed for public use.
506 #[allow(clippy::wrong_self_convention)]
507 fn into_data_async(
508 tensor: Self::Primitive,
509 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
510
511 /// Creates a tensor from the given data enforcing the provided data type.
512 ///
513 /// # Arguments
514 ///
515 /// * `data` - The data of the tensor.
516 /// * `device` - The device on which the tensor will be allocated.
517 /// * `dtype` - The target data type.
518 ///
519 /// # Remarks
520 ///
521 /// This is a low-level function used internally by the library to call different backend functions
522 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
523 /// or use this function directly.
524 ///
525 /// For creating a tensor from data, users should prefer the
526 #[cfg_attr(doc, doc = crate::doc_tensor!("from_data"))]
527 #[cfg_attr(not(doc), doc = "`Tensor::from_data`")]
528 /// function, which is more high-level and designed for public use.
529 fn from_data(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
530
531 /// Repeat the tensor along the given dimension.
532 ///
533 /// # Arguments
534 ///
535 /// * `tensor` - The tensor.
536 /// * `dim` - The dimension along which the tensor will be repeated.
537 /// * `times` - The number of times the tensor will be repeated.
538 ///
539 /// # Returns
540 ///
541 /// The repeated tensor.
542 ///
543 /// # Remarks
544 ///
545 /// This is a low-level function used internally by the library to call different backend functions
546 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
547 /// or use this function directly.
548 ///
549 /// For repeating a tensor, users should prefer the
550 #[cfg_attr(doc, doc = crate::doc_tensor!("repeat_dim"))]
551 #[cfg_attr(not(doc), doc = "`Tensor::repeat_dim`")]
552 /// function, which is more high-level and designed for public use.
553 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
554
555 /// Concatenates the given tensors along the given dimension.
556 ///
557 /// # Arguments
558 ///
559 /// * `vectors` - The tensors to concatenate.
560 /// * `dim` - The dimension along which the tensors will be concatenated.
561 ///
562 /// # Returns
563 ///
564 /// The concatenated tensor.
565 ///
566 /// # Remarks
567 ///
568 /// This is a low-level function used internally by the library to call different backend functions
569 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
570 /// or use this function directly.
571 ///
572 /// For concatenating tensors, users should prefer the
573 #[cfg_attr(doc, doc = crate::doc_tensor!("cat"))]
574 #[cfg_attr(not(doc), doc = "`Tensor::cat`")]
575 /// function, which is more high-level and designed for public use.
576 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
577
578 /// Equates the given tensors.
579 ///
580 /// # Arguments
581 ///
582 /// * `lhs` - The left hand side tensor.
583 /// * `rhs` - The right hand side tensor.
584 ///
585 /// # Returns
586 ///
587 /// The tensor of booleans indicating whether the corresponding elements are equal.
588 ///
589 /// # Remarks
590 ///
591 /// This is a low-level function used internally by the library to call different backend functions
592 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
593 /// or use this function directly.
594 ///
595 /// For equating tensors, users should prefer the
596 #[cfg_attr(doc, doc = crate::doc_tensor!("equal"))]
597 #[cfg_attr(not(doc), doc = "`Tensor::equal`")]
598 /// function, which is more high-level and designed for public use.
599 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
600
601 /// Element-wise equality between two tensors.
602 ///
603 /// # Arguments
604 ///
605 /// * `lhs` - The left hand side tensor.
606 /// * `rhs` - The right hand side scalar.
607 ///
608 /// # Returns
609 ///
610 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
611 /// corresponding elements of the input tensors are equal, and false otherwise.
612 ///
613 /// # Remarks
614 ///
615 /// This is a low-level function used internally by the library to call different backend functions
616 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
617 /// or use this function directly.
618 ///
619 /// For element-wise equality between two tensors, users should prefer the
620 #[cfg_attr(doc, doc = crate::doc_tensor!("equal_elem"))]
621 #[cfg_attr(not(doc), doc = "`Tensor::equal_elem`")]
622 /// function, which is more high-level and designed for public use.
623 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
624
625 /// Applies element-wise non-equality comparison between the given tensors.
626 ///
627 /// # Arguments
628 ///
629 /// * `lhs` - The left hand side tensor.
630 /// * `rhs` - The right hand side tensor.
631 ///
632 /// # Returns
633 ///
634 /// The tensor of booleans indicating whether the corresponding elements are equal.
635 ///
636 /// # Remarks
637 ///
638 /// This is a low-level function used internally by the library to call different backend functions
639 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
640 /// or use this function directly.
641 ///
642 /// For non-equality comparison of tensors, users should prefer the
643 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal"))]
644 #[cfg_attr(not(doc), doc = "`Tensor::not_equal`")]
645 /// function, which is more high-level and designed for public use.
646 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
647
648 /// Element-wise non-equality between two tensors.
649 ///
650 /// # Arguments
651 ///
652 /// * `lhs` - The left hand side tensor.
653 /// * `rhs` - The right hand side scalar.
654 ///
655 /// # Returns
656 ///
657 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
658 /// corresponding elements of the input tensors are equal, and false otherwise.
659 ///
660 /// # Remarks
661 ///
662 /// This is a low-level function used internally by the library to call different backend functions
663 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
664 /// or use this function directly.
665 ///
666 /// For element-wise non-equality between two tensors, users should prefer the
667 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal_elem"))]
668 #[cfg_attr(not(doc), doc = "`Tensor::not_equal_elem`")]
669 /// function, which is more high-level and designed for public use.
670 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
671
672 /// Returns the name of the element type.
673 fn elem_type_name() -> &'static str {
674 core::any::type_name::<Self::Elem>()
675 }
676
677 /// Returns the tensor data type.
678 fn dtype(tensor: &Self::Primitive) -> DType {
679 tensor.dtype()
680 }
681
682 /// Tests if any element in the `tensor` evaluates to True.
683 ///
684 /// # Arguments
685 ///
686 /// * `tensor` - The tensor to test.
687 ///
688 /// # Returns
689 ///
690 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
691 ///
692 /// # Remarks
693 ///
694 /// This is a low-level function used internally by the library to call different backend functions
695 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
696 /// or use this function directly. Users should prefer the
697 #[cfg_attr(doc, doc = crate::doc_tensor!("any"))]
698 #[cfg_attr(not(doc), doc = "`Tensor::any`")]
699 /// function, which is more high-level and designed for public use.
700 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
701
702 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
703 ///
704 /// # Arguments
705 ///
706 /// * tensor - The tensor to test.
707 /// * dim - The axis along which to test.
708 ///
709 /// # Returns
710 ///
711 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
712 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
713 ///
714 /// # Remarks
715 ///
716 /// This is a low-level function used internally by the library to call different backend functions
717 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
718 /// or use this function directly. Users should prefer the
719 #[cfg_attr(doc, doc = crate::doc_tensor!("any_dim"))]
720 #[cfg_attr(not(doc), doc = "`Tensor::any_dim`")]
721 /// function, which is more high-level and designed for public use.
722 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
723
724 /// Tests if all elements in the `tensor` evaluate to True.
725 ///
726 /// # Arguments
727 ///
728 /// * `tensor` - The tensor to test.
729 ///
730 /// # Returns
731 ///
732 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
733 ///
734 /// # Remarks
735 ///
736 /// This is a low-level function used internally by the library to call different backend functions
737 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
738 /// or use this function directly. Users should prefer the
739 #[cfg_attr(doc, doc = crate::doc_tensor!("all"))]
740 #[cfg_attr(not(doc), doc = "`Tensor::all`")]
741 /// function, which is more high-level and designed for public use.
742 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
743
744 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
745 ///
746 /// # Arguments
747 ///
748 /// * `tensor` - The tensor to test.
749 ///
750 /// # Returns
751 ///
752 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
753 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
754 ///
755 /// # Remarks
756 ///
757 /// This is a low-level function used internally by the library to call different backend functions
758 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
759 /// or use this function directly. Users should prefer the
760 #[cfg_attr(doc, doc = crate::doc_tensor!("all_dim"))]
761 #[cfg_attr(not(doc), doc = "`Tensor::all_dim`")]
762 /// function, which is more high-level and designed for public use.
763 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
764
765 /// Broadcasts the given tensor to the specified shape.
766 ///
767 /// # Arguments
768 ///
769 /// * `tensor` - The tensor to broadcast.
770 /// * `shape` - The shape to broadcast to.
771 ///
772 /// # Returns
773 ///
774 /// The broadcasted tensor.
775 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
776
777 /// Unfold windows along a dimension.
778 ///
779 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
780 /// where windows are advanced by `step` at each index.
781 ///
782 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
783 ///
784 /// # Warning
785 ///
786 /// For the `ndarray` and `candle` backends; this is not a view but a full copy.
787 ///
788 /// # Arguments
789 ///
790 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
791 /// * `dim` - the dimension to unfold.
792 /// * `size` - the size of each unfolded window.
793 /// * `step` - the step between each window.
794 ///
795 /// # Returns
796 ///
797 /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
798 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
799}