burn_backend/tensor/ops/base.rs
1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 Backend, BackendTypes, ExecutionError, Scalar, TensorData, TensorMetadata,
6 element::Element,
7 ops::TransactionPrimitive,
8 tensor::{IndexingUpdateOp, IntTensor, TensorKind},
9};
10
11/// Trait for the one basic op that still requires Backend
12///
13/// # Warnings
14///
15/// This is an internal trait, use the public API provided by the
16#[cfg_attr(doc, doc = crate::doc_tensor!())]
17#[cfg_attr(not(doc), doc = "`Tensor`")]
18pub trait TransactionOp<B: Backend>: TensorKind<B> {
19 /// Read the data from the tensor using a transaction.
20 ///
21 /// # Remarks
22 ///
23 /// This is a low-level function used internally by the library to call different backend functions
24 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
25 /// or use this function directly.
26 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive);
27}
28/// Trait that list all operations that can be applied on all tensors.
29///
30/// # Warnings
31///
32/// This is an internal trait, use the public API provided by the
33#[cfg_attr(doc, doc = crate::doc_tensor!())]
34#[cfg_attr(not(doc), doc = "`Tensor`")]
35/// struct.
36pub trait BasicOps<B: BackendTypes>: TensorKind<B> {
37 /// The type of the tensor elements.
38 type Elem: Element;
39
40 /// Creates an empty tensor with the given shape.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device on which the tensor will be allocated.
46 /// * `dtype` - The target data type.
47 ///
48 /// # Returns
49 ///
50 /// The empty tensor.
51 ///
52 /// # Remarks
53 ///
54 /// This is a low-level function used internally by the library to call different backend functions
55 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
56 /// or use this function directly.
57 ///
58 /// For creating empty tensors, users should prefer the
59 #[cfg_attr(doc, doc = crate::doc_tensor!("empty"))]
60 #[cfg_attr(not(doc), doc = "`Tensor::empty`")]
61 /// function, which is more high-level and designed for public use.
62 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
63
64 /// Creates a tensor filled with zeros.
65 ///
66 /// # Arguments
67 ///
68 /// * `shape` - The shape of the tensor.
69 /// * `device` - The device on which the tensor will be allocated.
70 /// * `dtype` - The target data type.
71 ///
72 /// # Returns
73 ///
74 /// The tensor filled with zeros.
75 ///
76 /// # Remarks
77 ///
78 /// This is a low-level function used internally by the library to call different backend functions
79 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
80 /// or use this function directly.
81 ///
82 /// For creating a tensor filled with zeros, users should prefer the
83 #[cfg_attr(doc, doc = crate::doc_tensor!("zeros"))]
84 #[cfg_attr(not(doc), doc = "`Tensor::zeros`")]
85 /// function, which is more high-level and designed for public use.
86 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
87
88 /// Creates a tensor filled with ones.
89 ///
90 /// # Arguments
91 ///
92 /// * `shape` - The shape of the tensor.
93 /// * `device` - The device on which the tensor will be allocated.
94 /// * `dtype` - The target data type.
95 ///
96 /// # Returns
97 ///
98 /// The tensor filled with ones.
99 ///
100 /// # Remarks
101 ///
102 /// This is a low-level function used internally by the library to call different backend functions
103 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
104 /// or use this function directly.
105 ///
106 /// For creating a tensor filled with ones, users should prefer the
107 #[cfg_attr(doc, doc = crate::doc_tensor!("ones"))]
108 #[cfg_attr(not(doc), doc = "`Tensor::ones`")]
109 /// function, which is more high-level and designed for public use.
110 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
111
112 /// Creates a tensor of the given shape where each element is equal to the provided value.
113 ///
114 /// # Arguments
115 ///
116 /// * `shape` - The shape of the tensor.
117 /// * `fill_value` - The value with which to fill the tensor.
118 /// * `device` - The device on which the tensor will be allocated.
119 /// * `dtype` - The target data type.
120 ///
121 /// # Returns
122 ///
123 /// The tensor filled with the specified value.
124 ///
125 /// # Remarks
126 ///
127 /// This is a low-level function used internally by the library to call different backend functions
128 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
129 /// or use this function directly.
130 ///
131 /// For creating full tensors, users should prefer the
132 #[cfg_attr(doc, doc = crate::doc_tensor!("full"))]
133 #[cfg_attr(not(doc), doc = "`Tensor::full`")]
134 /// function, which is more high-level and designed for public use.
135 fn full(shape: Shape, fill_value: Scalar, device: &B::Device, dtype: DType) -> Self::Primitive;
136
137 /// Reshapes the tensor.
138 ///
139 /// # Arguments
140 ///
141 /// * `tensor` - The tensor.
142 /// * `shape` - The new shape of the tensor.
143 ///
144 /// # Returns
145 ///
146 /// The reshaped tensor.
147 ///
148 /// # Remarks
149 ///
150 /// This is a low-level function used internally by the library to call different backend functions
151 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
152 /// or use this function directly.
153 ///
154 /// For reshaping a tensor, users should prefer the
155 #[cfg_attr(doc, doc = crate::doc_tensor!("reshape"))]
156 #[cfg_attr(not(doc), doc = "`Tensor::reshape`")]
157 /// function, which is more high-level and designed for public use.
158 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
159
160 /// Transposes a tensor.
161 ///
162 /// # Arguments
163 ///
164 /// * `tensor` - The tensor to transpose.
165 ///
166 /// # Returns
167 ///
168 /// The transposed tensor.
169 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
170
171 /// Swaps two dimensions of a tensor.
172 ///
173 /// # Arguments
174 ///
175 /// * `tensor` - The tensor to swap the dimensions of.
176 /// * `dim1` - The first dimension to swap.
177 /// * `dim2` - The second dimension to swap.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the dimensions swapped.
182 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
183
184 /// Permutes the dimensions of a tensor.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor to permute the dimensions of.
189 /// * `axes` - The new order of the dimensions.
190 ///
191 /// # Returns
192 ///
193 /// The tensor with the dimensions permuted.
194 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
195
196 /// Flips the tensor along the given axes.
197 ///
198 /// # Arguments
199 ///
200 /// * `tensor` - The tensor to flip.
201 /// * `axes` - The axes to flip the tensor along.
202 ///
203 /// # Returns
204 ///
205 /// The tensor with the axes flipped.
206 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
207
208 /// Select tensor elements corresponding to the given slices.
209 ///
210 /// # Arguments
211 ///
212 /// * `tensor` - The tensor.
213 /// * `slices` - The slices specifying ranges and steps for each dimension.
214 ///
215 /// # Returns
216 ///
217 /// The selected elements.
218 ///
219 /// # Remarks
220 ///
221 /// This is a low-level function used internally by the library to call different backend functions
222 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
223 /// or use this function directly.
224 ///
225 /// For selecting elements of a tensor, users should prefer the
226 #[cfg_attr(doc, doc = crate::doc_tensor!("slice"))]
227 #[cfg_attr(not(doc), doc = "`Tensor::slice`")]
228 /// function, which is more high-level and designed for public use.
229 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
230
231 /// Assigns the given value to the tensor elements corresponding to the given slices.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor.
236 /// * `slices` - The slices specifying which elements to assign, including support for steps.
237 /// * `value` - The value to assign.
238 ///
239 /// # Returns
240 ///
241 /// The tensor with the assigned values.
242 ///
243 /// # Remarks
244 ///
245 /// This is a low-level function used internally by the library to call different backend functions
246 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
247 /// or use this function directly.
248 ///
249 /// For assigning values to elements of a tensor, users should prefer the
250 #[cfg_attr(doc, doc = crate::doc_tensor!("slice_assign"))]
251 #[cfg_attr(not(doc), doc = "`Tensor::slice_assign`")]
252 /// function, which is more high-level and designed for public use.
253 fn slice_assign(
254 tensor: Self::Primitive,
255 slices: &[Slice],
256 value: Self::Primitive,
257 ) -> Self::Primitive;
258
259 /// Select tensor elements along the given dimension corresponding to the given indices.
260 ///
261 /// # Arguments
262 ///
263 /// * `tensor` - The tensor to select from.
264 /// * `dim` - The dimension along which to select.
265 /// * `indices` - The indices of the elements to select.
266 ///
267 /// # Returns
268 ///
269 /// The selected tensor elements.
270 ///
271 /// # Remarks
272 ///
273 /// This is a low-level function used internally by the library to call different backend functions
274 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
275 /// or use this function directly.
276 ///
277 /// For selecting elements from a tensor along an axis, users should prefer the
278 #[cfg_attr(doc, doc = crate::doc_tensor!("select"))]
279 #[cfg_attr(not(doc), doc = "`Tensor::select`")]
280 /// function, which is more high-level and designed for public use.
281 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive;
282
283 /// Assign the selected elements along the given dimension corresponding to the given indices
284 /// from the value tensor.
285 ///
286 /// # Arguments
287 ///
288 /// * `tensor` - The tensor to assign elements to.
289 /// * `dim` - The axis along which to assign elements.
290 /// * `indices` - The indices of the elements to assign.
291 /// * `values` - The values to assign to the tensor.
292 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
293 ///
294 /// # Returns
295 ///
296 /// A tensor with the same shape as the input tensor, where each element is taken from the
297 /// corresponding element of the input tensor at the corresponding index along the specified axis,
298 /// except for the elements at the specified indices, which are taken from the corresponding
299 /// element of the values tensor.
300 ///
301 /// # Remarks
302 ///
303 /// This is a low-level function used internally by the library to call different backend functions
304 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
305 /// or use this function directly.
306 ///
307 /// For assigning elements to a tensor along an axis, users should prefer the
308 #[cfg_attr(doc, doc = crate::doc_tensor!("select_assign"))]
309 #[cfg_attr(not(doc), doc = "`Tensor::select_assign`")]
310 /// function, which is more high-level and designed for public use.
311 fn select_assign(
312 tensor: Self::Primitive,
313 dim: usize,
314 indices: IntTensor<B>,
315 values: Self::Primitive,
316 update: IndexingUpdateOp,
317 ) -> Self::Primitive;
318
319 /// Selects elements from a tensor based on a boolean mask.
320 ///
321 /// # Arguments
322 ///
323 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
324 /// * `mask` - The boolean mask to use for selecting elements.
325 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
326 ///
327 /// # Returns
328 ///
329 /// A tensor with the same shape as the input tensors, where each element is taken from the
330 /// corresponding element of the left hand side tensor if the corresponding element of the mask
331 /// is true, and from the corresponding element of the right hand side tensor otherwise.
332 ///
333 /// # Remarks
334 ///
335 /// This is a low-level function used internally by the library to call different backend functions
336 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
337 /// or use this function directly.
338 ///
339 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
340 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_where"))]
341 #[cfg_attr(not(doc), doc = "`Tensor::mask_where`")]
342 /// function, which is more high-level and designed for public use.
343 fn mask_where(
344 tensor: Self::Primitive,
345 mask: B::BoolTensorPrimitive,
346 source: Self::Primitive,
347 ) -> Self::Primitive;
348
349 /// Fills elements of a tensor based on a boolean mask.
350 ///
351 /// # Arguments
352 ///
353 /// * `tensor` - The tensor where will be overwritten with the value
354 /// when the corresponding element of the mask is true.
355 /// * `mask` - The boolean mask to use for filling elements.
356 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
357 ///
358 /// # Returns
359 ///
360 /// A tensor with the same shape as the input tensors, where each element is taken from the
361 /// corresponding element unmodified if the corresponding element of the mask is false, and
362 /// filled with the value otherwise.
363 ///
364 /// # Remarks
365 ///
366 /// This is a low-level function used internally by the library to call different backend functions
367 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
368 /// or use this function directly.
369 ///
370 /// For filling elements of a tensor based on a boolean mask, users should prefer the
371 #[cfg_attr(doc, doc = crate::doc_tensor!("mask_fill"))]
372 #[cfg_attr(not(doc), doc = "`Tensor::mask_fill`")]
373 /// function, which is more high-level and designed for public use.
374 fn mask_fill(
375 tensor: Self::Primitive,
376 mask: B::BoolTensorPrimitive,
377 value: Scalar,
378 ) -> Self::Primitive;
379
380 /// Gathers elements from a tensor along an axis.
381 ///
382 /// # Arguments
383 ///
384 /// * `dim` - The axis along which to gather elements.
385 /// * `tensor` - The tensor to gather elements from.
386 /// * `indices` - The indices of the elements to gather.
387 ///
388 /// # Returns
389 ///
390 /// A tensor with the same shape as the input tensor, where each element is taken from the
391 /// corresponding element of the input tensor at the corresponding index along the specified axis.
392 ///
393 /// # Remarks
394 ///
395 /// This is a low-level function used internally by the library to call different backend functions
396 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
397 /// or use this function directly.
398 ///
399 /// For gathering elements from a tensor along an axis, users should prefer the
400 #[cfg_attr(doc, doc = crate::doc_tensor!("gather"))]
401 #[cfg_attr(not(doc), doc = "`Tensor::gather`")]
402 /// function, which is more high-level and designed for public use.
403 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;
404
405 /// Scatters elements into a tensor along an axis.
406 ///
407 /// # Arguments
408 ///
409 /// * `dim` - The axis along which to scatter elements.
410 /// * `tensor` - The tensor to scatter elements into.
411 /// * `indices` - The indices of the elements to scatter.
412 /// * `values` - The values to scatter into the tensor.
413 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
414 ///
415 /// # Returns
416 ///
417 /// A tensor with the same shape as the input tensor, where each element is taken from the
418 /// corresponding element of the input tensor at the corresponding index along the specified axis,
419 /// except for the elements at the specified indices, which are taken from the corresponding
420 /// element of the values tensor.
421 ///
422 /// # Remarks
423 ///
424 /// This is a low-level function used internally by the library to call different backend functions
425 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
426 /// or use this function directly.
427 ///
428 /// For scattering elements into a tensor along an axis, users should prefer the
429 #[cfg_attr(doc, doc = crate::doc_tensor!("scatter"))]
430 #[cfg_attr(not(doc), doc = "`Tensor::scatter`")]
431 /// function, which is more high-level and designed for public use.
432 fn scatter(
433 dim: usize,
434 tensor: Self::Primitive,
435 indices: IntTensor<B>,
436 values: Self::Primitive,
437 update: IndexingUpdateOp,
438 ) -> Self::Primitive;
439
440 /// Multi-dimensional scatter: update `data` at multi-index locations specified by `indices`.
441 fn scatter_nd(
442 data: Self::Primitive,
443 indices: IntTensor<B>,
444 values: Self::Primitive,
445 reduction: IndexingUpdateOp,
446 ) -> Self::Primitive;
447
448 /// Multi-dimensional gather: collect slices from `data` at multi-index locations.
449 fn gather_nd(data: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;
450
451 /// Returns the device on which the tensor is allocated.
452 ///
453 /// # Arguments
454 ///
455 /// * `tensor` - The tensor.
456 ///
457 /// # Returns
458 ///
459 /// The device on which the tensor is allocated.
460 ///
461 /// # Remarks
462 ///
463 /// This is a low-level function used internally by the library to call different backend functions
464 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
465 /// or use this function directly.
466 ///
467 /// For getting the device of a tensor, users should prefer the
468 #[cfg_attr(doc, doc = crate::doc_tensor!("device"))]
469 #[cfg_attr(not(doc), doc = "`Tensor::device`")]
470 /// function, which is more high-level and designed for public use.
471 fn device(tensor: &Self::Primitive) -> B::Device;
472
473 /// Moves the tensor to the given device.
474 ///
475 /// # Arguments
476 ///
477 /// * `tensor` - The tensor.
478 /// * `device` - The device on which the tensor will be moved.
479 ///
480 /// # Returns
481 ///
482 /// The tensor on the given device.
483 ///
484 /// # Remarks
485 ///
486 /// This is a low-level function used internally by the library to call different backend functions
487 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
488 /// or use this function directly.
489 ///
490 /// For moving a tensor to a device, users should prefer the
491 #[cfg_attr(doc, doc = crate::doc_tensor!("to_device"))]
492 #[cfg_attr(not(doc), doc = "`Tensor::to_device`")]
493 /// function, which is more high-level and designed for public use.
494 #[allow(clippy::wrong_self_convention)]
495 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
496
497 /// Extracts the data from the tensor asynchronously.
498 ///
499 /// # Arguments
500 ///
501 /// * `tensor` - The tensor.
502 ///
503 /// # Returns
504 ///
505 /// The data of the tensor.
506 ///
507 /// # Remarks
508 ///
509 /// This is a low-level function used internally by the library to call different backend functions
510 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
511 /// or use this function directly.
512 ///
513 /// For extracting the data of a tensor, users should prefer the
514 #[cfg_attr(doc, doc = crate::doc_tensor!("into_data"))]
515 #[cfg_attr(not(doc), doc = "`Tensor::into_data`")]
516 /// function, which is more high-level and designed for public use.
517 #[allow(clippy::wrong_self_convention)]
518 fn into_data_async(
519 tensor: Self::Primitive,
520 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
521
522 /// Creates a tensor from the given data enforcing the provided data type.
523 ///
524 /// # Arguments
525 ///
526 /// * `data` - The data of the tensor.
527 /// * `device` - The device on which the tensor will be allocated.
528 /// * `dtype` - The target data type.
529 ///
530 /// # Remarks
531 ///
532 /// This is a low-level function used internally by the library to call different backend functions
533 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
534 /// or use this function directly.
535 ///
536 /// For creating a tensor from data, users should prefer the
537 #[cfg_attr(doc, doc = crate::doc_tensor!("from_data"))]
538 #[cfg_attr(not(doc), doc = "`Tensor::from_data`")]
539 /// function, which is more high-level and designed for public use.
540 fn from_data(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
541
542 /// Repeat the tensor along the given dimension.
543 ///
544 /// # Arguments
545 ///
546 /// * `tensor` - The tensor.
547 /// * `dim` - The dimension along which the tensor will be repeated.
548 /// * `times` - The number of times the tensor will be repeated.
549 ///
550 /// # Returns
551 ///
552 /// The repeated tensor.
553 ///
554 /// # Remarks
555 ///
556 /// This is a low-level function used internally by the library to call different backend functions
557 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
558 /// or use this function directly.
559 ///
560 /// For repeating a tensor, users should prefer the
561 #[cfg_attr(doc, doc = crate::doc_tensor!("repeat_dim"))]
562 #[cfg_attr(not(doc), doc = "`Tensor::repeat_dim`")]
563 /// function, which is more high-level and designed for public use.
564 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
565
566 /// Concatenates the given tensors along the given dimension.
567 ///
568 /// # Arguments
569 ///
570 /// * `vectors` - The tensors to concatenate.
571 /// * `dim` - The dimension along which the tensors will be concatenated.
572 ///
573 /// # Returns
574 ///
575 /// The concatenated tensor.
576 ///
577 /// # Remarks
578 ///
579 /// This is a low-level function used internally by the library to call different backend functions
580 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
581 /// or use this function directly.
582 ///
583 /// For concatenating tensors, users should prefer the
584 #[cfg_attr(doc, doc = crate::doc_tensor!("cat"))]
585 #[cfg_attr(not(doc), doc = "`Tensor::cat`")]
586 /// function, which is more high-level and designed for public use.
587 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
588
589 /// Equates the given tensors.
590 ///
591 /// # Arguments
592 ///
593 /// * `lhs` - The left hand side tensor.
594 /// * `rhs` - The right hand side tensor.
595 ///
596 /// # Returns
597 ///
598 /// The tensor of booleans indicating whether the corresponding elements are equal.
599 ///
600 /// # Remarks
601 ///
602 /// This is a low-level function used internally by the library to call different backend functions
603 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
604 /// or use this function directly.
605 ///
606 /// For equating tensors, users should prefer the
607 #[cfg_attr(doc, doc = crate::doc_tensor!("equal"))]
608 #[cfg_attr(not(doc), doc = "`Tensor::equal`")]
609 /// function, which is more high-level and designed for public use.
610 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
611
612 /// Element-wise equality between two tensors.
613 ///
614 /// # Arguments
615 ///
616 /// * `lhs` - The left hand side tensor.
617 /// * `rhs` - The right hand side scalar.
618 ///
619 /// # Returns
620 ///
621 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
622 /// corresponding elements of the input tensors are equal, and false otherwise.
623 ///
624 /// # Remarks
625 ///
626 /// This is a low-level function used internally by the library to call different backend functions
627 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
628 /// or use this function directly.
629 ///
630 /// For element-wise equality between two tensors, users should prefer the
631 #[cfg_attr(doc, doc = crate::doc_tensor!("equal_elem"))]
632 #[cfg_attr(not(doc), doc = "`Tensor::equal_elem`")]
633 /// function, which is more high-level and designed for public use.
634 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
635
636 /// Applies element-wise non-equality comparison between the given tensors.
637 ///
638 /// # Arguments
639 ///
640 /// * `lhs` - The left hand side tensor.
641 /// * `rhs` - The right hand side tensor.
642 ///
643 /// # Returns
644 ///
645 /// The tensor of booleans indicating whether the corresponding elements are equal.
646 ///
647 /// # Remarks
648 ///
649 /// This is a low-level function used internally by the library to call different backend functions
650 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
651 /// or use this function directly.
652 ///
653 /// For non-equality comparison of tensors, users should prefer the
654 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal"))]
655 #[cfg_attr(not(doc), doc = "`Tensor::not_equal`")]
656 /// function, which is more high-level and designed for public use.
657 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
658
659 /// Element-wise non-equality between two tensors.
660 ///
661 /// # Arguments
662 ///
663 /// * `lhs` - The left hand side tensor.
664 /// * `rhs` - The right hand side scalar.
665 ///
666 /// # Returns
667 ///
668 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
669 /// corresponding elements of the input tensors are equal, and false otherwise.
670 ///
671 /// # Remarks
672 ///
673 /// This is a low-level function used internally by the library to call different backend functions
674 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
675 /// or use this function directly.
676 ///
677 /// For element-wise non-equality between two tensors, users should prefer the
678 #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal_elem"))]
679 #[cfg_attr(not(doc), doc = "`Tensor::not_equal_elem`")]
680 /// function, which is more high-level and designed for public use.
681 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;
682
683 /// Returns the name of the element type.
684 fn elem_type_name() -> &'static str {
685 core::any::type_name::<Self::Elem>()
686 }
687
688 /// Returns the tensor data type.
689 fn dtype(tensor: &Self::Primitive) -> DType {
690 tensor.dtype()
691 }
692
693 /// Tests if any element in the `tensor` evaluates to True.
694 ///
695 /// # Arguments
696 ///
697 /// * `tensor` - The tensor to test.
698 ///
699 /// # Returns
700 ///
701 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
702 ///
703 /// # Remarks
704 ///
705 /// This is a low-level function used internally by the library to call different backend functions
706 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
707 /// or use this function directly. Users should prefer the
708 #[cfg_attr(doc, doc = crate::doc_tensor!("any"))]
709 #[cfg_attr(not(doc), doc = "`Tensor::any`")]
710 /// function, which is more high-level and designed for public use.
711 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
712
713 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
714 ///
715 /// # Arguments
716 ///
717 /// * tensor - The tensor to test.
718 /// * dim - The axis along which to test.
719 ///
720 /// # Returns
721 ///
722 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
723 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
724 ///
725 /// # Remarks
726 ///
727 /// This is a low-level function used internally by the library to call different backend functions
728 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
729 /// or use this function directly. Users should prefer the
730 #[cfg_attr(doc, doc = crate::doc_tensor!("any_dim"))]
731 #[cfg_attr(not(doc), doc = "`Tensor::any_dim`")]
732 /// function, which is more high-level and designed for public use.
733 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
734
735 /// Tests if all elements in the `tensor` evaluate to True.
736 ///
737 /// # Arguments
738 ///
739 /// * `tensor` - The tensor to test.
740 ///
741 /// # Returns
742 ///
743 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
744 ///
745 /// # Remarks
746 ///
747 /// This is a low-level function used internally by the library to call different backend functions
748 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
749 /// or use this function directly. Users should prefer the
750 #[cfg_attr(doc, doc = crate::doc_tensor!("all"))]
751 #[cfg_attr(not(doc), doc = "`Tensor::all`")]
752 /// function, which is more high-level and designed for public use.
753 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
754
755 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
756 ///
757 /// # Arguments
758 ///
759 /// * `tensor` - The tensor to test.
760 ///
761 /// # Returns
762 ///
763 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
764 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
765 ///
766 /// # Remarks
767 ///
768 /// This is a low-level function used internally by the library to call different backend functions
769 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
770 /// or use this function directly. Users should prefer the
771 #[cfg_attr(doc, doc = crate::doc_tensor!("all_dim"))]
772 #[cfg_attr(not(doc), doc = "`Tensor::all_dim`")]
773 /// function, which is more high-level and designed for public use.
774 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
775
776 /// Broadcasts the given tensor to the specified shape.
777 ///
778 /// # Arguments
779 ///
780 /// * `tensor` - The tensor to broadcast.
781 /// * `shape` - The shape to broadcast to.
782 ///
783 /// # Returns
784 ///
785 /// The broadcasted tensor.
786 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
787
788 /// Unfold windows along a dimension.
789 ///
790 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
791 /// where windows are advanced by `step` at each index.
792 ///
793 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
794 ///
795 /// # Warning
796 ///
797 /// For the `ndarray` and `candle` backends; this is not a view but a full copy.
798 ///
799 /// # Arguments
800 ///
801 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
802 /// * `dim` - the dimension to unfold.
803 /// * `size` - the size of each unfolded window.
804 /// * `step` - the step between each window.
805 ///
806 /// # Returns
807 ///
808 /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
809 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
810}