burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::future::Future;
11use core::iter::repeat;
12use core::{fmt::Debug, ops::Range};
13use serde::{Deserialize, Deserializer};
14
15use serde::{Serialize, Serializer};
16
17use super::{Slice, SliceArg, TensorMetadata, Transaction};
18use crate::indexing::{AsIndex, canonicalize_dim, wrap_index};
19use crate::{
20 Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, backend::Backend, check,
21 ops::Device,
22};
23use crate::{DType, Element, TensorPrimitive};
24use crate::{cast::ToElement, check::TensorCheck};
25
26/// A tensor with a given backend, shape and data type.
27///
28/// # Indexing
29/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
30/// or [`select`](Tensor::select) for numeric types.
31///
32/// ## Example
33///
34/// ```rust
35/// use burn_tensor::backend::Backend;
36/// use burn_tensor::Tensor;
37/// use burn_tensor::Int;
38///
39/// fn example<B: Backend>() {
40/// let device = Default::default();
41///
42/// let tensor = Tensor::<B, 2>::from_data(
43/// [
44/// [3.0, 4.9, 2.0],
45/// [2.0, 1.9, 3.0],
46/// [6.0, 1.5, 7.0],
47/// [3.0, 4.9, 9.0],
48/// ],
49/// &device,
50/// );
51///
52/// // Slice the tensor to get the second and third rows:
53/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
54/// // The resulting tensor will have dimensions [2, 3].
55/// let slice = tensor.clone().slice([1..3]);
56/// println!("{slice}");
57///
58/// // Slice the tensor to get the first two rows and the first 2 columns:
59/// // [[3.0, 4.9], [2.0, 1.9]]
60/// // The resulting tensor will have dimensions [2, 2].
61/// let slice = tensor.clone().slice([0..2, 0..2]);
62/// println!("{slice}");
63///
64/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
65/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
66/// // The resulting tensor will have dimensions [4, 2]
67/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
68/// let indexed = tensor.select(1, indices);
69/// println!("{indexed}");
70/// }
71/// ```
72#[derive(new, Clone, Debug)]
73pub struct Tensor<B, const D: usize, K = Float>
74where
75 B: Backend,
76 K: TensorKind<B>,
77{
78 pub(crate) primitive: K::Primitive,
79}
80
81impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
82where
83 B: Backend,
84 K: BasicOps<B>,
85 T: Into<TensorData>,
86{
87 fn from(value: T) -> Self {
88 Tensor::from_data(value.into(), &Default::default())
89 }
90}
91
92impl<B, const D: usize, K> Tensor<B, D, K>
93where
94 B: Backend,
95 K: BasicOps<B>,
96 K::Elem: Element,
97{
98 /// Executes an operation on the tensor and modifies its value.
99 ///
100 /// # Notes
101 ///
102 /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
103 /// no other reference pointing to the same tensor.
104 ///
105 /// Wrapping operations with inplace is not an optimization, it's mainly there if you
106 /// want to mutate a tensor by using owned operations. A plausible usage would be to
107 /// update the weights of a mutable model reference.
108 pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
109 let mut tensor_owned = Tensor::empty([0; D], &self.device());
110 core::mem::swap(&mut tensor_owned, self);
111
112 let mut tensor_new = func(tensor_owned);
113 core::mem::swap(&mut tensor_new, self);
114 }
115
116 /// Converts the tensor into a primitive tensor.
117 pub fn into_primitive(self) -> K::Primitive {
118 self.primitive
119 }
120
121 /// Converts from a primitive tensor into a tensor.
122 pub fn from_primitive(tensor: K::Primitive) -> Self {
123 Self::new(tensor)
124 }
125
126 /// Returns the number of dimensions of the tensor.
127 pub fn rank(&self) -> usize {
128 self.primitive.rank()
129 }
130
131 /// Returns the tensor primitive data type.
132 ///
133 /// # Note
134 /// Some element types are encoded in different primitive types depending on the backend
135 /// (e.g., bool could be encoded as `u8` or `u32`).
136 pub fn dtype(&self) -> DType {
137 self.primitive.dtype()
138 }
139
140 /// Create an empty tensor of the given shape.
141 ///
142 /// # Arguments
143 ///
144 /// - `shape`: The shape of the tensor.
145 /// - `device`: The device where the tensor will be created.
146 ///
147 /// # Example
148 /// ```rust
149 /// use burn_tensor::backend::Backend;
150 /// use burn_tensor::Tensor;
151 ///
152 /// fn example<B: Backend>() {
153 /// let device = Default::default();
154 /// // Create an empty tensor with dimensions [2, 3, 4].
155 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
156 /// }
157 /// ```
158 pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
159 let shape = shape.into();
160 check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
161 Self::new(K::empty(shape, device, K::Elem::dtype()))
162 }
163
164 /// Create a tensor of the given shape where each element is equal to the provided value.
165 ///
166 /// # Example
167 ///
168 /// ```rust
169 /// use burn_tensor::backend::Backend;
170 /// use burn_tensor::{Tensor, Shape};
171 ///
172 /// fn example<B: Backend>() {
173 /// let device = B::Device::default();
174 /// let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
175 /// println!("{tensor}");
176 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
177 /// }
178 /// ```
179 pub fn full<S: Into<Shape>, E: ElementConversion>(
180 shape: S,
181 fill_value: E,
182 device: &B::Device,
183 ) -> Self {
184 let shape = shape.into();
185 check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
186 Self::new(K::full(shape, fill_value, device, K::Elem::dtype()))
187 }
188
189 /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
190 /// filled with the provided value.
191 ///
192 /// # Example
193 ///
194 /// ```rust
195 /// use burn_tensor::backend::Backend;
196 /// use burn_tensor::{Tensor, Shape};
197 ///
198 /// fn example<B: Backend>() {
199 /// let device = B::Device::default();
200 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
201 /// let tensor = tensor.full_like(5.0);
202 /// println!("{tensor}");
203 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
204 /// }
205 /// ```
206 pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
207 Self::new(K::full(
208 self.shape(),
209 fill_value,
210 &self.device(),
211 self.dtype(),
212 ))
213 }
214
215 /// Returns the dimensions of the current tensor.
216 ///
217 /// # Example
218 /// ```rust
219 /// use burn_tensor::backend::Backend;
220 /// use burn_tensor::Tensor;
221 ///
222 /// fn example<B: Backend>() {
223 /// let device = Default::default();
224 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
225 /// let dims = tensor.dims(); // [2, 3, 4]
226 /// println!("{dims:?}");
227 /// }
228 /// ```
229 pub fn dims(&self) -> [usize; D] {
230 Self::shape(self).dims()
231 }
232
233 /// Returns the shape of the current tensor.
234 ///
235 /// # Example
236 /// ```rust
237 /// use burn_tensor::backend::Backend;
238 /// use burn_tensor::Tensor;
239 ///
240 /// fn example<B: Backend>() {
241 /// let device = Default::default();
242 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
243 /// // Shape { dims: [2, 3, 4] }
244 /// let shape = tensor.shape();
245 /// }
246 /// ```
247 pub fn shape(&self) -> Shape {
248 self.primitive.shape()
249 }
250
251 /// Reshape the tensor to have the given shape.
252 ///
253 /// The tensor has the same data and number of elements as the input.
254 ///
255 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
256 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
257 ///
258 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
259 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
260 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
261 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
262 /// with [1, 3, 4] dimensions to [1, 12].
263 ///
264 /// # Arguments
265 /// - `shape`: The new shape of the tensor.
266 ///
267 /// # Panics
268 /// - If the tensor contains more than one `-1` in the shape.
269 /// - If the tensor contains values that are not positive (other than -1).
270 /// - If the shape does not match the number of elements of the original shape.
271 ///
272 /// # Example
273 ///
274 /// ```rust
275 /// use burn_tensor::backend::Backend;
276 /// use burn_tensor::Tensor;
277 ///
278 /// fn example<B: Backend>() {
279 /// let device = Default::default();
280 /// // Create a tensor with dimensions [2, 3, 4]
281 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
282 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
283 /// let reshaped = tensor.reshape([2, -1]);
284 /// println!("{reshaped}");
285 /// }
286 /// ```
287 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
288 // Convert reshape args to shape
289 let shape = shape.into_shape(&self);
290 Tensor::new(K::reshape(self.primitive, shape))
291 }
292
293 /// Transpose the tensor.
294 ///
295 /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
296 /// applied on the last two dimensions. For example, the transpose of a tensor with shape
297 /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
298 ///
299 /// See also [`permute`](Tensor::permute).
300 ///
301 /// # Arguments
302 ///
303 /// * `tensor` - The tensor to transpose.
304 ///
305 /// # Returns
306 ///
307 /// The transposed tensor.
308 ///
309 /// # Example
310 ///
311 /// ```rust
312 /// use burn_tensor::backend::Backend;
313 /// use burn_tensor::Tensor;
314 ///
315 /// fn example<B: Backend>() {
316 /// let device = Default::default();
317 /// // Create a 2D tensor of shape [2, 3]
318 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
319 ///
320 /// // Transpose the tensor:
321 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
322 /// // The resulting tensor will have dimensions [3, 2].
323 /// let transposed = tensor.transpose();
324 /// println!("{transposed}");
325 /// }
326 /// ```
327 pub fn transpose(self) -> Tensor<B, D, K> {
328 Tensor::new(K::transpose(self.primitive))
329 }
330
331 /// Alias for `transpose`.
332 #[inline(always)]
333 pub fn t(self) -> Tensor<B, D, K> {
334 self.transpose()
335 }
336
337 /// Swaps two dimensions of a tensor.
338 ///
339 /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
340 ///
341 /// # Arguments
342 ///
343 /// * `tensor` - The tensor to swap the dimensions of.
344 /// * `dim1` - The first dimension to swap, supports negative indexing.
345 /// * `dim2` - The second dimension to swap, supports negative indexing.
346 ///
347 /// # Returns
348 ///
349 /// The tensor with the dimensions swapped.
350 ///
351 /// # Panics
352 ///
353 /// When dimensions are out of bounds.
354 ///
355 /// # Example
356 ///
357 /// ```rust
358 /// use burn_tensor::backend::Backend;
359 /// use burn_tensor::Tensor;
360 ///
361 /// fn example<B: Backend>() {
362 /// let device = Default::default();
363 /// // Create a 2D tensor of shape [2, 3]
364 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
365 ///
366 /// // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
367 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
368 /// // The resulting tensor will have dimensions [3, 2].
369 /// let swapped = tensor.swap_dims(0, -1);
370 /// println!("{swapped}");
371 /// }
372 /// ```
373 pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
374 where
375 Dim1: AsIndex,
376 Dim2: AsIndex,
377 {
378 let dim1 = canonicalize_dim(dim1, D, false);
379 let dim2 = canonicalize_dim(dim2, D, false);
380 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
381 if dim1 == dim2 {
382 self
383 } else {
384 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
385 }
386 }
387
388 /// Permute the dimensions of the tensor.
389 ///
390 /// This is a no-op when the resolved `axes` match the current order.
391 ///
392 /// # Arguments
393 ///
394 /// * `axes` - The new order of the dimensions. The length of the axes
395 /// must be equal to the number of dimensions of the tensor.
396 /// The values must be unique and in the range of the number of dimensions.
397 /// The values can be negative, in which case they are used as an offset from the end.
398 ///
399 /// # Returns
400 ///
401 /// The tensor with the dimensions permuted.
402 ///
403 /// # Example
404 ///
405 /// ```rust
406 /// use burn_tensor::backend::Backend;
407 /// use burn_tensor::Tensor;
408 ///
409 /// fn example<B: Backend>() {
410 /// let device = Default::default();
411 /// // Create a 2D tensor of shape [3, 2]
412 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
413 ///
414 /// // Permute the dimensions 1 and 0:
415 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
416 /// // The resulting tensor will have dimensions [3, 2].
417 /// let permuted = tensor.permute([1, 0]);
418 /// println!("{permuted}");
419 /// }
420 /// ```
421 pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
422 where
423 Dim: AsIndex,
424 {
425 let mut no_op = true;
426 let mut fixed_axes = [0; D];
427 for (i, axis) in axes.into_iter().enumerate() {
428 let dim = canonicalize_dim(axis, D, false);
429 no_op &= dim == i;
430 fixed_axes[i] = dim;
431 }
432
433 if no_op {
434 self
435 } else {
436 check!(TensorCheck::permute(fixed_axes));
437 Tensor::new(K::permute(self.primitive, &fixed_axes))
438 }
439 }
440
441 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
442 ///
443 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
444 /// at the positions not specified in destination.
445 ///
446 /// # Arguments
447 ///
448 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
449 /// The values can be negative, in which case they are used as an offset from the end.
450 ///
451 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
452 ///
453 /// # Panics
454 ///
455 /// - If the source and destination dimensions are not of the same length.
456 /// - If the source and destination vectors contain duplicate values.
457 /// - If the source and destination vectors contain values that are out of bounds.
458 ///
459 /// # Returns
460 ///
461 /// The tensor with the dimensions moved.
462 ///
463 /// # Example
464 ///
465 /// ```rust
466 /// use burn_tensor::backend::Backend;
467 /// use burn_tensor::Tensor;
468 ///
469 /// fn example<B: Backend>() {
470 /// let device = Default::default();
471 /// // Create a 3D tensor of shape [3, 2, 1]
472 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
473 ///
474 /// // Move the dimensions 0 and 1:
475 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
476 /// // The resulting tensor will have dimensions [2, 3, 1].
477 /// let moved = tensor.movedim(1, 0);
478 /// println!("{moved}");
479 /// }
480 /// ```
481 ///
482 /// # Note
483 ///
484 /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
485 /// for it
486 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
487 let source_dims = src.into_dim_vec::<D>();
488 let destination_dims = dst.into_dim_vec::<D>();
489
490 check!(TensorCheck::movedim_args_length(
491 &source_dims,
492 &destination_dims
493 ));
494
495 let mut m = [-1; D];
496 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
497 m[d] = s as isize;
498 }
499 let mut axes: [isize; D] = [0; D];
500 let mut source_i = 0;
501 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
502 *item = if m[dest_i] != -1 {
503 m[dest_i]
504 } else {
505 while source_dims.contains(&source_i) {
506 source_i += 1;
507 }
508 let result = source_i as isize;
509 source_i += 1;
510 result
511 };
512 }
513
514 self.permute(axes)
515 }
516
517 /// Reverse the order of elements in the tensor along the given dimensions.
518 ///
519 /// # Arguments
520 ///
521 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
522 /// The values can be negative, in which case they are used as an offset from the end.
523 ///
524 /// # Returns
525 ///
526 /// The tensor with the axes flipped.
527 ///
528 /// # Example
529 ///
530 /// ```rust
531 /// use burn_tensor::backend::Backend;
532 /// use burn_tensor::Tensor;
533 ///
534 /// fn example<B: Backend>() {
535 /// let device = Default::default();
536 /// // Create a 2D tensor with dimensions [4, 3]
537 /// let tensor = Tensor::<B, 2>::from_data(
538 /// [
539 /// [3.0, 4.9, 2.0],
540 /// [2.0, 1.9, 3.0],
541 /// [4.0, 5.9, 8.0],
542 /// [1.4, 5.8, 6.0],
543 /// ],
544 /// &device,
545 /// );
546 ///
547 /// // Flip the elements in dimensions 0 and 1:
548 /// // [[6.0, 5.8, 1.4],
549 /// // [8.0, 5.9, 4.0],
550 /// // [3.0, 1.9, 2.0],
551 /// // [2.0, 4.9, 3.0]]
552 /// // The resulting tensor will have dimensions [4, 3].
553 /// let flipped = tensor.flip([0, 1]);
554 /// println!("{flipped}");
555 /// }
556 /// ```
557 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
558 // Convert the axes to usize and handle negative values without using vector
559 let mut transformed_axes: [usize; N] = [0; N];
560 for (i, &x) in axes.iter().enumerate() {
561 transformed_axes[i] = if x < 0 {
562 (D as isize + x) as usize
563 } else {
564 x as usize
565 };
566 }
567
568 // Check if the axes are valid
569 check!(TensorCheck::flip(D, &transformed_axes));
570
571 Tensor::new(K::flip(self.primitive, &transformed_axes))
572 }
573
574 /// Flatten the tensor along a given range of dimensions.
575 ///
576 /// This function collapses the specified range of dimensions into a single dimension,
577 /// effectively flattening the tensor in that range.
578 ///
579 /// # Arguments
580 ///
581 /// - `start_dim`: The starting dimension of the range to be flattened,
582 /// supports negative indexing.
583 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
584 /// supports negative indexing.
585 ///
586 /// # Type Parameters
587 ///
588 /// - `D2`: The resulting number of dimensions in the flattened tensor.
589 ///
590 /// # Returns
591 ///
592 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
593 ///
594 /// # Example
595 ///
596 /// ```rust
597 ///
598 /// use burn_tensor::backend::Backend;
599 /// use burn_tensor::{Tensor, Shape};
600 ///
601 /// fn example<B: Backend>() {
602 /// let device = Default::default();
603 /// // Create a 3D tensor with dimensions [2, 3, 4]
604 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
605 ///
606 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
607 /// // The resulting tensor will have dimensions [2, 12]
608 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
609 /// println!("{flattened}");
610 /// }
611 /// ```
612 pub fn flatten<const D2: usize>(
613 self,
614 start_dim: impl AsIndex,
615 end_dim: impl AsIndex,
616 ) -> Tensor<B, D2, K> {
617 let start_dim = canonicalize_dim(start_dim, D, false);
618 let end_dim = canonicalize_dim(end_dim, D, false);
619 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
620
621 let current_dims = self.shape().dims;
622 let mut new_dims: [usize; D2] = [0; D2];
623 let mut flatten_dims = 1;
624
625 for i in current_dims[start_dim..=end_dim].iter() {
626 flatten_dims *= i;
627 }
628
629 new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]);
630 new_dims[start_dim] = flatten_dims;
631 new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]);
632
633 Tensor::new(K::reshape(self.primitive, new_dims.into()))
634 }
635
636 /// Squeeze the tensor along all dimensions, removing dimensions
637 /// of size one, and effectively reducing the rank of the tensor.
638 ///
639 /// # Type Parameters
640 ///
641 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
642 ///
643 /// # Returns
644 ///
645 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
646 ///
647 /// # Example
648 ///
649 /// ```rust
650 ///
651 /// use burn_tensor::backend::Backend;
652 /// use burn_tensor::{Tensor, Shape};
653 ///
654 /// fn example<B: Backend>() {
655 /// let device = Default::default();
656 /// // Create a 4D tensor with dimensions [1, 3, 1, 3]
657 /// let tensor = Tensor::<B, 4>::from_data(
658 /// [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
659 /// &device,
660 /// );
661 ///
662 /// // Squeeze the tensor dimensions.
663 /// // The resulting tensor will have dimensions [3, 3].
664 /// let squeezed = tensor.squeeze::<2>();
665 /// println!("{squeezed}");
666 /// }
667 /// ```
668 pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
669 let new_dims = self
670 .shape()
671 .dims
672 .iter()
673 .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
674 .collect::<Vec<_>>();
675 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
676
677 Tensor::new(K::reshape(self.primitive, new_dims.into()))
678 }
679
680 /// Squeeze the tensor along the given dimension, removing the specified dimension
681 /// of size one, and effectively reducing the rank of the tensor by one.
682 ///
683 /// # Arguments
684 ///
685 /// - `dim`: The dimension to be squeezed.
686 ///
687 /// # Type Parameters
688 ///
689 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
690 ///
691 /// # Panics
692 ///
693 /// If the size in the squeezed dimension is not 1.
694 ///
695 /// # Returns
696 ///
697 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
698 ///
699 /// # Example
700 ///
701 /// ```rust
702 ///
703 /// use burn_tensor::backend::Backend;
704 /// use burn_tensor::{Tensor, Shape};
705 ///
706 /// fn example<B: Backend>() {
707 /// let device = Default::default();
708 /// // Create a 3D tensor with dimensions [3, 1, 3]
709 /// let tensor = Tensor::<B, 3>::from_data(
710 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
711 /// &device,
712 /// );
713 ///
714 /// // Squeeze the dimension 1.
715 /// // The resulting tensor will have dimensions [3, 3].
716 /// let squeezed = tensor.squeeze_dim::<2>(1);
717 /// println!("{squeezed}");
718 /// }
719 /// ```
720 pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
721 check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
722
723 let current_dims = self.shape().dims;
724 let mut new_dims: [usize; D2] = [0; D2];
725
726 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
727 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
728
729 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
730 Tensor::new(K::reshape(self.primitive, new_dims.into()))
731 }
732
733 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
734 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
735 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
736 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
737 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
738 /// from the back.
739 ///
740 /// # Arguments
741 ///
742 /// - `dims`: The dimension(s) to be squeezed.
743 ///
744 /// # Type Parameters
745 ///
746 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
747 ///
748 /// # Returns
749 ///
750 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
751 ///
752 /// # Example
753 ///
754 /// ```rust
755 ///
756 /// use burn_tensor::backend::Backend;
757 /// use burn_tensor::{Tensor, Shape};
758 ///
759 /// fn example<B: Backend>() {
760 /// let device = Default::default();
761 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
762 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
763 ///
764 /// // Squeeze the dimensions 1 and 3.
765 /// // The resulting tensor will have dimensions [2, 4].
766 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
767 /// println!("{squeezed}");
768 /// }
769 /// ```
770 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
771 let current_dims = self.shape().dims;
772 let mut dim_indices: Vec<usize>;
773
774 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
775 if dims.is_empty() {
776 dim_indices = current_dims
777 .iter()
778 .enumerate()
779 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
780 .collect();
781 } else {
782 // If negative dims, count from the back
783 dim_indices = dims
784 .iter()
785 .map(|&d| {
786 if d < 0 {
787 (current_dims.len() as isize + d) as usize
788 } else {
789 d as usize
790 }
791 })
792 .collect();
793 }
794
795 // Sort indices and remove duplicates
796 dim_indices.sort_unstable();
797 dim_indices.dedup();
798
799 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
800 check!(TensorCheck::squeeze_dims_input::<D2>(
801 &dim_indices,
802 ¤t_dims
803 ));
804
805 // Calculate new dimensions
806 let mut new_dims = Vec::new();
807 for (index, &dim_size) in current_dims.iter().enumerate() {
808 // Exclude the dimension if it's explicitly marked for squeezing
809 if dim_indices.contains(&index) {
810 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
811 continue;
812 }
813 new_dims.push(dim_size);
814 }
815
816 // Check that after squeezing, we still respect the D2 size
817 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
818
819 Tensor::new(K::reshape(self.primitive, new_dims.into()))
820 }
821
822 /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
823 ///
824 /// # Type Parameters
825 ///
826 /// - `D2`: The resulting number of dimensions in the unsqueezed tensor.
827 ///
828 /// # Panics
829 ///
830 /// If the output size `D2` is smaller than the current number of dimensions.
831 ///
832 /// # Returns
833 ///
834 /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
835 ///
836 /// # Example
837 ///
838 /// ```rust
839 /// use burn_tensor::backend::Backend;
840 /// use burn_tensor::{Tensor, Shape};
841 ///
842 /// fn example<B: Backend>() {
843 /// let device = Default::default();
844 /// // Create a 2D tensor with dimensions [3, 3]
845 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
846 /// // Unsqueeze the tensor up to 4 dimensions.
847 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
848 /// let unsqueezed = tensor.unsqueeze::<4>();
849 /// println!("{unsqueezed}");
850 /// }
851 /// ```
852 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
853 check!(TensorCheck::unsqueeze::<D, D2>());
854
855 let mut dims = [1; D2];
856 let num_ones = D2 - D;
857 let shape = self.shape();
858
859 dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
860
861 let shape = Shape::new(dims);
862 self.reshape(shape)
863 }
864
865 /// Creates a new tensor with a dimension of size one inserted at the specified position.
866 ///
867 /// # Example
868 ///
869 /// ```rust
870 /// use burn_tensor::backend::Backend;
871 /// use burn_tensor::{Tensor, Shape};
872 ///
873 /// fn example<B: Backend>() {
874 /// let device = Default::default();
875 /// // Create a 2D tensor with dimensions [3, 3]
876 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
877 /// // Unsqueeze the dimension 1.
878 /// // The resulting tensor will have dimensions [3, 1, 3].
879 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
880 /// println!("{unsqueezed}");
881 /// }
882 /// ```
883 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
884 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
885
886 let mut dims = [1; D2];
887 let shape = self.shape();
888
889 dims[0..dim].copy_from_slice(&shape[0..dim]);
890
891 if dim < D {
892 dims[dim] = 1;
893 dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
894 } else {
895 dims[dim] = 1;
896 }
897
898 let shape = Shape::new(dims);
899 self.reshape(shape)
900 }
901
902 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
903 /// The indices can be negative, in which case they are counted from the last to the first dimension.
904 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
905 /// is the number of duplicates.
906 /// # Example
907 ///
908 /// ```rust
909 /// use burn_tensor::backend::Backend;
910 /// use burn_tensor::{Tensor, Shape};
911 ///
912 /// fn example<B: Backend>() {
913 /// let device = Default::default();
914 /// // Create a 3D tensor with dimensions [3, 4, 5]
915 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
916 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
917 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
918 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
919 /// println!("{unsqueezed}");
920 /// }
921 /// ```
922 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
923 let mut new_dims = [1; D2];
924 let old_dims = self.shape().dims;
925 //for checking if the dimension is in the acceptable range
926
927 //part 1: convert the negative indices to positive
928 let mut neg_offset = D2;
929 let mut dim_indices = axes
930 .iter()
931 .map(|d| {
932 // check if the dimension is in the acceptable range
933 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
934 (if *d < 0 {
935 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
936 d + neg_offset as isize + 1
937 } else {
938 *d
939 }) as usize
940 })
941 .collect::<Vec<usize>>();
942
943 //sort the indices
944 dim_indices.sort_unstable();
945
946 //Now use this to copy the chunks of the dims
947 let mut prev_idx: usize = 0;
948 let mut current_left_b: usize = 0;
949 let mut current_right_b: usize = 0;
950 let mut offset: usize = 0;
951 dim_indices.iter().for_each(|d| {
952 //check if there is space for at least one dimension
953 if prev_idx < *d {
954 current_right_b = *d - offset;
955 //copy the chunks of the dims
956 if current_right_b < D {
957 new_dims[prev_idx..*d]
958 .copy_from_slice(&old_dims[current_left_b..current_right_b]);
959 } else {
960 new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
961 }
962 prev_idx = *d + 1;
963 //offset is equal to the number of extracted elements from the original shape
964 offset += current_right_b - current_left_b;
965 current_left_b = current_right_b;
966 } else {
967 //it's sorted so the only reason this would happen
968 //is if multiple indices are the same
969 prev_idx += 1;
970 }
971 });
972 //copy over anything past the index of the last new dimension
973 if current_left_b < D {
974 new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
975 }
976
977 //lastly, create the shape and reshape
978 let shape = Shape::new(new_dims);
979 self.reshape(shape)
980 }
981
982 /// Roll operation along a specific dimension; wrapping around the elements.
983 ///
984 /// ## Parameters
985 ///
986 /// - `shift`: The roll extent; supports negative values and wraps around.
987 /// - `dim`: The dimension to roll; supports negative indexing.
988 ///
989 /// ## Returns
990 ///
991 /// A new tensor with the specified dimension rolled by the given shift amount.
992 pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
993 where
994 Shift: AsIndex,
995 Dim: AsIndex,
996 {
997 let dim = canonicalize_dim(dim, D, false);
998 let size = self.shape().dims[dim];
999 if size == 0 {
1000 // If the dimension is empty, return the tensor as is.
1001 return self;
1002 }
1003
1004 let shift = wrap_index(shift, size);
1005 if shift == 0 {
1006 // If the shift is zero, return the tensor as is.
1007 return self;
1008 }
1009
1010 self.unchecked_roll_dim(shift, dim)
1011 }
1012
1013 /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1014 ///
1015 /// ## Parameters
1016 ///
1017 /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1018 /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1019 ///
1020 /// ## Returns
1021 ///
1022 /// A new tensor with the specified dimension rolled by the given shift amount.
1023 #[inline(always)]
1024 fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1025 #[cfg(debug_assertions)]
1026 {
1027 let size = self.shape().dims[dim];
1028 assert!(
1029 0 < shift && shift < size,
1030 "Expected: 0 < shift < size: found shift={shift}, size={size}",
1031 );
1032 assert!(
1033 dim < self.shape().num_dims(),
1034 "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1035 );
1036 }
1037
1038 Tensor::cat(
1039 vec![
1040 self.clone().slice_dim(dim, shift..),
1041 self.slice_dim(dim, ..shift),
1042 ],
1043 dim,
1044 )
1045 }
1046
1047 /// Roll operation.
1048 ///
1049 /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1050 ///
1051 /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1052 ///
1053 /// ## Parameters
1054 ///
1055 /// - `shifts`: A slice of shifts corresponding to each dimension;
1056 /// supports negative values and wraps around.
1057 /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1058 ///
1059 /// ## Returns
1060 ///
1061 /// A new tensor with the specified dimensions rolled by the given shifts.
1062 pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1063 where
1064 Shift: AsIndex,
1065 Dim: AsIndex,
1066 {
1067 assert_eq!(
1068 dims.len(),
1069 shifts.len(),
1070 "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1071 );
1072
1073 // This is a fair amount of complexity, which could be replaced
1074 // by a simple canonicalization of `dims` and wrapping of `shifts`.
1075 // The work is done here to ensure that any roll operation
1076 // which could be a no-op is a no-op; simplifying the accounting
1077 // needed by backend-specific implementations of the inner roll op.
1078
1079 let item_count = dims.len();
1080
1081 let shape = self.shape().dims;
1082
1083 // Accumulate the effective shifts for each dimension.
1084 let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1085 for i in 0..item_count {
1086 let dim = canonicalize_dim(dims[i], D, false);
1087 accumulated_shifts[dim] += shifts[i].index();
1088 }
1089
1090 // Do this after we've checked the validity of `dims` and `shifts`.
1091 if self.shape().num_elements() == 0 {
1092 // If the tensor is empty, return it as is.
1093 return self;
1094 }
1095
1096 // Wrap the accumulated shifts, and filter out empty dimensions.
1097 let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1098 let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1099 for dim in 0..shape.len() {
1100 // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1101 let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1102 if shift == 0 {
1103 continue;
1104 }
1105
1106 effective_dims.push(dim);
1107 effective_shifts.push(shift);
1108 }
1109
1110 // If no shifts are needed, return the original tensor.
1111 if effective_shifts.is_empty() {
1112 return self;
1113 }
1114
1115 // At this point:
1116 // - `dims` contains the effective dimensions to roll, in index order,
1117 // - `shifts` contains the effective usize shifts for each dimension.
1118 // - Every shift is non-zero, and less than the size of the corresponding dimension.
1119 self.unchecked_roll(&effective_shifts, &effective_dims)
1120 }
1121
1122 /// `roll` internal implementation.
1123 ///
1124 /// ## Parameters
1125 ///
1126 /// - `shifts`: A slice of shifts corresponding to each dimension;
1127 /// must be non-empty, the same length as `dims`, and all ``1..<size>``.
1128 /// - `dims`: A slice of dimensions to roll; must be non-empty;
1129 /// the same length as `shifts`, and must not contain repeats.
1130 ///
1131 /// ## Panics
1132 ///
1133 /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1134 ///
1135 /// ## Returns
1136 ///
1137 /// A new tensor with the specified dimensions rolled by the given shifts.
1138 #[inline(always)]
1139 fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1140 #[cfg(debug_assertions)]
1141 {
1142 assert!(!shifts.is_empty());
1143 assert_eq!(
1144 shifts.len(),
1145 dims.len(),
1146 "Shifts and dimensions must align; found {} shifts and {} dims",
1147 shifts.len(),
1148 dims.len()
1149 );
1150
1151 let mut unique_dims = dims.to_vec();
1152 unique_dims.dedup();
1153
1154 assert_eq!(
1155 unique_dims.len(),
1156 dims.len(),
1157 "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1158 unique_dims.len(),
1159 dims.len()
1160 )
1161 }
1162
1163 let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1164
1165 if dims.len() == 1 {
1166 x
1167 } else {
1168 x.unchecked_roll(&shifts[1..], &dims[1..])
1169 }
1170 }
1171
1172 /// Returns a tensor containing the elements selected from the given slices.
1173 ///
1174 /// This method provides flexible tensor slicing with support for various range types,
1175 /// negative indices, and stepped slicing. The method accepts both single slices and
1176 /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1177 ///
1178 /// # Arguments
1179 ///
1180 /// * `slices` - Can be:
1181 /// - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1182 /// - An array of ranges (e.g., `[0..2, 1..4]`)
1183 /// - The [`s!`] macro output for advanced slicing with steps
1184 ///
1185 /// # Behavior
1186 ///
1187 /// - Supports partial and full slicing in any number of dimensions
1188 /// - Handles negative indices by wrapping from the end (-1 is the last element)
1189 /// - Automatically clamps ranges that exceed tensor dimensions
1190 /// - Supports stepped slicing for selecting every nth element
1191 /// - Negative steps reverse the selection order
1192 ///
1193 /// # Panics
1194 ///
1195 /// - If the number of slices exceeds the tensor's dimensions
1196 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1197 /// - If a step is zero
1198 ///
1199 /// # Examples
1200 ///
1201 /// ```rust
1202 /// use burn_tensor::backend::Backend;
1203 /// use burn_tensor::{Tensor, Shape, s};
1204 ///
1205 /// fn example<B: Backend>() {
1206 /// let device = B::Device::default();
1207 ///
1208 /// // Single dimension slicing - no brackets needed!
1209 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1210 /// let slice = tensor.clone().slice(2..8); // Simple range
1211 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1212 ///
1213 /// // Using s! macro for single dimension with step
1214 /// let slice = tensor.clone().slice(s![0..10;2]); // Every 2nd element
1215 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1216 ///
1217 /// // Reverse a dimension with negative step
1218 /// let slice = tensor.slice(s![..;-1]); // Reverse entire tensor
1219 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1220 ///
1221 /// // Multi-dimensional slicing
1222 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1223 ///
1224 /// // Array syntax for simple ranges
1225 /// let slice = tensor.clone().slice([1..3, 2..5]);
1226 /// assert_eq!(slice.dims(), [2, 3]);
1227 ///
1228 /// // Advanced multi-dimensional with s! macro
1229 /// let slice = tensor.clone().slice(s![0..4;2, ..;-1]); // Every 2nd row, reverse columns
1230 /// assert_eq!(slice.dims(), [2, 6]);
1231 ///
1232 /// // Complex 3D example with mixed slice types
1233 /// let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1234 /// let slice = tensor.slice(s![1..3, ..;2, -3..]); // Rows 1-2, every 2nd col, last 3 depth
1235 /// assert_eq!(slice.dims(), [2, 3, 3]);
1236 ///
1237 /// // Using negative indices
1238 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1239 /// let slice = tensor.slice(s![-2.., ..-1]); // Last 2 rows, all but last column
1240 /// assert_eq!(slice.dims(), [2, 5]);
1241 /// }
1242 /// ```
1243 ///
1244 /// # See Also
1245 ///
1246 /// - [`s!`] - The recommended macro for creating complex slice specifications
1247 /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1248 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1249 /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1250 ///
1251 /// [`s!`]: crate::s!
1252 pub fn slice<const D2: usize, S>(self, slices: S) -> Self
1253 where
1254 S: SliceArg<D2>,
1255 {
1256 let shape = self.shape();
1257 let slices = slices.into_slices(shape.clone());
1258
1259 // Validate slices
1260 check!(TensorCheck::slice::<D, D2>(&shape, &slices));
1261
1262 // Use the slice method that supports steps
1263 Self::new(K::slice(self.primitive, &slices))
1264 }
1265
1266 /// Assigns values to a slice of the tensor and returns the updated tensor.
1267 ///
1268 /// This method supports advanced slicing with steps, including negative steps for reverse
1269 /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1270 /// providing powerful syntax for complex patterns.
1271 ///
1272 /// # Arguments
1273 ///
1274 /// * `slices` - Slice specification (same format as `slice` method)
1275 /// * `values` - Tensor with values to assign (must match slice dimensions)
1276 ///
1277 /// # Panics
1278 ///
1279 /// - If slices exceed tensor dimensions
1280 /// - If values dimensions don't match the selected slice shape
1281 /// - If a step is zero
1282 ///
1283 /// # Examples
1284 ///
1285 /// ```rust
1286 /// use burn_tensor::backend::Backend;
1287 /// use burn_tensor::{Tensor, s};
1288 ///
1289 /// fn example<B: Backend>() {
1290 /// let device = B::Device::default();
1291 ///
1292 /// // Simple assignment to a sub-region
1293 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1294 /// let values = Tensor::<B, 2>::ones([2, 3], &device);
1295 /// tensor = tensor.slice_assign([1..3, 2..5], values);
1296 /// // Now tensor[1..3, 2..5] contains ones
1297 ///
1298 /// // Single dimension assignment with step
1299 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1300 /// let values = Tensor::<B, 1>::ones([5], &device);
1301 /// tensor = tensor.slice_assign(s![0..10;2], values);
1302 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1303 ///
1304 /// // Reverse assignment with negative step
1305 /// let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1306 /// let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1307 /// tensor = tensor.slice_assign(s![..;-1], values);
1308 /// // Assigns in reverse: [14, 13, 12, 11, 10]
1309 ///
1310 /// // Complex multi-dimensional assignment
1311 /// let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1312 /// let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1313 /// tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1314 /// // Assigns to every 2nd row, every 2nd column, last 3 in depth
1315 ///
1316 /// // Mixed syntax example
1317 /// let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1318 /// let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1319 /// tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1320 /// // Creates a checkerboard pattern with ones
1321 /// }
1322 /// ```
1323 ///
1324 /// # See Also
1325 ///
1326 /// - [`s!`] - The recommended macro for creating complex slice specifications
1327 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1328 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1329 ///
1330 /// [`s!`]: crate::s!
1331 pub fn slice_assign<const D2: usize, S>(self, slices: S, values: Self) -> Self
1332 where
1333 S: SliceArg<D2>,
1334 {
1335 let shape = self.shape();
1336 let slices = slices.into_slices(shape.clone());
1337
1338 check!(TensorCheck::slice_assign::<D, D2>(
1339 &shape,
1340 &values.shape(),
1341 &slices
1342 ));
1343
1344 Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1345 }
1346
1347 /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1348 ///
1349 /// Like other slice methods, accepts both single slices and arrays. However, this method
1350 /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1351 /// with a constant tensor for stepped patterns.
1352 ///
1353 /// # Arguments
1354 ///
1355 /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1356 /// * `value` - The value to fill the slice with
1357 ///
1358 /// # Panics
1359 ///
1360 /// - If slices exceed tensor dimensions
1361 /// - If any slice has a step != 1 (not yet supported)
1362 ///
1363 /// # Examples
1364 ///
1365 /// ```rust
1366 /// use burn_tensor::backend::Backend;
1367 /// use burn_tensor::{Tensor, s};
1368 ///
1369 /// fn example<B: Backend>() {
1370 /// let device = B::Device::default();
1371 ///
1372 /// // Simple fill for a single dimension
1373 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1374 /// tensor = tensor.slice_fill(2..5, 1.0);
1375 /// // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1376 ///
1377 /// // Multi-dimensional fill
1378 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1379 /// tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1380 /// // Fills the rectangle at rows 1-2, columns 2-4 with -1
1381 ///
1382 /// // Using negative indices
1383 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1384 /// tensor = tensor.slice_fill(-3.., 2.0);
1385 /// // Fills the last 3 elements with 2.0
1386 ///
1387 /// // Complex multi-dimensional example
1388 /// let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1389 /// tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1390 /// // Sets rows 1-2, all columns, last 2 in depth to 0
1391 ///
1392 /// // Stepped slicing is supported
1393 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1394 /// tensor = tensor.slice_fill(s![0..10;2], 1.0);
1395 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1396 /// }
1397 /// ```
1398 ///
1399 /// # See Also
1400 ///
1401 /// - [`s!`] - The macro for creating slice specifications with steps
1402 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1403 /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1404 ///
1405 /// [`s!`]: crate::s!
1406 pub fn slice_fill<const D2: usize, S, E: ElementConversion>(self, slices: S, value: E) -> Self
1407 where
1408 S: SliceArg<D2>,
1409 {
1410 let shape = self.shape();
1411 let slices = slices.into_slices(shape.clone());
1412
1413 check!(TensorCheck::slice::<D, D2>(&shape, &slices));
1414
1415 Self::new(K::slice_fill(self.primitive, &slices, value.elem()))
1416 }
1417
1418 /// Returns a new tensor with the specified dimension sliced.
1419 ///
1420 /// # Arguments
1421 ///
1422 /// * `dim`: The dimension to slice.
1423 /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1424 /// slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1425 ///
1426 /// # Returns
1427 ///
1428 /// A new tensor with the specified dimension sliced.
1429 ///
1430 /// # Panics
1431 ///
1432 /// If the slice is out of bounds for the specified dimension.
1433 ///
1434 /// # Examples
1435 ///
1436 /// ```rust
1437 /// # use burn_tensor::{Tensor, s};
1438 /// # use burn_tensor::backend::Backend;
1439 /// #
1440 /// # fn example<B: Backend>() {
1441 /// # let device = B::Device::default();
1442 /// let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1443 ///
1444 /// // Simple range slicing
1445 /// let sliced = tensor.clone().slice_dim(1, 1..3);
1446 /// assert_eq!(sliced.shape().dims, [3, 2, 5]);
1447 ///
1448 /// // Slicing with step - take every 2nd element
1449 /// let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1450 /// assert_eq!(sliced.shape().dims, [3, 4, 3]); // Takes indices 0, 2, 4
1451 ///
1452 /// // Reverse slicing with negative step
1453 /// let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1454 /// assert_eq!(sliced.shape().dims, [3, 4, 5]); // Reverses dimension 1
1455 ///
1456 /// // Select from index 2 with step 3
1457 /// let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1458 /// assert_eq!(sliced.shape().dims, [1, 4, 5]); // Takes only index 2
1459 ///
1460 /// // Select single index (reduces dimension to size 1)
1461 /// let sliced = tensor.slice_dim(0, 1);
1462 /// assert_eq!(sliced.shape().dims, [1, 4, 5]);
1463 /// # }
1464 /// ```
1465 ///
1466 /// # See Also
1467 ///
1468 /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1469 /// - [`s!`] - The macro for creating complex slice specifications
1470 ///
1471 /// [`s!`]: crate::s!
1472 pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1473 where
1474 S: Into<Slice>,
1475 {
1476 check!(TensorCheck::check_dim::<D>(dim));
1477 let slice: Slice = slice.into();
1478
1479 Self::new(K::slice_dim(self.primitive, dim, &slice))
1480 }
1481
1482 /// Returns the device of the current tensor.
1483 pub fn device(&self) -> B::Device {
1484 K::device(&self.primitive)
1485 }
1486
1487 /// Move the tensor to the given device.
1488 pub fn to_device(self, device: &B::Device) -> Self {
1489 Self::new(K::to_device(self.primitive, device))
1490 }
1491
1492 /// Select tensor elements along the given dimension corresponding to the given indices.
1493 ///
1494 /// # Arguments
1495 ///
1496 /// * `dim` - The dimension to select from. Supports negative indexing.
1497 /// * `indices` - The indices of the elements to select.
1498 ///
1499 /// # Example
1500 ///
1501 /// ```rust
1502 /// use burn_tensor::backend::Backend;
1503 /// use burn_tensor::{Tensor, Int};
1504 ///
1505 /// fn example<B: Backend>() {
1506 /// let device = B::Device::default();
1507 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1508 /// let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1509 /// let tensor = tensor.select(0, indices);
1510 /// println!("{tensor}");
1511 /// // [[1.0, -2.0, 3.0]]
1512 /// }
1513 /// ```
1514 pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1515 let dim = canonicalize_dim(dim, D, false);
1516 check!(TensorCheck::select::<D>(dim));
1517 Self::new(K::select(self.primitive, dim, indices))
1518 }
1519
1520 /// Assign the selected elements along the given dimension corresponding to the given indices
1521 /// from the value tensor to the original tensor using sum reduction.
1522 ///
1523 /// # Note
1524 /// For booleans, the sum operator is logical or.
1525 ///
1526 /// # Arguments
1527 ///
1528 /// * `dim` - The dimension along which to select. Supports negative indexing.
1529 /// * `indices` - The indices to select from the tensor.
1530 /// * `values` - The values to assign to the selected indices.
1531 ///
1532 /// # Example
1533 ///
1534 /// Example using a 3D tensor:
1535 ///
1536 /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1537 /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1538 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1539 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1540 ///
1541 /// # Warning
1542 ///
1543 /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1544 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1545 pub fn select_assign(
1546 self,
1547 dim: impl AsIndex,
1548 indices: Tensor<B, 1, Int>,
1549 values: Tensor<B, D, K>,
1550 ) -> Self {
1551 let dim = canonicalize_dim(dim, D, false);
1552 check!(TensorCheck::select_assign::<D>(
1553 dim,
1554 &indices.shape(),
1555 &values.shape()
1556 ));
1557
1558 Self::new(K::select_assign(
1559 self.primitive,
1560 dim,
1561 indices,
1562 values.primitive,
1563 ))
1564 }
1565
1566 /// Converts the data of the current tensor.
1567 ///
1568 /// # Note
1569 ///
1570 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
1571 /// tensors at once. This may improve laziness, especially if executed on a different
1572 /// thread in native environments.
1573 pub fn into_data(self) -> TensorData {
1574 crate::try_read_sync(self.into_data_async()).expect(
1575 "Failed to read tensor data synchronously.
1576 This can happen on platforms that don't support blocking futures like WASM.
1577 If possible, try using into_data_async instead.",
1578 )
1579 }
1580
1581 /// Converts the data of the current tensor.
1582 ///
1583 /// # Note
1584 ///
1585 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
1586 /// tensors at once. This may improve laziness, especially if executed on a different
1587 /// thread in native environments.
1588 pub fn to_data(&self) -> TensorData {
1589 self.clone().into_data()
1590 }
1591
1592 /// Returns the data of the current tensor.
1593 pub async fn into_data_async(self) -> TensorData {
1594 K::into_data_async(self.primitive).await
1595 }
1596
1597 /// Returns the data of the current tensor.
1598 pub async fn to_data_async(&self) -> TensorData {
1599 self.clone().into_data_async().await
1600 }
1601
1602 /// Create a tensor from the given data on the given device.
1603 pub fn from_data<T>(data: T, device: &B::Device) -> Self
1604 where
1605 T: Into<TensorData>,
1606 {
1607 let data = data.into();
1608 check!(TensorCheck::creation_ops::<D>(
1609 "From Data",
1610 data.shape.as_slice()
1611 ));
1612 Self::new(K::from_data(data, device))
1613 }
1614
1615 /// Create a tensor from the given data on the given device enforcing the given data type.
1616 pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
1617 where
1618 T: Into<TensorData>,
1619 {
1620 let data = data.into();
1621 check!(TensorCheck::creation_ops::<D>(
1622 "From Data",
1623 data.shape.as_slice()
1624 ));
1625 Self::new(K::from_data_dtype(data, device, dtype))
1626 }
1627
1628 /// Repeat the tensor along the given dimension.
1629 ///
1630 /// The output tensor has the same shape, except along the given dimension.
1631 ///
1632 /// # Arguments
1633 /// - `dim`: The dimension to repeat.
1634 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1635 ///
1636 /// # Returns
1637 ///
1638 /// A new tensor with the given dimension repeated `times` times.
1639 ///
1640 /// # Example
1641 ///
1642 /// ```rust
1643 /// use burn_tensor::backend::Backend;
1644 /// use burn_tensor::Tensor;
1645 ///
1646 /// fn example<B: Backend>() {
1647 /// let device = Default::default();
1648 /// // Create a 2D tensor with dimensions [3, 2]
1649 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1650 ///
1651 /// // Repeat the tensor along the dimension 0 twice.
1652 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1653 /// // The resulting tensor will have dimensions [6, 2].
1654 /// let repeated = tensor.repeat_dim(0, 2);
1655 /// println!("{repeated}");
1656 /// }
1657 /// ```
1658 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1659 Self::new(K::repeat_dim(self.primitive, dim, times))
1660 }
1661
1662 /// Repeat the tensor along the given dimensions.
1663 /// # Arguments
1664 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1665 ///
1666 /// # Returns
1667 ///
1668 /// A new tensor with the given dimensions repeated `times` times.
1669 ///
1670 /// # Panics
1671 ///
1672 /// If `sizes` contains more elements than the number of dimensions.
1673 ///
1674 /// # Example
1675 ///
1676 /// ```rust
1677 ///
1678 /// use burn_tensor::backend::Backend;
1679 /// use burn_tensor::Tensor;
1680 ///
1681 /// fn example<B: Backend>() {
1682 /// let device = Default::default();
1683 /// // Create a 2D tensor with dimensions [3, 2]
1684 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1685 ///
1686 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1687 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1688 /// // The resulting tensor will have dimensions [6, 2].
1689 /// let repeated = tensor.repeat(&[2, 1]);
1690 /// }
1691 /// ```
1692 pub fn repeat(self, sizes: &[usize]) -> Self {
1693 let mut tensor = self;
1694 for (dim, ×) in sizes.iter().enumerate() {
1695 if times > 1 {
1696 tensor = tensor.repeat_dim(dim, times);
1697 }
1698 }
1699 tensor
1700 }
1701
1702 /// Applies element-wise equal comparison.
1703 ///
1704 /// # Returns
1705 /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1706 ///
1707 /// # Panics
1708 ///
1709 /// If the two tensors don't have the same shape.
1710 ///
1711 /// # Example
1712 ///
1713 /// ```rust
1714 /// use burn_tensor::backend::Backend;
1715 /// use burn_tensor::Tensor;
1716 ///
1717 /// fn example<B: Backend>() {
1718 /// let device = Default::default();
1719 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1720 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1721 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
1722 /// // [[false, true], [true, true], [true, true]]
1723 /// let equal = t1.equal(t2);
1724 /// println!("{equal}");
1725 /// }
1726 /// ```
1727 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1728 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1729 Tensor::new(K::equal(self.primitive, other.primitive))
1730 }
1731
1732 /// Applies element-wise non-equality comparison.
1733 ///
1734 /// # Returns
1735 /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1736 ///
1737 /// # Panics
1738 ///
1739 /// If the two tensors don't have the same shape.
1740 ///
1741 /// # Example
1742 ///
1743 /// ```rust
1744 /// use burn_tensor::backend::Backend;
1745 /// use burn_tensor::Tensor;
1746 ///
1747 /// fn example<B: Backend>() {
1748 /// let device = Default::default();
1749 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1750 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1751 /// // Compare the elements of the two 2D tensors for inequality.
1752 /// // [[true, false], [false, false], [false, false]]
1753 /// let not_equal = t1.not_equal(t2);
1754 /// println!("{not_equal}");
1755 /// }
1756 /// ```
1757 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1758 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1759 Tensor::new(K::not_equal(self.primitive, other.primitive))
1760 }
1761
1762 /// Concatenates all tensors into a new one along the given dimension.
1763 ///
1764 /// # Panics
1765 ///
1766 /// - If `dim` is higher than the rank.
1767 /// - If `tensors` is an empty vector.
1768 /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
1769 ///
1770 /// # Example
1771 ///
1772 /// ```rust
1773 /// use burn_tensor::backend::Backend;
1774 /// use burn_tensor::Tensor;
1775 ///
1776 /// fn example<B: Backend>() {
1777 /// let device = Default::default();
1778 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
1779 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1780 ///
1781 /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
1782 /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
1783 /// // The resulting tensor will have shape [2, 7].
1784 /// let concat = Tensor::cat(vec![t1, t2], 1);
1785 /// println!("{concat}");
1786 /// }
1787 /// ```
1788 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1789 check!(TensorCheck::cat(&tensors, dim));
1790
1791 Self::new(K::cat(
1792 tensors.into_iter().map(|vector| vector.primitive).collect(),
1793 dim,
1794 ))
1795 }
1796
1797 /// Concatenates all tensors into a new one along a new dimension.
1798 ///
1799 /// # Panics
1800 ///
1801 /// - If all tensors don't have the same shape.
1802 /// - If given dimension is not with range of 0..D2
1803 ///
1804 /// # Example
1805 ///
1806 /// ```rust
1807 /// use burn_tensor::backend::Backend;
1808 /// use burn_tensor::Tensor;
1809 ///
1810 /// fn example<B: Backend>() {
1811 /// let device = Default::default();
1812 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1813 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1814 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1815 ///
1816 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1817 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1818 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1819 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1820 /// // The resulting tensor will have shape [3, 2, 3].
1821 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1822 /// println!("{stacked}");
1823 /// }
1824 /// ```
1825 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1826 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1827 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1828 Tensor::<B, D2, K>::cat(tensors, dim)
1829 }
1830
1831 /// Iterate over slices of tensors alongside a given dimension.
1832 ///
1833 /// # Panics
1834 ///
1835 /// If given dimension is greater than or equal to tensor rank.
1836 ///
1837 /// # Returns
1838 ///
1839 /// A tensor iterator.
1840 ///
1841 /// # Example
1842 ///
1843 /// ```rust
1844 /// use burn_tensor::backend::Backend;
1845 /// use burn_tensor::Tensor;
1846 /// fn example<B: Backend>() {
1847 /// let device = Default::default();
1848 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1849 /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
1850 /// let iter = tensor.iter_dim(0);
1851 /// for (i,tensor) in iter.enumerate() {
1852 /// println!("Tensor {}: {}", i, tensor);
1853 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1854 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1855 /// }
1856 /// }
1857 /// ```
1858 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1859 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1860 DimIter::new(self, dim)
1861 }
1862
1863 /// Returns a new tensor with the given dimension narrowed to the given range.
1864 ///
1865 /// # Panics
1866 ///
1867 /// - If the dimension is greater than the number of dimensions of the tensor.
1868 /// - If the given range exceeds the number of elements on the given dimension.
1869 ///
1870 /// # Returns
1871 ///
1872 /// A new tensor with the given dimension narrowed to the given range.
1873 ///
1874 /// # Example
1875 ///
1876 /// ```rust
1877 /// use burn_tensor::backend::Backend;
1878 /// use burn_tensor::Tensor;
1879 ///
1880 /// fn example<B: Backend>() {
1881 /// let device = Default::default();
1882 /// // Create a 2D tensor with dimensions [4, 3]
1883 /// let tensor = Tensor::<B, 2>::from_data(
1884 /// [
1885 /// [3.0, 4.9, 2.0],
1886 /// [2.0, 1.9, 3.0],
1887 /// [6.0, 1.5, 7.0],
1888 /// [3.0, 4.9, 9.0],
1889 /// ],
1890 /// &device,
1891 /// );
1892 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1893 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1894 /// // The resulting tensor will have dimensions [3, 3].
1895 /// let narrowed = tensor.narrow(0, 1, 3);
1896 /// println!("{narrowed}");
1897 /// }
1898 /// ```
1899 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1900 check!(TensorCheck::dim_ops::<D>("narrow", dim));
1901 check!(TensorCheck::narrow(&self, dim, start, length));
1902 let dims = self.dims();
1903
1904 let ranges: [Range<usize>; D] = dims
1905 .iter()
1906 .enumerate()
1907 .map(|(i, d)| {
1908 if i == dim {
1909 start..(start + length)
1910 } else {
1911 0..*d
1912 }
1913 })
1914 .collect::<Vec<_>>()
1915 .try_into()
1916 .unwrap();
1917
1918 Self::slice(self, ranges)
1919 }
1920
1921 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1922 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1923 ///
1924 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1925 /// Otherwise all chunks will be of equal size except for the last one.
1926 ///
1927 /// # Panics
1928 ///
1929 /// If the dimension is greater than the number of dimensions of the tensor.
1930 ///
1931 /// # Returns
1932 /// A vector of tensors.
1933 ///
1934 /// # Example
1935 ///
1936 /// ```rust
1937 /// use burn_tensor::backend::Backend;
1938 /// use burn_tensor::Tensor;
1939 ///
1940 /// fn example<B: Backend>() {
1941 /// let device = Default::default();
1942 /// // Create a 2D tensor with dimensions [4, 3]
1943 /// let tensor = Tensor::<B, 2>::from_data(
1944 /// [
1945 /// [3.0, 4.9, 2.0],
1946 /// [2.0, 1.9, 3.0],
1947 /// [6.0, 1.5, 7.0],
1948 /// [3.0, 4.9, 9.0],
1949 /// ],
1950 /// &device,
1951 /// );
1952 /// // Split the tensor along the dimension 1 into 2 chunks.
1953 /// // The first chuck will have shape [4, 2]:
1954 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1955 /// // The second chunk will have shape [4, 1]:
1956 /// // [[2.0], [3.0], [7.0], [9.0]]
1957 /// let chunks = tensor.chunk(2, 1);
1958 /// println!("{chunks:?}");
1959 /// }
1960 /// ```
1961 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1962 check!(TensorCheck::dim_ops::<D>("chunk", dim));
1963 let size = self.shape().dims[dim];
1964 if size < chunks {
1965 return (0..size)
1966 .map(|i| Self::narrow(self.clone(), dim, i, 1))
1967 .collect();
1968 }
1969
1970 let mut tensors = Vec::with_capacity(chunks);
1971 let mut sum_chunk_size = 0;
1972 if size.is_multiple_of(chunks) {
1973 let chunk_size = size / chunks;
1974 for _ in 0..chunks {
1975 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
1976 sum_chunk_size += chunk_size;
1977 }
1978 } else {
1979 let chunk_size = (size / chunks) + 1; // assumes not divisible
1980 for _ in 0..chunks - 1 {
1981 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
1982 sum_chunk_size += chunk_size;
1983 }
1984 let remainder = size % chunk_size;
1985 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
1986 }
1987
1988 tensors
1989 }
1990
1991 /// Splits the tensor into chunks of a specified size along a given dimension.
1992 /// Each chunk is a view of the original tensor.
1993 ///
1994 /// If the tensor size along the given dimension is not divisible by `split_size`,
1995 /// then the last chunk will be smaller.
1996 ///
1997 /// # Panics
1998 ///
1999 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2000 ///
2001 /// # Returns
2002 ///
2003 /// A vector of tensors.
2004 ///
2005 /// # Example
2006 /// ```rust
2007 /// use burn_tensor::backend::Backend;
2008 /// use burn_tensor::Tensor;
2009 ///
2010 /// fn example<B: Backend>() {
2011 /// let device = Default::default();
2012 /// // Create a 1D tensor with 5 elements
2013 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2014 /// // Split the tensor into chunks of size 2 along dimension 0
2015 /// let chunks = tensor.split(2, 0);
2016 /// // The result is a vector of tensors:
2017 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2018 /// println!("{:?}", chunks);
2019 /// }
2020 /// ```
2021 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2022 check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2023 let size = self.shape().dims[dim];
2024 let mut tensors = Vec::new();
2025
2026 let mut start = 0;
2027 while start < size {
2028 let length = usize::min(split_size, size - start);
2029 tensors.push(Self::narrow(self.clone(), dim, start, length));
2030 start += length;
2031 }
2032
2033 tensors
2034 }
2035
2036 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2037 /// Each chunk is a view of the original tensor.
2038 ///
2039 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2040 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2041 ///
2042 /// # Panics
2043 ///
2044 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2045 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2046 ///
2047 /// # Returns
2048 ///
2049 /// A vector of tensors.
2050 ///
2051 /// # Example
2052 /// ```rust
2053 /// use burn_tensor::backend::Backend;
2054 /// use burn_tensor::Tensor;
2055 ///
2056 /// fn example<B: Backend>() {
2057 /// let device = Default::default();
2058 /// // Create a 1D tensor with 5 elements
2059 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2060 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
2061 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2062 /// // The result is a vector of tensors:
2063 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2064 /// println!("{:?}", chunks);
2065 /// }
2066 /// ```
2067 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2068 check!(TensorCheck::split_with_sizes::<D>(
2069 &self.shape(),
2070 &split_sizes,
2071 dim
2072 ));
2073 let mut tensors = Vec::new();
2074
2075 let mut start = 0;
2076 for length in split_sizes {
2077 if length == 0 {
2078 continue;
2079 }
2080 tensors.push(Self::narrow(self.clone(), dim, start, length));
2081 start += length;
2082 }
2083
2084 tensors
2085 }
2086
2087 /// Tests if any element in the `tensor` evaluates to True.
2088 ///
2089 /// # Arguments
2090 ///
2091 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2092 ///
2093 /// # Returns
2094 ///
2095 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2096 /// evaluates to True, False otherwise.
2097 ///
2098 /// # Example
2099 ///
2100 /// ```rust
2101 /// use burn_tensor::backend::Backend;
2102 /// use burn_tensor::{Tensor, Bool};
2103 ///
2104 /// fn example<B: Backend>() {
2105 /// let device = Default::default();
2106 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2107 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2108 ///
2109 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2110 /// let any_tensor = tensor.any();
2111 /// println!("{}", any_tensor);
2112 /// // Tensor { data: [true], ... }
2113 ///
2114 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2115 /// let any_tensor_two = tensor_two.any();
2116 /// println!("{}", any_tensor_two);
2117 /// // Tensor { data: [false], ... }
2118 /// }
2119 /// ```
2120 pub fn any(self) -> Tensor<B, 1, Bool> {
2121 Tensor::new(K::any(self.primitive))
2122 }
2123
2124 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2125 ///
2126 /// # Arguments
2127 ///
2128 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2129 /// * `dim` - The axis along which to test.
2130 ///
2131 /// # Returns
2132 ///
2133 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2134 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2135 /// evaluates to True, False otherwise.
2136 ///
2137 /// # Example
2138 ///
2139 /// ```rust
2140 /// use burn_tensor::backend::Backend;
2141 /// use burn_tensor::{Tensor, Bool};
2142 ///
2143 /// fn example<B: Backend>() {
2144 /// let device = Default::default();
2145 /// let tensor =
2146 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2147 /// // Check if any element in the tensor evaluates to True along the dimension 1.
2148 /// // [[true], [true]],
2149 /// let any_dim = tensor.clone().any_dim(1);
2150 /// println!("{any_dim}");
2151 /// }
2152 /// ```
2153 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2154 Tensor::new(K::any_dim(self.primitive, dim))
2155 }
2156
2157 /// Tests if all elements in the `tensor` evaluate to True.
2158 ///
2159 /// # Arguments
2160 ///
2161 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2162 ///
2163 /// # Returns
2164 ///
2165 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2166 /// evaluate to True, False otherwise.
2167 ///
2168 /// # Example
2169 ///
2170 /// ```rust
2171 /// use burn_tensor::backend::Backend;
2172 /// use burn_tensor::{Tensor, Bool};
2173 ///
2174 /// fn example<B: Backend>() {
2175 /// let device = Default::default();
2176 /// let tensor =
2177 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2178 /// // Check if all elements in the tensor evaluate to True (which is not the case).
2179 /// // [false]
2180 /// let all = tensor.all();
2181 /// println!("{all}");
2182 /// }
2183 /// ```
2184 pub fn all(self) -> Tensor<B, 1, Bool> {
2185 Tensor::new(K::all(self.primitive))
2186 }
2187
2188 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2189 ///
2190 /// # Arguments
2191 ///
2192 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2193 /// * `dim` - The axis along which to test.
2194 ///
2195 /// # Returns
2196 ///
2197 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2198 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2199 /// evaluates to True, False otherwise.
2200 ///
2201 /// # Example
2202 ///
2203 /// ```rust
2204 /// use burn_tensor::backend::Backend;
2205 /// use burn_tensor::{Tensor, Bool};
2206 ///
2207 /// fn example<B: Backend>() {
2208 /// let device = Default::default();
2209 /// let tensor =
2210 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2211 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
2212 /// // [[true, true, false]]
2213 /// let all_dim = tensor.clone().all_dim(0);
2214 /// println!("{all_dim}");
2215 /// }
2216 /// ```
2217 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2218 Tensor::new(K::all_dim(self.primitive, dim))
2219 }
2220
2221 /// Convert the tensor into a scalar.
2222 ///
2223 /// # Panics
2224 ///
2225 /// - If the tensor doesn't have one element.
2226 /// - If the backend fails to read the tensor data synchronously.
2227 ///
2228 /// # Returns
2229 ///
2230 /// The scalar value of the tensor.
2231 ///
2232 /// # Example
2233 ///
2234 /// ```rust
2235 /// use burn_tensor::backend::Backend;
2236 /// use burn_tensor::Tensor;
2237 ///
2238 /// fn example<B: Backend>() {
2239 /// let device = Default::default();
2240 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2241 /// // Convert the tensor with a single element into a scalar.
2242 /// let scalar = tensor.into_scalar();
2243 /// println!("{scalar}");
2244 /// }
2245 /// ```
2246 pub fn into_scalar(self) -> K::Elem {
2247 crate::try_read_sync(self.into_scalar_async()).expect(
2248 "Failed to read tensor data synchronously. This can happen on platforms
2249 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2250 )
2251 }
2252
2253 /// Convert the tensor into a scalar.
2254 ///
2255 /// # Panics
2256 ///
2257 /// If the tensor doesn't have one element.
2258 pub async fn into_scalar_async(self) -> K::Elem {
2259 check!(TensorCheck::into_scalar::<D>(&self.shape()));
2260
2261 self.into_data_async().await.iter().next().unwrap()
2262 }
2263
2264 /// Broadcast the tensor to the given shape.
2265 ///
2266 /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2267 /// (which can be inferred with `-1`).
2268 ///
2269 /// # Arguments
2270 ///
2271 /// * `shape` - The shape to broadcast the tensor to.
2272 /// Can contain -1 for dimensions that should be inferred.
2273 /// The number of elements in the shape must be greater or equal as
2274 /// the number of dimensions of the tensor.
2275 ///
2276 /// # Panics
2277 ///
2278 /// If the tensor cannot be broadcasted to the given shape.
2279 ///
2280 /// # Returns
2281 ///
2282 /// A new tensor with the given shape.
2283 ///
2284 /// # Example
2285 ///
2286 /// ```rust
2287 /// use burn_tensor::backend::Backend;
2288 /// use burn_tensor::Tensor;
2289 ///
2290 /// fn example<B: Backend>() {
2291 /// let device = Default::default();
2292 /// // Create a 2D tensor with dimensions [3, 1]
2293 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2294 /// // Expand the tensor to a new shape [3, 4]
2295 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2296 /// let expanded = tensor.expand([3, 4]);
2297 /// println!("{}", expanded);
2298 /// }
2299 /// ```
2300 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2301 let shape = shape.into_shape(&self.shape());
2302 check!(TensorCheck::expand::<D, D2>(
2303 "expand",
2304 &self.shape(),
2305 &shape,
2306 ));
2307
2308 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2309 }
2310
2311 /// Unfold windows along a dimension.
2312 ///
2313 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2314 /// where windows are advanced by `step` at each index.
2315 ///
2316 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2317 ///
2318 /// The new view will have the unfolded dimension replaced by two dimensions;
2319 /// one in the position of the original dimension, with size equal to the number of windows,
2320 /// and one appended to the right-most position, with size equal to `size`.
2321 ///
2322 /// # Warning
2323 ///
2324 /// For the `ndarray` and `candle` backends; this is not a view but a copy
2325 /// with duplicated data.
2326 ///
2327 /// # Arguments
2328 ///
2329 /// * `dim` - the dimension to unfold.
2330 /// * `size` - the size of each unfolded window.
2331 /// * `step` - the step between each window.
2332 ///
2333 /// # Returns
2334 ///
2335 /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2336 pub fn unfold<const D2: usize, I: AsIndex>(
2337 self,
2338 dim: I,
2339 size: usize,
2340 step: usize,
2341 ) -> Tensor<B, D2, K> {
2342 let dim = canonicalize_dim(dim, D, false);
2343 check!(TensorCheck::unfold::<D, D2>(
2344 "unfold",
2345 &self.shape(),
2346 dim,
2347 size,
2348 step,
2349 ));
2350 Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2351 }
2352}
2353
2354/// Iterator given by (Tensor::iter_dim).
2355pub struct DimIter<B, const D: usize, K>
2356where
2357 B: Backend,
2358 K: BasicOps<B>,
2359{
2360 start: usize,
2361 end: usize,
2362 dim: usize,
2363 ranges: [Range<usize>; D],
2364 tensor: Tensor<B, D, K>,
2365}
2366
2367impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2368 type Item = Tensor<B, D, K>;
2369
2370 fn next(&mut self) -> Option<Self::Item> {
2371 if self.start >= self.end {
2372 return None;
2373 }
2374
2375 let mut ranges = self.ranges.clone();
2376 ranges[self.dim] = self.start..(self.start + 1);
2377
2378 let slice = self.tensor.clone().slice(ranges);
2379 self.start += 1;
2380
2381 Some(slice)
2382 }
2383}
2384
2385impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2386 fn next_back(&mut self) -> Option<Self::Item> {
2387 if self.start >= self.end {
2388 return None;
2389 }
2390
2391 let mut ranges = self.ranges.clone();
2392 ranges[self.dim] = (self.end - 1)..self.end;
2393
2394 let slice = self.tensor.clone().slice(ranges);
2395 self.end = self.end.saturating_sub(1);
2396
2397 Some(slice)
2398 }
2399}
2400
2401impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2402 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2403 let dims = tensor.dims();
2404 let ranges = dims
2405 .iter()
2406 .map(|&dim| 0..dim)
2407 .collect::<Vec<Range<usize>>>();
2408 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2409 Self {
2410 end: dims[dim],
2411 ranges,
2412 start: 0,
2413 dim,
2414 tensor,
2415 }
2416 }
2417}
2418
2419impl<B, const D: usize, K> Tensor<B, D, K>
2420where
2421 B: Backend,
2422 K: BasicOps<B>,
2423 <K as BasicOps<B>>::Elem: Debug,
2424{
2425 #[inline]
2426 fn push_newline_indent(acc: &mut String, indent: usize) {
2427 acc.push('\n');
2428 for _ in 0..indent {
2429 acc.push(' ');
2430 }
2431 }
2432 fn fmt_inner_tensor(
2433 &self,
2434 acc: &mut String,
2435 depth: usize,
2436 multi_index: &mut [usize],
2437 range: (usize, usize),
2438 precision: Option<usize>,
2439 ) {
2440 let (start, end) = range;
2441 for i in start..end {
2442 if i > 0 {
2443 acc.push_str(", ");
2444 }
2445 multi_index[depth] = i;
2446 let range: [Range<usize>; D] =
2447 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2448
2449 let data =
2450 burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
2451
2452 if let Some(data) = data {
2453 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2454 match (precision, K::name()) {
2455 (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2456 (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2457 _ => acc.push_str(&format!("{elem:?}")),
2458 }
2459 } else {
2460 acc.push_str("<Tensor data not available>");
2461 }
2462 }
2463 }
2464
2465 fn fmt_outer_tensor(
2466 &self,
2467 acc: &mut String,
2468 depth: usize,
2469 multi_index: &mut [usize],
2470 print_options: &PrintOptions,
2471 summarize: bool,
2472 range: (usize, usize),
2473 ) {
2474 let (start, end) = range;
2475 for i in start..end {
2476 if i > start {
2477 acc.push(',');
2478 Self::push_newline_indent(acc, depth + 1);
2479 }
2480 acc.push('[');
2481 multi_index[depth] = i;
2482 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2483 acc.push(']');
2484 }
2485 }
2486
2487 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2488 ///
2489 /// This function is designed to work with tensors of any dimensionality.
2490 /// It traverses the tensor dimensions recursively, converting the elements
2491 /// to strings and appending them to the accumulator string with the
2492 /// appropriate formatting.
2493 ///
2494 /// # Arguments
2495 ///
2496 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2497 /// * `depth` - The current depth of the tensor dimensions being processed.
2498 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2499 fn display_recursive(
2500 &self,
2501 acc: &mut String,
2502 depth: usize,
2503 multi_index: &mut [usize],
2504 print_options: &PrintOptions,
2505 summarize: bool,
2506 ) {
2507 let edge_items = print_options.edge_items;
2508
2509 if depth == 0 {
2510 acc.push('[');
2511 }
2512
2513 if depth == self.dims().len() - 1 {
2514 // if we are at the innermost dimension, just push its elements into the accumulator
2515 if summarize && self.dims()[depth] > 2 * edge_items {
2516 // print the starting `edge_items` elements
2517 self.fmt_inner_tensor(
2518 acc,
2519 depth,
2520 multi_index,
2521 (0, edge_items),
2522 print_options.precision,
2523 );
2524 acc.push_str(", ...");
2525 // print the last `edge_items` elements
2526 self.fmt_inner_tensor(
2527 acc,
2528 depth,
2529 multi_index,
2530 (self.dims()[depth] - edge_items, self.dims()[depth]),
2531 print_options.precision,
2532 );
2533 } else {
2534 // print all the elements
2535 self.fmt_inner_tensor(
2536 acc,
2537 depth,
2538 multi_index,
2539 (0, self.dims()[depth]),
2540 print_options.precision,
2541 );
2542 }
2543 } else {
2544 // otherwise, iterate through the current dimension and recursively display the inner tensors
2545 if summarize && self.dims()[depth] > 2 * edge_items {
2546 self.fmt_outer_tensor(
2547 acc,
2548 depth,
2549 multi_index,
2550 print_options,
2551 summarize,
2552 (0, edge_items),
2553 );
2554
2555 acc.push(',');
2556 Self::push_newline_indent(acc, depth + 1);
2557 acc.push_str("...");
2558 Self::push_newline_indent(acc, depth + 1);
2559
2560 self.fmt_outer_tensor(
2561 acc,
2562 depth,
2563 multi_index,
2564 print_options,
2565 summarize,
2566 (self.dims()[depth] - edge_items, self.dims()[depth]),
2567 );
2568 } else {
2569 self.fmt_outer_tensor(
2570 acc,
2571 depth,
2572 multi_index,
2573 print_options,
2574 summarize,
2575 (0, self.dims()[depth]),
2576 );
2577 }
2578 }
2579
2580 if depth == 0 {
2581 acc.push(']');
2582 }
2583 }
2584}
2585
2586#[derive(Clone, Debug)]
2587/// Options for Tensor pretty printing
2588pub struct PrintOptions {
2589 /// number of elements to start summarizing tensor
2590 pub threshold: usize,
2591
2592 /// number of starting elements and ending elements to display
2593 pub edge_items: usize,
2594
2595 /// Precision for floating point numbers
2596 pub precision: Option<usize>,
2597}
2598
2599static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
2600
2601impl PrintOptions {
2602 /// Print options with default values
2603 pub const fn const_default() -> Self {
2604 Self {
2605 threshold: 1000,
2606 edge_items: 3,
2607 precision: None,
2608 }
2609 }
2610}
2611
2612impl Default for PrintOptions {
2613 fn default() -> Self {
2614 Self::const_default()
2615 }
2616}
2617
2618/// Set print options
2619pub fn set_print_options(options: PrintOptions) {
2620 let mut print_opts = PRINT_OPTS.write().unwrap();
2621 *print_opts = options;
2622}
2623
2624/// Pretty print tensors
2625impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
2626where
2627 B: Backend,
2628 B::IntElem: core::fmt::Display,
2629 K: BasicOps<B>,
2630 <K as BasicOps<B>>::Elem: Debug,
2631{
2632 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2633 writeln!(f, "Tensor {{")?;
2634
2635 {
2636 // Do not lock the mutex for the whole function
2637 let mut po = { PRINT_OPTS.read().unwrap().clone() };
2638
2639 // Override the precision if it is set from the formatter
2640 // This will be possible when the tensor is printed using the `{:.*}` syntax
2641 if let Some(precision) = f.precision() {
2642 po.precision = Some(precision);
2643 }
2644
2645 let mut acc = String::new();
2646 let mut multi_index = vec![0; D];
2647 let summarize = self.shape().num_elements() > po.threshold;
2648
2649 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
2650
2651 writeln!(f, " data:")?;
2652 write!(f, "{acc}")?;
2653 writeln!(f, ",")?;
2654 }
2655
2656 writeln!(f, " shape: {:?},", self.dims())?;
2657 writeln!(f, " device: {:?},", self.device())?;
2658 writeln!(f, " backend: {:?},", B::name(&self.device()))?;
2659 writeln!(f, " kind: {:?},", K::name())?;
2660
2661 let dtype = self.primitive.dtype();
2662
2663 writeln!(f, " dtype: {:?},", dtype.name())?;
2664 write!(f, "}}")
2665 }
2666}
2667
2668/// Trait that list all operations that can be applied on all tensors.
2669///
2670/// # Warnings
2671///
2672/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2673pub trait BasicOps<B: Backend>: TensorKind<B> {
2674 /// The type of the tensor elements.
2675 type Elem: Element;
2676
2677 /// Creates an empty tensor with the given shape.
2678 ///
2679 /// # Arguments
2680 ///
2681 /// * `shape` - The shape of the tensor.
2682 /// * `device` - The device on which the tensor will be allocated.
2683 /// * `dtype` - The target data type.
2684 ///
2685 /// # Returns
2686 ///
2687 /// The empty tensor.
2688 ///
2689 /// # Remarks
2690 ///
2691 /// This is a low-level function used internally by the library to call different backend functions
2692 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2693 /// or use this function directly.
2694 ///
2695 /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
2696 /// which is more high-level and designed for public use.
2697 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
2698
2699 /// Creates a tensor of the given shape where each element is equal to the provided value.
2700 ///
2701 /// # Arguments
2702 ///
2703 /// * `shape` - The shape of the tensor.
2704 /// * `fill_value` - The value with which to fill the tensor.
2705 /// * `device` - The device on which the tensor will be allocated.
2706 /// * `dtype` - The target data type.
2707 ///
2708 /// # Returns
2709 ///
2710 /// The empty tensor.
2711 ///
2712 /// # Remarks
2713 ///
2714 /// This is a low-level function used internally by the library to call different backend functions
2715 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2716 /// or use this function directly.
2717 ///
2718 /// For creating full tensors, users should prefer the [Tensor::full](Tensor::full) function,
2719 /// which is more high-level and designed for public use.
2720 fn full<E: ElementConversion>(
2721 shape: Shape,
2722 fill_value: E,
2723 device: &B::Device,
2724 dtype: DType,
2725 ) -> Self::Primitive;
2726
2727 /// Reshapes the tensor.
2728 ///
2729 /// # Arguments
2730 ///
2731 /// * `tensor` - The tensor.
2732 /// * `shape` - The new shape of the tensor.
2733 ///
2734 /// # Returns
2735 ///
2736 /// The reshaped tensor.
2737 ///
2738 /// # Remarks
2739 ///
2740 /// This is a low-level function used internally by the library to call different backend functions
2741 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2742 /// or use this function directly.
2743 ///
2744 /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
2745 /// which is more high-level and designed for public use.
2746 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2747
2748 /// Transposes a tensor.
2749 ///
2750 /// # Arguments
2751 ///
2752 /// * `tensor` - The tensor to transpose.
2753 ///
2754 /// # Returns
2755 ///
2756 /// The transposed tensor.
2757 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
2758
2759 /// Swaps two dimensions of a tensor.
2760 ///
2761 /// # Arguments
2762 ///
2763 /// * `tensor` - The tensor to swap the dimensions of.
2764 /// * `dim1` - The first dimension to swap.
2765 /// * `dim2` - The second dimension to swap.
2766 ///
2767 /// # Returns
2768 ///
2769 /// The tensor with the dimensions swapped.
2770 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
2771
2772 /// Permutes the dimensions of a tensor.
2773 ///
2774 /// # Arguments
2775 ///
2776 /// * `tensor` - The tensor to permute the dimensions of.
2777 /// * `axes` - The new order of the dimensions.
2778 ///
2779 /// # Returns
2780 ///
2781 /// The tensor with the dimensions permuted.
2782 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2783
2784 /// Flips the tensor along the given axes.
2785 ///
2786 /// # Arguments
2787 ///
2788 /// * `tensor` - The tensor to flip.
2789 /// * `axes` - The axes to flip the tensor along.
2790 ///
2791 /// # Returns
2792 ///
2793 /// The tensor with the axes flipped.
2794 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2795
2796 /// Select tensor elements corresponding to the given slices.
2797 ///
2798 /// # Arguments
2799 ///
2800 /// * `tensor` - The tensor.
2801 /// * `slices` - The slices specifying ranges and steps for each dimension.
2802 ///
2803 /// # Returns
2804 ///
2805 /// The selected elements.
2806 ///
2807 /// # Remarks
2808 ///
2809 /// This is a low-level function used internally by the library to call different backend functions
2810 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2811 /// or use this function directly.
2812 ///
2813 /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2814 /// which is more high-level and designed for public use.
2815 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;
2816
2817 /// Assigns the given value to the tensor elements corresponding to the given slices.
2818 ///
2819 /// # Arguments
2820 ///
2821 /// * `tensor` - The tensor.
2822 /// * `slices` - The slices specifying which elements to assign, including support for steps.
2823 /// * `value` - The value to assign.
2824 ///
2825 /// # Returns
2826 ///
2827 /// The tensor with the assigned values.
2828 ///
2829 /// # Remarks
2830 ///
2831 /// This is a low-level function used internally by the library to call different backend functions
2832 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2833 /// or use this function directly.
2834 ///
2835 /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2836 /// which is more high-level and designed for public use.
2837 fn slice_assign(
2838 tensor: Self::Primitive,
2839 slices: &[Slice],
2840 value: Self::Primitive,
2841 ) -> Self::Primitive;
2842
2843 /// Fills the tensor elements corresponding to the given slices with the given value.
2844 ///
2845 /// # Arguments
2846 ///
2847 /// * `tensor` - The tensor.
2848 /// * `slices` - The slices specifying ranges and steps for each dimension.
2849 /// * `value` - The value to fill.
2850 ///
2851 /// # Returns
2852 ///
2853 /// The tensor with the filled values at the specified slice positions.
2854 ///
2855 /// # Remarks
2856 ///
2857 /// This is a low-level function used internally by the library to call different backend functions
2858 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2859 /// or use this function directly.
2860 ///
2861 /// For filling values in a tensor, users should prefer the [Tensor::slice_fill](Tensor::slice_fill) function,
2862 /// which is more high-level and designed for public use.
2863 fn slice_fill(tensor: Self::Primitive, slices: &[Slice], value: Self::Elem) -> Self::Primitive {
2864 let slice_shape = tensor.shape().slice(slices).unwrap();
2865 let value = Self::from_data_dtype(
2866 TensorData::from([value]),
2867 &Self::device(&tensor),
2868 Self::dtype(&tensor),
2869 );
2870 let value = Self::expand(value, slice_shape);
2871 Self::slice_assign(tensor, slices, value)
2872 }
2873
2874 /// Slices the tensor along a given dimension.
2875 ///
2876 /// # Arguments
2877 ///
2878 /// * `tensor` - The tensor to slice.
2879 /// * `dim` - The dimension along which to slice.
2880 /// * `slice` - The slice information containing range and step.
2881 ///
2882 /// # Returns
2883 ///
2884 /// The sliced tensor.
2885 ///
2886 /// # Remarks
2887 ///
2888 /// This is a low-level function used internally by the library to call different backend functions
2889 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2890 /// or use this function directly.
2891 fn slice_dim(tensor: Self::Primitive, dim: usize, slice: &Slice) -> Self::Primitive {
2892 let mut slices = vec![Slice::full(); tensor.shape().num_dims()];
2893 slices[dim] = *slice;
2894
2895 Self::slice(tensor, &slices)
2896 }
2897
2898 /// Select tensor elements along the given dimension corresponding to the given indices.
2899 ///
2900 /// # Arguments
2901 ///
2902 /// * `tensor` - The tensor to select from.
2903 /// * `dim` - The dimension along which to select.
2904 /// * `indices` - The indices of the elements to select.
2905 ///
2906 /// # Returns
2907 ///
2908 /// The selected tensor elements.
2909 ///
2910 /// # Remarks
2911 ///
2912 /// This is a low-level function used internally by the library to call different backend functions
2913 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2914 /// or use this function directly.
2915 ///
2916 /// For selecting elements from a tensor along an axis, users should prefer the
2917 /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
2918 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive;
2919
2920 /// Assign the selected elements along the given dimension corresponding to the given indices
2921 /// from the value tensor.
2922 ///
2923 /// # Arguments
2924 ///
2925 /// * `tensor` - The tensor to assign elements to.
2926 /// * `dim` - The axis along which to assign elements.
2927 /// * `indices` - The indices of the elements to assign.
2928 /// * `values` - The values to assign to the tensor.
2929 ///
2930 /// # Returns
2931 ///
2932 /// A tensor with the same shape as the input tensor, where each element is taken from the
2933 /// corresponding element of the input tensor at the corresponding index along the specified axis,
2934 /// except for the elements at the specified indices, which are taken from the corresponding
2935 /// element of the values tensor.
2936 ///
2937 /// # Remarks
2938 ///
2939 /// This is a low-level function used internally by the library to call different backend functions
2940 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2941 /// or use this function directly.
2942 ///
2943 /// For assigning elements to a tensor along an axis, users should prefer the
2944 /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
2945 fn select_assign(
2946 tensor: Self::Primitive,
2947 dim: usize,
2948 indices: Tensor<B, 1, Int>,
2949 values: Self::Primitive,
2950 ) -> Self::Primitive;
2951
2952 /// Returns the device on which the tensor is allocated.
2953 ///
2954 /// # Arguments
2955 ///
2956 /// * `tensor` - The tensor.
2957 ///
2958 /// # Returns
2959 ///
2960 /// The device on which the tensor is allocated.
2961 ///
2962 /// # Remarks
2963 ///
2964 /// This is a low-level function used internally by the library to call different backend functions
2965 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2966 /// or use this function directly.
2967 ///
2968 /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2969 /// which is more high-level and designed for public use.
2970 fn device(tensor: &Self::Primitive) -> B::Device;
2971
2972 /// Moves the tensor to the given device.
2973 ///
2974 /// # Arguments
2975 ///
2976 /// * `tensor` - The tensor.
2977 /// * `device` - The device on which the tensor will be moved.
2978 ///
2979 /// # Returns
2980 ///
2981 /// The tensor on the given device.
2982 ///
2983 /// # Remarks
2984 ///
2985 /// This is a low-level function used internally by the library to call different backend functions
2986 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2987 /// or use this function directly.
2988 ///
2989 /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2990 /// which is more high-level and designed for public use.
2991 #[allow(clippy::wrong_self_convention)]
2992 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2993
2994 /// Extracts the data from the tensor asynchronously.
2995 ///
2996 /// # Arguments
2997 ///
2998 /// * `tensor` - The tensor.
2999 ///
3000 /// # Returns
3001 ///
3002 /// The data of the tensor.
3003 ///
3004 /// # Remarks
3005 ///
3006 /// This is a low-level function used internally by the library to call different backend functions
3007 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3008 /// or use this function directly.
3009 ///
3010 /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
3011 /// which is more high-level and designed for public use.
3012 #[allow(clippy::wrong_self_convention)]
3013 fn into_data_async(tensor: Self::Primitive) -> impl Future<Output = TensorData> + Send;
3014
3015 /// Read the data from the tensor using a transaction.
3016 ///
3017 /// # Remarks
3018 ///
3019 /// This is a low-level function used internally by the library to call different backend functions
3020 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3021 /// or use this function directly.
3022 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
3023
3024 /// Creates a tensor from the given data.
3025 ///
3026 /// # Arguments
3027 ///
3028 /// * `data` - The data of the tensor.
3029 /// * `device` - The device on which the tensor will be allocated.
3030 ///
3031 /// # Returns
3032 ///
3033 /// The tensor.
3034 ///
3035 /// # Remarks
3036 ///
3037 /// This is a low-level function used internally by the library to call different backend functions
3038 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3039 /// or use this function directly.
3040 ///
3041 /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
3042 /// which is more high-level and designed for public use.
3043 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
3044 /// Creates a tensor from the given data enforcing the given data type.
3045 ///
3046 /// # Remarks
3047 ///
3048 /// This is a low-level function used internally by the library to call different backend functions
3049 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3050 /// or use this function directly.
3051 ///
3052 /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype)
3053 /// function, which is more high-level and designed for public use.
3054 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
3055
3056 /// Repeat the tensor along the given dimension.
3057 ///
3058 /// # Arguments
3059 ///
3060 /// * `tensor` - The tensor.
3061 /// * `dim` - The dimension along which the tensor will be repeated.
3062 /// * `times` - The number of times the tensor will be repeated.
3063 ///
3064 /// # Returns
3065 ///
3066 /// The repeated tensor.
3067 ///
3068 /// # Remarks
3069 ///
3070 /// This is a low-level function used internally by the library to call different backend functions
3071 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3072 /// or use this function directly.
3073 ///
3074 /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
3075 /// which is more high-level and designed for public use.
3076 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
3077
3078 /// Concatenates the given tensors along the given dimension.
3079 ///
3080 /// # Arguments
3081 ///
3082 /// * `vectors` - The tensors to concatenate.
3083 /// * `dim` - The dimension along which the tensors will be concatenated.
3084 ///
3085 /// # Returns
3086 ///
3087 /// The concatenated tensor.
3088 ///
3089 /// # Remarks
3090 ///
3091 /// This is a low-level function used internally by the library to call different backend functions
3092 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3093 /// or use this function directly.
3094 ///
3095 /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
3096 /// which is more high-level and designed for public use.
3097 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
3098
3099 /// Equates the given tensors.
3100 ///
3101 /// # Arguments
3102 ///
3103 /// * `lhs` - The left hand side tensor.
3104 /// * `rhs` - The right hand side tensor.
3105 ///
3106 /// # Returns
3107 ///
3108 /// The tensor of booleans indicating whether the corresponding elements are equal.
3109 ///
3110 /// # Remarks
3111 ///
3112 /// This is a low-level function used internally by the library to call different backend functions
3113 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3114 /// or use this function directly.
3115 ///
3116 /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
3117 /// which is more high-level and designed for public use.
3118 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3119
3120 /// Applies element-wise non-equality comparison between the given tensors.
3121 ///
3122 /// # Arguments
3123 ///
3124 /// * `lhs` - The left hand side tensor.
3125 /// * `rhs` - The right hand side tensor.
3126 ///
3127 /// # Returns
3128 ///
3129 /// The tensor of booleans indicating whether the corresponding elements are equal.
3130 ///
3131 /// # Remarks
3132 ///
3133 /// This is a low-level function used internally by the library to call different backend functions
3134 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3135 /// or use this function directly.
3136 ///
3137 /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
3138 /// function, which is more high-level and designed for public use.
3139 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3140
3141 /// Returns the name of the element type.
3142 fn elem_type_name() -> &'static str {
3143 core::any::type_name::<Self::Elem>()
3144 }
3145
3146 /// Returns the tensor data type.
3147 fn dtype(tensor: &Self::Primitive) -> DType {
3148 tensor.dtype()
3149 }
3150
3151 /// Tests if any element in the `tensor` evaluates to True.
3152 ///
3153 /// # Arguments
3154 ///
3155 /// * `tensor` - The tensor to test.
3156 ///
3157 /// # Returns
3158 ///
3159 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
3160 ///
3161 /// # Remarks
3162 ///
3163 /// This is a low-level function used internally by the library to call different backend functions
3164 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3165 /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
3166 /// which is more high-level and designed for public use.
3167 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
3168
3169 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
3170 ///
3171 /// # Arguments
3172 ///
3173 /// * tensor - The tensor to test.
3174 /// * dim - The axis along which to test.
3175 ///
3176 /// # Returns
3177 ///
3178 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
3179 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
3180 ///
3181 /// # Remarks
3182 ///
3183 /// This is a low-level function used internally by the library to call different backend functions
3184 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3185 /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
3186 /// which is more high-level and designed for public use.
3187 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
3188
3189 /// Tests if all elements in the `tensor` evaluate to True.
3190 ///
3191 /// # Arguments
3192 ///
3193 /// * `tensor` - The tensor to test.
3194 ///
3195 /// # Returns
3196 ///
3197 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
3198 ///
3199 /// # Remarks
3200 ///
3201 /// This is a low-level function used internally by the library to call different backend functions
3202 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3203 /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
3204 /// which is more high-level and designed for public use.
3205 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
3206
3207 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
3208 ///
3209 /// # Arguments
3210 ///
3211 /// * `tensor` - The tensor to test.
3212 ///
3213 /// # Returns
3214 ///
3215 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
3216 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
3217 ///
3218 /// # Remarks
3219 ///
3220 /// This is a low-level function used internally by the library to call different backend functions
3221 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3222 /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
3223 /// which is more high-level and designed for public use.
3224 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
3225
3226 /// Broadcasts the given tensor to the specified shape.
3227 ///
3228 /// # Arguments
3229 ///
3230 /// * `tensor` - The tensor to broadcast.
3231 /// * `shape` - The shape to broadcast to.
3232 ///
3233 /// # Returns
3234 ///
3235 /// The broadcasted tensor.
3236 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
3237
3238 /// Unfold windows along a dimension.
3239 ///
3240 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
3241 /// where windows are advanced by `step` at each index.
3242 ///
3243 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
3244 ///
3245 /// # Warning
3246 ///
3247 /// For the `ndarray` and `candle` backends; this is not a view but a full copy.
3248 ///
3249 /// # Arguments
3250 ///
3251 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
3252 /// * `dim` - the dimension to unfold.
3253 /// * `size` - the size of each unfolded window.
3254 /// * `step` - the step between each window.
3255 ///
3256 /// # Returns
3257 ///
3258 /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
3259 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;
3260}
3261
3262impl<B: Backend> BasicOps<B> for Float {
3263 type Elem = B::FloatElem;
3264
3265 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3266 TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
3267 }
3268
3269 fn full<E: ElementConversion>(
3270 shape: Shape,
3271 fill_value: E,
3272 device: &B::Device,
3273 dtype: DType,
3274 ) -> Self::Primitive {
3275 TensorPrimitive::Float(B::float_full(
3276 shape,
3277 fill_value.elem(),
3278 device,
3279 dtype.into(),
3280 ))
3281 }
3282
3283 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
3284 tr.register_float(tensor);
3285 }
3286
3287 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3288 match tensor {
3289 TensorPrimitive::Float(tensor) => {
3290 TensorPrimitive::Float(B::float_reshape(tensor, shape))
3291 }
3292 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
3293 }
3294 }
3295
3296 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
3297 match tensor {
3298 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
3299 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
3300 }
3301 }
3302
3303 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
3304 match tensor {
3305 TensorPrimitive::Float(tensor) => {
3306 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
3307 }
3308 TensorPrimitive::QFloat(tensor) => {
3309 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
3310 }
3311 }
3312 }
3313
3314 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
3315 match tensor {
3316 TensorPrimitive::Float(tensor) => {
3317 TensorPrimitive::Float(B::float_slice(tensor, slices))
3318 }
3319 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
3320 }
3321 }
3322
3323 fn slice_assign(
3324 tensor: Self::Primitive,
3325 slices: &[Slice],
3326 value: Self::Primitive,
3327 ) -> Self::Primitive {
3328 TensorPrimitive::Float(B::float_slice_assign(
3329 tensor.tensor(),
3330 slices,
3331 value.tensor(),
3332 ))
3333 }
3334
3335 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3336 match tensor {
3337 TensorPrimitive::Float(tensor) => {
3338 TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive))
3339 }
3340 TensorPrimitive::QFloat(tensor) => {
3341 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive))
3342 }
3343 }
3344 }
3345
3346 fn select_assign(
3347 tensor: Self::Primitive,
3348 dim: usize,
3349 indices: Tensor<B, 1, Int>,
3350 values: Self::Primitive,
3351 ) -> Self::Primitive {
3352 // Select assign is ambiguous for QFloat
3353 TensorPrimitive::Float(B::float_select_assign(
3354 tensor.tensor(),
3355 dim,
3356 indices.primitive,
3357 values.tensor(),
3358 ))
3359 }
3360
3361 fn device(tensor: &Self::Primitive) -> Device<B> {
3362 match tensor {
3363 TensorPrimitive::Float(tensor) => B::float_device(tensor),
3364 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
3365 }
3366 }
3367
3368 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
3369 match tensor {
3370 TensorPrimitive::Float(tensor) => {
3371 TensorPrimitive::Float(B::float_to_device(tensor, device))
3372 }
3373 TensorPrimitive::QFloat(tensor) => {
3374 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
3375 }
3376 }
3377 }
3378
3379 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
3380 match tensor {
3381 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
3382 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
3383 }
3384 }
3385
3386 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
3387 match &data.dtype {
3388 DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
3389 _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
3390 }
3391 }
3392
3393 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
3394 match dtype {
3395 DType::QFloat(_scheme) => {
3396 TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
3397 }
3398 _ if dtype.is_float() => {
3399 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
3400 }
3401 _ => panic!("Expected float dtype, got {dtype:?}"),
3402 }
3403 }
3404
3405 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
3406 match tensor {
3407 TensorPrimitive::Float(tensor) => {
3408 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
3409 }
3410 TensorPrimitive::QFloat(tensor) => {
3411 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
3412 }
3413 }
3414 }
3415
3416 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
3417 match vectors.first().unwrap() {
3418 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
3419 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
3420 dim,
3421 )),
3422 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
3423 vectors
3424 .into_iter()
3425 .map(|tensor| {
3426 if let TensorPrimitive::QFloat(t) = tensor {
3427 t
3428 } else {
3429 panic!("Concatenation only works with vector of QFloat")
3430 }
3431 })
3432 .collect(),
3433 dim,
3434 )),
3435 }
3436 }
3437
3438 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3439 B::float_equal(lhs.tensor(), rhs.tensor())
3440 }
3441
3442 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3443 B::float_not_equal(lhs.tensor(), rhs.tensor())
3444 }
3445
3446 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3447 B::float_any(tensor.tensor())
3448 }
3449
3450 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3451 B::float_any_dim(tensor.tensor(), dim)
3452 }
3453
3454 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3455 B::float_all(tensor.tensor())
3456 }
3457
3458 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3459 B::float_all_dim(tensor.tensor(), dim)
3460 }
3461
3462 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3463 match tensor {
3464 TensorPrimitive::Float(tensor) => {
3465 TensorPrimitive::Float(B::float_permute(tensor, axes))
3466 }
3467 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
3468 }
3469 }
3470
3471 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3472 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
3473 }
3474
3475 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3476 match tensor {
3477 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
3478 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
3479 }
3480 }
3481
3482 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
3483 TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
3484 }
3485}
3486
3487impl<B: Backend> BasicOps<B> for Int {
3488 type Elem = B::IntElem;
3489
3490 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3491 B::int_empty(shape, device, dtype.into())
3492 }
3493
3494 fn full<E: ElementConversion>(
3495 shape: Shape,
3496 fill_value: E,
3497 device: &B::Device,
3498 dtype: DType,
3499 ) -> Self::Primitive {
3500 B::int_full(shape, fill_value.elem(), device, dtype.into())
3501 }
3502
3503 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
3504 tr.register_int(tensor);
3505 }
3506
3507 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3508 B::int_reshape(tensor, shape)
3509 }
3510
3511 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
3512 B::int_transpose(tensor)
3513 }
3514
3515 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
3516 B::int_swap_dims(tensor, dim1, dim2)
3517 }
3518
3519 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
3520 B::int_slice(tensor, slices)
3521 }
3522
3523 fn slice_assign(
3524 tensor: Self::Primitive,
3525 slices: &[Slice],
3526 value: Self::Primitive,
3527 ) -> Self::Primitive {
3528 B::int_slice_assign(tensor, slices, value)
3529 }
3530
3531 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3532 B::int_select(tensor, dim, indices.primitive)
3533 }
3534
3535 fn select_assign(
3536 tensor: Self::Primitive,
3537 dim: usize,
3538 indices: Tensor<B, 1, Int>,
3539 values: Self::Primitive,
3540 ) -> Self::Primitive {
3541 B::int_select_assign(tensor, dim, indices.primitive, values)
3542 }
3543
3544 fn device(tensor: &Self::Primitive) -> Device<B> {
3545 B::int_device(tensor)
3546 }
3547
3548 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
3549 B::int_to_device(tensor, device)
3550 }
3551
3552 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
3553 B::int_into_data(tensor).await
3554 }
3555
3556 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
3557 B::int_from_data(data.convert::<B::IntElem>(), device)
3558 }
3559
3560 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
3561 if !dtype.is_int() {
3562 panic!("Expected int dtype, got {dtype:?}")
3563 }
3564
3565 B::int_from_data(data.convert_dtype(dtype), device)
3566 }
3567
3568 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
3569 B::int_repeat_dim(tensor, dim, times)
3570 }
3571
3572 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3573 B::int_equal(lhs, rhs)
3574 }
3575
3576 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3577 B::int_not_equal(lhs, rhs)
3578 }
3579
3580 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
3581 B::int_cat(vectors, dim)
3582 }
3583
3584 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3585 B::int_any(tensor)
3586 }
3587
3588 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3589 B::int_any_dim(tensor, dim)
3590 }
3591
3592 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3593 B::int_all(tensor)
3594 }
3595
3596 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3597 B::int_all_dim(tensor, dim)
3598 }
3599
3600 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3601 B::int_permute(tensor, axes)
3602 }
3603
3604 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3605 B::int_expand(tensor, shape)
3606 }
3607
3608 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3609 B::int_flip(tensor, axes)
3610 }
3611
3612 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
3613 B::int_unfold(tensor, dim, size, step)
3614 }
3615}
3616
3617impl<B: Backend> BasicOps<B> for Bool {
3618 type Elem = B::BoolElem;
3619
3620 fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3621 if dtype != Self::Elem::dtype() {
3622 panic!("Expected bool data type, got {dtype:?}");
3623 }
3624 B::bool_empty(shape, device)
3625 }
3626
3627 fn full<E: ElementConversion>(
3628 shape: Shape,
3629 fill_value: E,
3630 device: &B::Device,
3631 dtype: DType,
3632 ) -> Self::Primitive {
3633 if dtype != Self::Elem::dtype() {
3634 panic!("Expected bool data type, got {dtype:?}");
3635 }
3636 if fill_value.elem() {
3637 B::bool_ones(shape, device)
3638 } else {
3639 B::bool_zeros(shape, device)
3640 }
3641 }
3642
3643 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
3644 tr.register_bool(tensor);
3645 }
3646
3647 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3648 B::bool_reshape(tensor, shape)
3649 }
3650
3651 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
3652 B::bool_transpose(tensor)
3653 }
3654
3655 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
3656 B::bool_swap_dims(tensor, dim1, dim2)
3657 }
3658
3659 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
3660 B::bool_slice(tensor, slices)
3661 }
3662
3663 fn slice_assign(
3664 tensor: Self::Primitive,
3665 slices: &[Slice],
3666 value: Self::Primitive,
3667 ) -> Self::Primitive {
3668 B::bool_slice_assign(tensor, slices, value)
3669 }
3670
3671 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3672 B::bool_select(tensor, dim, indices.primitive)
3673 }
3674
3675 fn select_assign(
3676 tensor: Self::Primitive,
3677 dim: usize,
3678 indices: Tensor<B, 1, Int>,
3679 values: Self::Primitive,
3680 ) -> Self::Primitive {
3681 B::bool_select_assign(tensor, dim, indices.primitive, values)
3682 }
3683
3684 fn device(tensor: &Self::Primitive) -> Device<B> {
3685 B::bool_device(tensor)
3686 }
3687
3688 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
3689 B::bool_to_device(tensor, device)
3690 }
3691
3692 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
3693 B::bool_into_data(tensor).await
3694 }
3695
3696 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
3697 B::bool_from_data(data.convert::<B::BoolElem>(), device)
3698 }
3699
3700 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
3701 // Backends only use one bool representation dtype
3702 if dtype != B::BoolElem::dtype() {
3703 panic!("Expected bool dtype, got {dtype:?}")
3704 }
3705 B::bool_from_data(data.convert_dtype(dtype), device)
3706 }
3707
3708 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
3709 B::bool_repeat_dim(tensor, dim, times)
3710 }
3711
3712 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3713 B::bool_equal(lhs, rhs)
3714 }
3715
3716 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3717 B::bool_not_equal(lhs, rhs)
3718 }
3719
3720 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
3721 B::bool_cat(vectors, dim)
3722 }
3723
3724 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3725 B::bool_any(tensor)
3726 }
3727
3728 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3729 B::bool_any_dim(tensor, dim)
3730 }
3731
3732 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
3733 B::bool_all(tensor)
3734 }
3735
3736 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
3737 B::bool_all_dim(tensor, dim)
3738 }
3739
3740 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3741 B::bool_permute(tensor, axes)
3742 }
3743
3744 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
3745 B::bool_expand(tensor, shape)
3746 }
3747
3748 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
3749 B::bool_flip(tensor, axes)
3750 }
3751
3752 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
3753 B::bool_unfold(tensor, dim, size, step)
3754 }
3755}
3756
3757/// Trait used for movedim arguments
3758pub trait MovedimArgs {
3759 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3760 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3761}
3762
3763impl MovedimArgs for Vec<i32> {
3764 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3765 let set = self
3766 .iter()
3767 .map(|&dim| {
3768 if dim < 0 {
3769 (D as i32 + dim) as usize
3770 } else {
3771 dim as usize
3772 }
3773 })
3774 .collect::<Vec<usize>>();
3775 check!(TensorCheck::movedim_args_vec::<D>(&set));
3776
3777 set
3778 }
3779}
3780
3781impl MovedimArgs for Vec<usize> {
3782 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3783 check!(TensorCheck::movedim_args_vec::<D>(&self));
3784 self
3785 }
3786}
3787
3788impl MovedimArgs for usize {
3789 #[allow(clippy::vec_init_then_push)]
3790 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3791 check!(TensorCheck::movedim_args_usize::<D>(self));
3792
3793 let mut set = Vec::with_capacity(1);
3794 set.push(self);
3795
3796 set
3797 }
3798}
3799
3800impl MovedimArgs for i32 {
3801 #[allow(clippy::vec_init_then_push)]
3802 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3803 check!(TensorCheck::movedim_args_i32::<D>(self));
3804
3805 let dim = if self < 0 {
3806 (D as i32 + self) as usize
3807 } else {
3808 self as usize
3809 };
3810
3811 let mut set = Vec::with_capacity(1);
3812 set.push(dim);
3813
3814 set
3815 }
3816}
3817
3818/// Trait used for reshape arguments.
3819pub trait ReshapeArgs<const D2: usize> {
3820 /// Converts to a shape.
3821 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3822 self,
3823 tensor: &Tensor<B, D, K>,
3824 ) -> Shape;
3825}
3826
3827impl<const D2: usize> ReshapeArgs<D2> for Shape {
3828 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3829 self,
3830 tensor: &Tensor<B, D, K>,
3831 ) -> Shape {
3832 check!(TensorCheck::reshape_args_usize::<D, D2>(
3833 &tensor.shape(),
3834 &self
3835 ));
3836
3837 self
3838 }
3839}
3840impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3841 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3842 self,
3843 tensor: &Tensor<B, D, K>,
3844 ) -> Shape {
3845 let shape = Shape::from(self);
3846
3847 check!(TensorCheck::reshape_args_usize::<D, D2>(
3848 &tensor.shape(),
3849 &shape
3850 ));
3851
3852 shape
3853 }
3854}
3855
3856impl<const D2: usize> ReshapeArgs<D2> for [i64; D2] {
3857 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3858 self,
3859 tensor: &Tensor<B, D, K>,
3860 ) -> Shape {
3861 // Validate the reshape arguments
3862 check!(TensorCheck::reshape_args_i64(&self));
3863
3864 // Temporary shape
3865 let mut new_shape: [i64; D2] = [1; D2];
3866
3867 // We need to find the index of the 0 dimension and
3868 // replace it with the actual dimension value.
3869 for (i, &s) in self.iter().enumerate() {
3870 if s != 0 {
3871 new_shape[i] = s;
3872 } else {
3873 new_shape[i] = tensor.dims()[i] as i64;
3874 }
3875 }
3876
3877 // Find the index of the inferred dimension (-1)
3878 let infer_index = new_shape.iter().position(|x| x == &-1);
3879
3880 // Handle the case where the dimension is inferred (via -1)
3881 if let Some(index) = infer_index {
3882 // Handle the case where the dimension is inferred
3883 let mut product = 1;
3884 for (i, &s) in new_shape.iter().enumerate() {
3885 if i != index {
3886 product *= s;
3887 }
3888 }
3889 let product_current = tensor.shape().num_elements() as i64;
3890
3891 new_shape[index] = product_current / product;
3892
3893 // Check if the reshape is valid
3894 if product_current % product != 0 {
3895 panic!(
3896 "Cannot reshape tensor of shape {:?} to shape {:?}",
3897 tensor.shape(),
3898 new_shape
3899 );
3900 }
3901 };
3902
3903 // Convert each element to usize
3904 let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3905
3906 Shape::from(new_shape)
3907 }
3908}
3909
3910impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3911 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3912 self,
3913 tensor: &Tensor<B, D, K>,
3914 ) -> Shape {
3915 // Convert i32 array to i64 array and use existing implementation
3916 let i64_array: [i64; D2] = self.map(|x| x as i64);
3917 ReshapeArgs::into_shape(i64_array, tensor)
3918 }
3919}
3920
3921/// Trait used for broadcast arguments.
3922pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3923 /// Converts to a shape.
3924 fn into_shape(self, shape: &Shape) -> Shape;
3925}
3926
3927impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3928 fn into_shape(self, _shape: &Shape) -> Shape {
3929 self
3930 }
3931}
3932impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3933 fn into_shape(self, _shape: &Shape) -> Shape {
3934 Shape::from(self)
3935 }
3936}
3937
3938impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3939 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3940 fn into_shape(self, shape: &Shape) -> Shape {
3941 if self.len() < shape.num_dims() {
3942 panic!("Broadcast arguments must be greater than the number of dimensions");
3943 }
3944
3945 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3946 let new_shape: Vec<_> = self
3947 .iter()
3948 .rev()
3949 .map(|x| {
3950 let primitive = x.to_i64();
3951 if primitive < -1 || primitive == 0 {
3952 panic!("Broadcast arguments must be positive or -1");
3953 }
3954 primitive
3955 })
3956 .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3957 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3958 .collect::<Vec<_>>()
3959 .into_iter()
3960 .rev()
3961 .collect();
3962
3963 if new_shape.contains(&0) {
3964 panic!("Cannot substitute -1 for a non-existing dimension");
3965 }
3966
3967 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3968
3969 Shape::from(new_shape)
3970 }
3971}
3972
3973impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3974where
3975 B: Backend,
3976 K: BasicOps<B>,
3977 K::Elem: Debug + Copy + Serialize,
3978{
3979 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3980 let data = self.to_data();
3981 data.serialize(serializer)
3982 }
3983}
3984
3985impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3986where
3987 B: Backend,
3988 K: BasicOps<B>,
3989 K::Elem: Debug + Copy + Deserialize<'de>,
3990{
3991 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3992 let tensor = Tensor::from_data(
3993 TensorData::deserialize(deserializer)?,
3994 &<B::Device as Default>::default(),
3995 );
3996 Ok(tensor)
3997 }
3998}
3999
4000#[cfg(test)]
4001mod tests {
4002 use crate::{Shape, s};
4003
4004 #[test]
4005 fn slice_range_single_dim_leading() {
4006 let shape = Shape::new([8, 4]);
4007
4008 // Half-open range
4009 let slices = shape.clone().into_slices([0..5]);
4010 assert_eq!(slices[0].to_range(8), 0..5);
4011 let slices = shape.clone().into_slices([-3..-1]);
4012 assert_eq!(slices[0].to_range(8), 5..7);
4013
4014 // Inclusive range
4015 let slices = shape.clone().into_slices([0..=4]);
4016 assert_eq!(slices[0].to_range(8), 0..5);
4017 let slices = shape.clone().into_slices([-2..=-1]);
4018 assert_eq!(slices[0].to_range(8), 6..8);
4019
4020 // Unbounded start
4021 let slices = shape.clone().into_slices([..3]);
4022 assert_eq!(slices[0].to_range(8), 0..3);
4023 let slices = shape.clone().into_slices([..-5]);
4024 assert_eq!(slices[0].to_range(8), 0..3);
4025
4026 // Unbounded end
4027 let slices = shape.clone().into_slices([5..]);
4028 assert_eq!(slices[0].to_range(8), 5..8);
4029 let slices = shape.clone().into_slices([-3..]);
4030 assert_eq!(slices[0].to_range(8), 5..8);
4031
4032 // Full range
4033 let slices = shape.into_slices([..]);
4034 assert_eq!(slices[0].to_range(8), 0..8);
4035 }
4036
4037 #[test]
4038 fn test_negative_slice_indices() {
4039 use crate::Slice;
4040
4041 // Test negative indices conversion
4042 let slice: Slice = (-3..-1).into();
4043 assert_eq!(slice.start, -3);
4044 assert_eq!(slice.end, Some(-1));
4045
4046 // Test to_range conversion with size 8
4047 let range = slice.to_range(8);
4048 assert_eq!(range, 5..7);
4049
4050 // Test with shape slice
4051 let shape = Shape::new([8, 4]);
4052 let result = shape.clone().into_slices([-3..-1]);
4053 assert_eq!(result[0].to_range(8), 5..7);
4054
4055 // Test more negative index cases
4056 let slice2: Slice = (-5..).into();
4057 assert_eq!(slice2.to_range(10), 5..10);
4058
4059 let slice3: Slice = (..-2).into();
4060 assert_eq!(slice3.to_range(10), 0..8);
4061
4062 // Test with s! macro - single dimension returns Slice directly
4063 let slice4 = s![-3..-1];
4064 assert_eq!(slice4.start, -3);
4065 assert_eq!(slice4.end, Some(-1));
4066 }
4067
4068 #[test]
4069 fn slice_range_multi_dim() {
4070 let shape = Shape::new([8, 4]);
4071
4072 // Multiple ways to provide ranges
4073 let slices = shape.clone().into_slices([0..5, 0..4]);
4074 assert_eq!(slices[0].to_range(8), 0..5);
4075 assert_eq!(slices[1].to_range(4), 0..4);
4076
4077 let slices = shape.clone().into_slices([0.., 0..]);
4078 assert_eq!(slices[0].to_range(8), 0..8);
4079 assert_eq!(slices[1].to_range(4), 0..4);
4080
4081 let slices = shape.clone().into_slices([0..=7, 0..=3]);
4082 assert_eq!(slices[0].to_range(8), 0..8);
4083 assert_eq!(slices[1].to_range(4), 0..4);
4084
4085 let slices = shape.clone().into_slices([0..5, 0..3]);
4086 assert_eq!(slices[0].to_range(8), 0..5);
4087 assert_eq!(slices[1].to_range(4), 0..3);
4088
4089 let slices = shape.into_slices([0.., 0..]);
4090 assert_eq!(slices[0].to_range(8), 0..8);
4091 assert_eq!(slices[1].to_range(4), 0..4);
4092 }
4093
4094 #[test]
4095 fn slice_range_multi_dim_index() {
4096 let shape = Shape::new([8, 4]);
4097
4098 // Indices (single integer) should also convert to correct range
4099 let slices = shape.clone().into_slices([0, 2]);
4100 assert_eq!(slices[0].to_range(8), 0..1);
4101 assert_eq!(slices[1].to_range(4), 2..3);
4102
4103 let slices = shape.into_slices([-1, -1]);
4104 assert_eq!(slices[0].to_range(8), 7..8);
4105 assert_eq!(slices[1].to_range(4), 3..4);
4106 }
4107
4108 #[test]
4109 fn slice_range_multi_dim_heterogeneous() {
4110 // Slice macro `s![]` can be used to provide different range types
4111 let shape = Shape::new([8, 4, 2]);
4112 let slice = s![0..5, .., -1];
4113 let slices = shape.into_slices(slice);
4114 assert_eq!(slices[0].to_range(8), 0..5);
4115 assert_eq!(slices[1].to_range(4), 0..4);
4116 assert_eq!(slices[2].to_range(2), 1..2);
4117
4118 let shape = Shape::new([8, 4, 2, 3]);
4119 let slice = s![..=4, 0..=3, .., -2..];
4120 let slices = shape.into_slices(slice);
4121 assert_eq!(slices[0].to_range(8), 0..5);
4122 assert_eq!(slices[1].to_range(4), 0..4);
4123 assert_eq!(slices[2].to_range(2), 0..2);
4124 assert_eq!(slices[3].to_range(3), 1..3);
4125
4126 let shape = Shape::new([3, 4]);
4127 let slice = s![1..-1, ..];
4128 let slices = shape.into_slices(slice);
4129 assert_eq!(slices[0].to_range(3), 1..2);
4130 assert_eq!(slices[1].to_range(4), 0..4);
4131 }
4132}