burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2use crate::backend::ExecutionError;
3use crate::check::unwrap_shape_reshape;
4
5pub use burn_backend::tensor::BasicOps;
6
7use alloc::vec::Vec;
8
9use alloc::format;
10use alloc::string::String;
11use alloc::vec;
12
13use burn_std::stub::RwLock;
14use core::iter::repeat;
15use core::{fmt::Debug, ops::Range};
16use serde::{Deserialize, Deserializer};
17
18use crate::{AsIndex, Slice, SliceArg, wrap_index};
19use crate::{
20 Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,
21 backend::Backend, check,
22};
23use crate::{DType, Element};
24use crate::{IndexingUpdateOp, TensorCreationOptions};
25use crate::{cast::ToElement, check::TensorCheck};
26use serde::{Serialize, Serializer};
27
28/// A tensor with a given backend, shape and data type.
29///
30/// # Indexing
31/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
32/// or [`select`](Tensor::select) for numeric types.
33///
34/// ## Example
35///
36/// ```rust
37/// use burn_tensor::backend::Backend;
38/// use burn_tensor::Tensor;
39/// use burn_tensor::Int;
40///
41/// fn example<B: Backend>() {
42/// let device = Default::default();
43///
44/// let tensor = Tensor::<B, 2>::from_data(
45/// [
46/// [3.0, 4.9, 2.0],
47/// [2.0, 1.9, 3.0],
48/// [6.0, 1.5, 7.0],
49/// [3.0, 4.9, 9.0],
50/// ],
51/// &device,
52/// );
53///
54/// // Slice the tensor to get the second and third rows:
55/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
56/// // The resulting tensor will have dimensions [2, 3].
57/// let slice = tensor.clone().slice([1..3]);
58/// println!("{slice}");
59///
60/// // Slice the tensor to get the first two rows and the first 2 columns:
61/// // [[3.0, 4.9], [2.0, 1.9]]
62/// // The resulting tensor will have dimensions [2, 2].
63/// let slice = tensor.clone().slice([0..2, 0..2]);
64/// println!("{slice}");
65///
66/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
67/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
68/// // The resulting tensor will have dimensions [4, 2]
69/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
70/// let indexed = tensor.select(1, indices);
71/// println!("{indexed}");
72/// }
73/// ```
74#[derive(new, Clone, Debug)]
75pub struct Tensor<B, const D: usize, K = Float>
76where
77 B: Backend,
78 K: TensorKind<B>,
79{
80 pub(crate) primitive: K::Primitive,
81}
82
83impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
84where
85 B: Backend,
86 K: BasicOps<B>,
87 T: Into<TensorData>,
88{
89 fn from(value: T) -> Self {
90 Tensor::from_data(value.into(), &Default::default())
91 }
92}
93
94impl<B, const D: usize, K> Tensor<B, D, K>
95where
96 B: Backend,
97 K: BasicOps<B>,
98 K::Elem: Element,
99{
100 /// Executes an operation on the tensor and modifies its value.
101 ///
102 /// # Notes
103 ///
104 /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
105 /// no other reference pointing to the same tensor.
106 ///
107 /// Wrapping operations with inplace is not an optimization, it's mainly there if you
108 /// want to mutate a tensor by using owned operations. A plausible usage would be to
109 /// update the weights of a mutable model reference.
110 pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
111 let mut tensor_owned = Tensor::empty([0; D], &self.device());
112 core::mem::swap(&mut tensor_owned, self);
113
114 let mut tensor_new = func(tensor_owned);
115 core::mem::swap(&mut tensor_new, self);
116 }
117
118 /// Converts the tensor into a primitive tensor.
119 pub fn into_primitive(self) -> K::Primitive {
120 self.primitive
121 }
122
123 /// Converts from a primitive tensor into a tensor.
124 pub fn from_primitive(tensor: K::Primitive) -> Self {
125 Self::new(tensor)
126 }
127
128 /// Returns the number of dimensions of the tensor.
129 pub fn rank(&self) -> usize {
130 self.primitive.rank()
131 }
132
133 /// Returns the tensor primitive data type.
134 ///
135 /// # Note
136 /// Some element types are encoded in different primitive types depending on the backend
137 /// (e.g., bool could be encoded as `u8` or `u32`).
138 pub fn dtype(&self) -> DType {
139 self.primitive.dtype()
140 }
141
142 /// Create an empty tensor of the given shape.
143 ///
144 /// # Arguments
145 ///
146 /// - `shape`: The shape of the tensor.
147 /// - `device`: The device where the tensor will be created.
148 ///
149 /// # Example
150 /// ```rust
151 /// use burn_tensor::backend::Backend;
152 /// use burn_tensor::Tensor;
153 ///
154 /// fn example<B: Backend>() {
155 /// let device = Default::default();
156 /// // Create an empty tensor with dimensions [2, 3, 4].
157 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
158 /// }
159 /// ```
160 pub fn empty<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
161 let opt = options.into();
162 let shape = shape.into();
163 check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
164 Self::new(K::empty(shape, &opt.device, opt.dtype_or(K::Elem::dtype())))
165 }
166
167 /// Create a tensor of the given shape where each element is zero.
168 ///
169 /// # Example
170 ///
171 /// ```rust
172 /// use burn_tensor::backend::Backend;
173 /// use burn_tensor::{Tensor, Shape};
174 ///
175 /// fn example<B: Backend>() {
176 /// let device = B::Device::default();
177 /// let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
178 /// println!("{tensor}");
179 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
180 /// }
181 /// ```
182 pub fn zeros<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
183 let opt = options.into();
184 let shape = shape.into();
185 check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
186 Self::new(K::zeros(shape, &opt.device, opt.dtype_or(K::Elem::dtype())))
187 }
188
189 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
190 ///
191 /// # Example
192 ///
193 /// ```rust
194 /// use burn_tensor::backend::Backend;
195 /// use burn_tensor::{Tensor, Shape};
196 ///
197 /// fn example<B: Backend>() {
198 /// let device = B::Device::default();
199 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
200 /// let tensor = tensor.zeros_like();
201 /// println!("{tensor}");
202 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
203 /// }
204 /// ```
205 pub fn zeros_like(&self) -> Self {
206 Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
207 }
208
209 /// Create a tensor of the given shape where each element is one.
210 ///
211 /// # Example
212 ///
213 /// ```rust
214 /// use burn_tensor::backend::Backend;
215 /// use burn_tensor::{Tensor, Shape};
216 ///
217 /// fn example<B: Backend>() {
218 /// let device = B::Device::default();
219 /// let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
220 /// println!("{tensor}");
221 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
222 /// }
223 /// ```
224 pub fn ones<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
225 let opt = options.into();
226 let shape = shape.into();
227 check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
228 Self::new(K::ones(shape, &opt.device, opt.dtype_or(K::Elem::dtype())))
229 }
230
231 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
232 ///
233 /// # Example
234 ///
235 /// ```rust
236 /// use burn_tensor::backend::Backend;
237 /// use burn_tensor::{Tensor, Shape};
238 ///
239 /// fn example<B: Backend>() {
240 /// let device = B::Device::default();
241 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
242 /// let tensor = tensor.ones_like();
243 /// println!("{tensor}");
244 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
245 /// }
246 /// ```
247 pub fn ones_like(&self) -> Self {
248 Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
249 }
250
251 /// Create a tensor of the given shape where each element is equal to the provided value.
252 ///
253 /// # Example
254 ///
255 /// ```rust
256 /// use burn_tensor::backend::Backend;
257 /// use burn_tensor::{Tensor, Shape};
258 ///
259 /// fn example<B: Backend>() {
260 /// let device = B::Device::default();
261 /// let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
262 /// println!("{tensor}");
263 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
264 /// }
265 /// ```
266 pub fn full<S: Into<Shape>, E: ElementConversion>(
267 shape: S,
268 fill_value: E,
269 options: impl Into<TensorCreationOptions<B>>,
270 ) -> Self {
271 let opt = options.into();
272 let shape = shape.into();
273 check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
274 Self::new(K::full(
275 shape,
276 fill_value,
277 &opt.device,
278 opt.dtype_or(K::Elem::dtype()),
279 ))
280 }
281
282 /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
283 /// filled with the provided value.
284 ///
285 /// # Example
286 ///
287 /// ```rust
288 /// use burn_tensor::backend::Backend;
289 /// use burn_tensor::{Tensor, Shape};
290 ///
291 /// fn example<B: Backend>() {
292 /// let device = B::Device::default();
293 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
294 /// let tensor = tensor.full_like(5.0);
295 /// println!("{tensor}");
296 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
297 /// }
298 /// ```
299 pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
300 Self::new(K::full(
301 self.shape(),
302 fill_value,
303 &self.device(),
304 self.dtype(),
305 ))
306 }
307
308 /// Returns the dimensions of the current tensor.
309 ///
310 /// # Example
311 /// ```rust
312 /// use burn_tensor::backend::Backend;
313 /// use burn_tensor::Tensor;
314 ///
315 /// fn example<B: Backend>() {
316 /// let device = Default::default();
317 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
318 /// let dims = tensor.dims(); // [2, 3, 4]
319 /// println!("{dims:?}");
320 /// }
321 /// ```
322 pub fn dims(&self) -> [usize; D] {
323 Self::shape(self).dims()
324 }
325
326 /// Returns the shape of the current tensor.
327 ///
328 /// # Example
329 /// ```rust
330 /// use burn_tensor::backend::Backend;
331 /// use burn_tensor::Tensor;
332 ///
333 /// fn example<B: Backend>() {
334 /// let device = Default::default();
335 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
336 /// // Shape { dims: [2, 3, 4] }
337 /// let shape = tensor.shape();
338 /// }
339 /// ```
340 pub fn shape(&self) -> Shape {
341 self.primitive.shape()
342 }
343
344 /// Reshape the tensor to have the given shape.
345 ///
346 /// The tensor has the same data and number of elements as the input.
347 ///
348 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
349 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
350 ///
351 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
352 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
353 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
354 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
355 /// with [1, 3, 4] dimensions to [1, 12].
356 ///
357 /// # Arguments
358 /// - `shape`: The new shape of the tensor.
359 ///
360 /// # Panics
361 /// - If the tensor contains more than one `-1` in the shape.
362 /// - If the tensor contains values that are not positive (other than -1).
363 /// - If the shape does not match the number of elements of the original shape.
364 ///
365 /// # Example
366 ///
367 /// ```rust
368 /// use burn_tensor::backend::Backend;
369 /// use burn_tensor::Tensor;
370 ///
371 /// fn example<B: Backend>() {
372 /// let device = Default::default();
373 /// // Create a tensor with dimensions [2, 3, 4]
374 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
375 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
376 /// let reshaped = tensor.reshape([2, -1]);
377 /// println!("{reshaped}");
378 /// }
379 /// ```
380 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
381 // Convert reshape args to shape
382 let shape = shape.into_shape::<D2>(self.shape());
383 Tensor::new(K::reshape(self.primitive, shape))
384 }
385
386 /// Transpose the tensor.
387 ///
388 /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
389 /// applied on the last two dimensions. For example, the transpose of a tensor with shape
390 /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
391 ///
392 /// See also [`permute`](Tensor::permute).
393 ///
394 /// # Arguments
395 ///
396 /// * `tensor` - The tensor to transpose.
397 ///
398 /// # Returns
399 ///
400 /// The transposed tensor.
401 ///
402 /// # Example
403 ///
404 /// ```rust
405 /// use burn_tensor::backend::Backend;
406 /// use burn_tensor::Tensor;
407 ///
408 /// fn example<B: Backend>() {
409 /// let device = Default::default();
410 /// // Create a 2D tensor of shape [2, 3]
411 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
412 ///
413 /// // Transpose the tensor:
414 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
415 /// // The resulting tensor will have dimensions [3, 2].
416 /// let transposed = tensor.transpose();
417 /// println!("{transposed}");
418 /// }
419 /// ```
420 pub fn transpose(self) -> Tensor<B, D, K> {
421 Tensor::new(K::transpose(self.primitive))
422 }
423
424 /// Alias for `transpose`.
425 #[inline(always)]
426 pub fn t(self) -> Tensor<B, D, K> {
427 self.transpose()
428 }
429
430 /// Swaps two dimensions of a tensor.
431 ///
432 /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
433 ///
434 /// # Arguments
435 ///
436 /// * `tensor` - The tensor to swap the dimensions of.
437 /// * `dim1` - The first dimension to swap, supports negative indexing.
438 /// * `dim2` - The second dimension to swap, supports negative indexing.
439 ///
440 /// # Returns
441 ///
442 /// The tensor with the dimensions swapped.
443 ///
444 /// # Panics
445 ///
446 /// When dimensions are out of bounds.
447 ///
448 /// # Example
449 ///
450 /// ```rust
451 /// use burn_tensor::backend::Backend;
452 /// use burn_tensor::Tensor;
453 ///
454 /// fn example<B: Backend>() {
455 /// let device = Default::default();
456 /// // Create a 2D tensor of shape [2, 3]
457 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
458 ///
459 /// // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
460 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
461 /// // The resulting tensor will have dimensions [3, 2].
462 /// let swapped = tensor.swap_dims(0, -1);
463 /// println!("{swapped}");
464 /// }
465 /// ```
466 pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
467 where
468 Dim1: AsIndex,
469 Dim2: AsIndex,
470 {
471 let dim1 = dim1.expect_dim_index(D);
472 let dim2 = dim2.expect_dim_index(D);
473 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
474 if dim1 == dim2 {
475 self
476 } else {
477 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
478 }
479 }
480
481 /// Permute the dimensions of the tensor.
482 ///
483 /// This is a no-op when the resolved `axes` match the current order.
484 ///
485 /// # Arguments
486 ///
487 /// * `axes` - The new order of the dimensions. The length of the axes
488 /// must be equal to the number of dimensions of the tensor.
489 /// The values must be unique and in the range of the number of dimensions.
490 /// The values can be negative, in which case they are used as an offset from the end.
491 ///
492 /// # Returns
493 ///
494 /// The tensor with the dimensions permuted.
495 ///
496 /// # Example
497 ///
498 /// ```rust
499 /// use burn_tensor::backend::Backend;
500 /// use burn_tensor::Tensor;
501 ///
502 /// fn example<B: Backend>() {
503 /// let device = Default::default();
504 /// // Create a 2D tensor of shape [3, 2]
505 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
506 ///
507 /// // Permute the dimensions 1 and 0:
508 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
509 /// // The resulting tensor will have dimensions [3, 2].
510 /// let permuted = tensor.permute([1, 0]);
511 /// println!("{permuted}");
512 /// }
513 /// ```
514 pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
515 where
516 Dim: AsIndex,
517 {
518 let mut no_op = true;
519 let mut fixed_axes = [0; D];
520 for (i, axis) in axes.into_iter().enumerate() {
521 let dim = axis.expect_dim_index(D);
522 no_op &= dim == i;
523 fixed_axes[i] = dim;
524 }
525
526 if no_op {
527 self
528 } else {
529 check!(TensorCheck::permute(fixed_axes));
530 Tensor::new(K::permute(self.primitive, &fixed_axes))
531 }
532 }
533
534 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
535 ///
536 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
537 /// at the positions not specified in destination.
538 ///
539 /// # Arguments
540 ///
541 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
542 /// The values can be negative, in which case they are used as an offset from the end.
543 ///
544 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
545 ///
546 /// # Panics
547 ///
548 /// - If the source and destination dimensions are not of the same length.
549 /// - If the source and destination vectors contain duplicate values.
550 /// - If the source and destination vectors contain values that are out of bounds.
551 ///
552 /// # Returns
553 ///
554 /// The tensor with the dimensions moved.
555 ///
556 /// # Example
557 ///
558 /// ```rust
559 /// use burn_tensor::backend::Backend;
560 /// use burn_tensor::Tensor;
561 ///
562 /// fn example<B: Backend>() {
563 /// let device = Default::default();
564 /// // Create a 3D tensor of shape [3, 2, 1]
565 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
566 ///
567 /// // Move the dimensions 0 and 1:
568 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
569 /// // The resulting tensor will have dimensions [2, 3, 1].
570 /// let moved = tensor.movedim(1, 0);
571 /// println!("{moved}");
572 /// }
573 /// ```
574 ///
575 /// # Note
576 ///
577 /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
578 /// for it
579 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
580 let source_dims = src.into_dim_vec::<D>();
581 let destination_dims = dst.into_dim_vec::<D>();
582
583 check!(TensorCheck::movedim_args_length(
584 &source_dims,
585 &destination_dims
586 ));
587
588 let mut m = [-1; D];
589 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
590 m[d] = s as isize;
591 }
592 let mut axes: [isize; D] = [0; D];
593 let mut source_i = 0;
594 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
595 *item = if m[dest_i] != -1 {
596 m[dest_i]
597 } else {
598 while source_dims.contains(&source_i) {
599 source_i += 1;
600 }
601 let result = source_i as isize;
602 source_i += 1;
603 result
604 };
605 }
606
607 self.permute(axes)
608 }
609
610 /// Reverse the order of elements in the tensor along the given dimensions.
611 ///
612 /// # Arguments
613 ///
614 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
615 /// The values can be negative, in which case they are used as an offset from the end.
616 ///
617 /// # Returns
618 ///
619 /// The tensor with the axes flipped.
620 ///
621 /// # Example
622 ///
623 /// ```rust
624 /// use burn_tensor::backend::Backend;
625 /// use burn_tensor::Tensor;
626 ///
627 /// fn example<B: Backend>() {
628 /// let device = Default::default();
629 /// // Create a 2D tensor with dimensions [4, 3]
630 /// let tensor = Tensor::<B, 2>::from_data(
631 /// [
632 /// [3.0, 4.9, 2.0],
633 /// [2.0, 1.9, 3.0],
634 /// [4.0, 5.9, 8.0],
635 /// [1.4, 5.8, 6.0],
636 /// ],
637 /// &device,
638 /// );
639 ///
640 /// // Flip the elements in dimensions 0 and 1:
641 /// // [[6.0, 5.8, 1.4],
642 /// // [8.0, 5.9, 4.0],
643 /// // [3.0, 1.9, 2.0],
644 /// // [2.0, 4.9, 3.0]]
645 /// // The resulting tensor will have dimensions [4, 3].
646 /// let flipped = tensor.flip([0, 1]);
647 /// println!("{flipped}");
648 /// }
649 /// ```
650 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
651 // Convert the axes to usize and handle negative values without using vector
652 let mut transformed_axes: [usize; N] = [0; N];
653 for (i, &x) in axes.iter().enumerate() {
654 transformed_axes[i] = if x < 0 {
655 (D as isize + x) as usize
656 } else {
657 x as usize
658 };
659 }
660
661 // Check if the axes are valid
662 check!(TensorCheck::flip(D, &transformed_axes));
663
664 Tensor::new(K::flip(self.primitive, &transformed_axes))
665 }
666
667 /// Flatten the tensor along a given range of dimensions.
668 ///
669 /// This function collapses the specified range of dimensions into a single dimension,
670 /// effectively flattening the tensor in that range.
671 ///
672 /// # Arguments
673 ///
674 /// - `start_dim`: The starting dimension of the range to be flattened,
675 /// supports negative indexing.
676 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
677 /// supports negative indexing.
678 ///
679 /// # Type Parameters
680 ///
681 /// - `D2`: The resulting number of dimensions in the flattened tensor.
682 ///
683 /// # Returns
684 ///
685 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
686 ///
687 /// # Example
688 ///
689 /// ```rust
690 ///
691 /// use burn_tensor::backend::Backend;
692 /// use burn_tensor::{Tensor, Shape};
693 ///
694 /// fn example<B: Backend>() {
695 /// let device = Default::default();
696 /// // Create a 3D tensor with dimensions [2, 3, 4]
697 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
698 ///
699 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
700 /// // The resulting tensor will have dimensions [2, 12]
701 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
702 /// println!("{flattened}");
703 /// }
704 /// ```
705 pub fn flatten<const D2: usize>(
706 self,
707 start_dim: impl AsIndex,
708 end_dim: impl AsIndex,
709 ) -> Tensor<B, D2, K> {
710 let start_dim = start_dim.expect_dim_index(D);
711 let end_dim = end_dim.expect_dim_index(D);
712 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
713 let new_shape = self.shape().flatten_dims(start_dim, end_dim);
714
715 Tensor::new(K::reshape(self.primitive, new_shape))
716 }
717
718 /// Squeeze the tensor along all dimensions, removing dimensions
719 /// of size one, and effectively reducing the rank of the tensor.
720 ///
721 /// # Type Parameters
722 ///
723 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
724 ///
725 /// # Returns
726 ///
727 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
728 ///
729 /// # Example
730 ///
731 /// ```rust
732 ///
733 /// use burn_tensor::backend::Backend;
734 /// use burn_tensor::{Tensor, Shape};
735 ///
736 /// fn example<B: Backend>() {
737 /// let device = Default::default();
738 /// // Create a 4D tensor with dimensions [1, 3, 1, 3]
739 /// let tensor = Tensor::<B, 4>::from_data(
740 /// [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
741 /// &device,
742 /// );
743 ///
744 /// // Squeeze the tensor dimensions.
745 /// // The resulting tensor will have dimensions [3, 3].
746 /// let squeezed = tensor.squeeze::<2>();
747 /// println!("{squeezed}");
748 /// }
749 /// ```
750 pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
751 let new_dims = self
752 .shape()
753 .dims
754 .iter()
755 .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
756 .collect::<Vec<_>>();
757 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
758
759 Tensor::new(K::reshape(self.primitive, new_dims.into()))
760 }
761
762 /// Squeeze the tensor along the given dimension, removing the specified dimension
763 /// of size one, and effectively reducing the rank of the tensor by one.
764 ///
765 /// # Arguments
766 ///
767 /// - `dim`: The dimension to be squeezed.
768 ///
769 /// # Type Parameters
770 ///
771 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
772 ///
773 /// # Panics
774 ///
775 /// If the size in the squeezed dimension is not 1.
776 ///
777 /// # Returns
778 ///
779 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
780 ///
781 /// # Example
782 ///
783 /// ```rust
784 ///
785 /// use burn_tensor::backend::Backend;
786 /// use burn_tensor::{Tensor, Shape};
787 ///
788 /// fn example<B: Backend>() {
789 /// let device = Default::default();
790 /// // Create a 3D tensor with dimensions [3, 1, 3]
791 /// let tensor = Tensor::<B, 3>::from_data(
792 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
793 /// &device,
794 /// );
795 ///
796 /// // Squeeze the dimension 1.
797 /// // The resulting tensor will have dimensions [3, 3].
798 /// let squeezed = tensor.squeeze_dim::<2>(1);
799 /// println!("{squeezed}");
800 /// }
801 /// ```
802 pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
803 check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
804
805 let current_dims = self.shape().dims;
806 let mut new_dims: [usize; D2] = [0; D2];
807
808 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
809 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
810
811 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
812 Tensor::new(K::reshape(self.primitive, new_dims.into()))
813 }
814
815 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
816 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
817 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
818 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
819 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
820 /// from the back.
821 ///
822 /// # Arguments
823 ///
824 /// - `dims`: The dimension(s) to be squeezed.
825 ///
826 /// # Type Parameters
827 ///
828 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
829 ///
830 /// # Returns
831 ///
832 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
833 ///
834 /// # Example
835 ///
836 /// ```rust
837 ///
838 /// use burn_tensor::backend::Backend;
839 /// use burn_tensor::{Tensor, Shape};
840 ///
841 /// fn example<B: Backend>() {
842 /// let device = Default::default();
843 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
844 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
845 ///
846 /// // Squeeze the dimensions 1 and 3.
847 /// // The resulting tensor will have dimensions [2, 4].
848 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
849 /// println!("{squeezed}");
850 /// }
851 /// ```
852 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
853 let current_dims = self.shape().dims;
854 let mut dim_indices: Vec<usize>;
855
856 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
857 if dims.is_empty() {
858 dim_indices = current_dims
859 .iter()
860 .enumerate()
861 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
862 .collect();
863 } else {
864 // If negative dims, count from the back
865 dim_indices = dims
866 .iter()
867 .map(|&d| {
868 if d < 0 {
869 (current_dims.len() as isize + d) as usize
870 } else {
871 d as usize
872 }
873 })
874 .collect();
875 }
876
877 // Sort indices and remove duplicates
878 dim_indices.sort_unstable();
879 dim_indices.dedup();
880
881 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
882 check!(TensorCheck::squeeze_dims_input::<D2>(
883 &dim_indices,
884 ¤t_dims
885 ));
886
887 // Calculate new dimensions
888 let mut new_dims = Vec::new();
889 for (index, &dim_size) in current_dims.iter().enumerate() {
890 // Exclude the dimension if it's explicitly marked for squeezing
891 if dim_indices.contains(&index) {
892 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
893 continue;
894 }
895 new_dims.push(dim_size);
896 }
897
898 // Check that after squeezing, we still respect the D2 size
899 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
900
901 Tensor::new(K::reshape(self.primitive, new_dims.into()))
902 }
903
904 /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
905 ///
906 /// # Type Parameters
907 ///
908 /// - `D2`: The resulting number of dimensions in the unsqueezed tensor.
909 ///
910 /// # Panics
911 ///
912 /// If the output size `D2` is smaller than the current number of dimensions.
913 ///
914 /// # Returns
915 ///
916 /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
917 ///
918 /// # Example
919 ///
920 /// ```rust
921 /// use burn_tensor::backend::Backend;
922 /// use burn_tensor::{Tensor, Shape};
923 ///
924 /// fn example<B: Backend>() {
925 /// let device = Default::default();
926 /// // Create a 2D tensor with dimensions [3, 3]
927 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
928 /// // Unsqueeze the tensor up to 4 dimensions.
929 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
930 /// let unsqueezed = tensor.unsqueeze::<4>();
931 /// println!("{unsqueezed}");
932 /// }
933 /// ```
934 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
935 check!(TensorCheck::unsqueeze::<D, D2>());
936
937 let mut dims = [1; D2];
938 let num_ones = D2 - D;
939 let shape = self.shape();
940
941 dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
942
943 let shape = Shape::new(dims);
944 self.reshape(shape)
945 }
946
947 /// Creates a new tensor with a dimension of size one inserted at the specified position.
948 ///
949 /// # Example
950 ///
951 /// ```rust
952 /// use burn_tensor::backend::Backend;
953 /// use burn_tensor::{Tensor, Shape};
954 ///
955 /// fn example<B: Backend>() {
956 /// let device = Default::default();
957 /// // Create a 2D tensor with dimensions [3, 3]
958 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
959 /// // Unsqueeze the dimension 1.
960 /// // The resulting tensor will have dimensions [3, 1, 3].
961 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
962 /// println!("{unsqueezed}");
963 /// }
964 /// ```
965 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
966 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
967
968 let mut dims = [1; D2];
969 let shape = self.shape();
970
971 dims[0..dim].copy_from_slice(&shape[0..dim]);
972
973 if dim < D {
974 dims[dim] = 1;
975 dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
976 } else {
977 dims[dim] = 1;
978 }
979
980 let shape = Shape::new(dims);
981 self.reshape(shape)
982 }
983
984 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
985 /// The indices can be negative, in which case they are counted from the last to the first dimension.
986 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
987 /// is the number of duplicates.
988 /// # Example
989 ///
990 /// ```rust
991 /// use burn_tensor::backend::Backend;
992 /// use burn_tensor::{Tensor, Shape};
993 ///
994 /// fn example<B: Backend>() {
995 /// let device = Default::default();
996 /// // Create a 3D tensor with dimensions [3, 4, 5]
997 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
998 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
999 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
1000 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
1001 /// println!("{unsqueezed}");
1002 /// }
1003 /// ```
1004 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[impl AsIndex]) -> Tensor<B, D2, K> {
1005 let mut new_dims = [1; D2];
1006 let old_dims = self.shape().dims;
1007 //for checking if the dimension is in the acceptable range
1008
1009 //part 1: convert the negative indices to positive
1010 let mut neg_offset = D2;
1011 let mut dim_indices = axes
1012 .iter()
1013 .map(|d| {
1014 let d = d.as_index();
1015 // check if the dimension is in the acceptable range
1016 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(d));
1017 (if d < 0 {
1018 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
1019 d + neg_offset as isize + 1
1020 } else {
1021 d
1022 }) as usize
1023 })
1024 .collect::<Vec<usize>>();
1025
1026 //sort the indices
1027 dim_indices.sort_unstable();
1028
1029 //Now use this to copy the chunks of the dims
1030 let mut prev_idx: usize = 0;
1031 let mut current_left_b: usize = 0;
1032 let mut current_right_b: usize = 0;
1033 let mut offset: usize = 0;
1034 dim_indices.iter().for_each(|d| {
1035 //check if there is space for at least one dimension
1036 if prev_idx < *d {
1037 current_right_b = *d - offset;
1038 //copy the chunks of the dims
1039 if current_right_b < D {
1040 new_dims[prev_idx..*d]
1041 .copy_from_slice(&old_dims[current_left_b..current_right_b]);
1042 } else {
1043 new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
1044 }
1045 prev_idx = *d + 1;
1046 //offset is equal to the number of extracted elements from the original shape
1047 offset += current_right_b - current_left_b;
1048 current_left_b = current_right_b;
1049 } else {
1050 //it's sorted so the only reason this would happen
1051 //is if multiple indices are the same
1052 prev_idx += 1;
1053 }
1054 });
1055 //copy over anything past the index of the last new dimension
1056 if current_left_b < D {
1057 new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
1058 }
1059
1060 //lastly, create the shape and reshape
1061 let shape = Shape::new(new_dims);
1062 self.reshape(shape)
1063 }
1064
1065 /// Roll operation along a specific dimension; wrapping around the elements.
1066 ///
1067 /// ## Parameters
1068 ///
1069 /// - `shift`: The roll extent; supports negative values and wraps around.
1070 /// - `dim`: The dimension to roll; supports negative indexing.
1071 ///
1072 /// ## Returns
1073 ///
1074 /// A new tensor with the specified dimension rolled by the given shift amount.
1075 pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
1076 where
1077 Shift: AsIndex,
1078 Dim: AsIndex,
1079 {
1080 let dim = dim.expect_dim_index(D);
1081 let size = self.shape().dims[dim];
1082 if size == 0 {
1083 // If the dimension is empty, return the tensor as is.
1084 return self;
1085 }
1086
1087 let shift = wrap_index(shift, size);
1088 if shift == 0 {
1089 // If the shift is zero, return the tensor as is.
1090 return self;
1091 }
1092
1093 self.unchecked_roll_dim(shift, dim)
1094 }
1095
1096 /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1097 ///
1098 /// ## Parameters
1099 ///
1100 /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1101 /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1102 ///
1103 /// ## Returns
1104 ///
1105 /// A new tensor with the specified dimension rolled by the given shift amount.
1106 #[inline(always)]
1107 fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1108 #[cfg(debug_assertions)]
1109 {
1110 let size = self.shape().dims[dim];
1111 assert!(
1112 0 < shift && shift < size,
1113 "Expected: 0 < shift < size: found shift={shift}, size={size}",
1114 );
1115 assert!(
1116 dim < self.shape().num_dims(),
1117 "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1118 );
1119 }
1120
1121 Tensor::cat(
1122 vec![
1123 self.clone().slice_dim(dim, shift..),
1124 self.slice_dim(dim, ..shift),
1125 ],
1126 dim,
1127 )
1128 }
1129
1130 /// Roll operation.
1131 ///
1132 /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1133 ///
1134 /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1135 ///
1136 /// ## Parameters
1137 ///
1138 /// - `shifts`: A slice of shifts corresponding to each dimension;
1139 /// supports negative values and wraps around.
1140 /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1141 ///
1142 /// ## Returns
1143 ///
1144 /// A new tensor with the specified dimensions rolled by the given shifts.
1145 pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1146 where
1147 Shift: AsIndex,
1148 Dim: AsIndex,
1149 {
1150 assert_eq!(
1151 dims.len(),
1152 shifts.len(),
1153 "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1154 );
1155
1156 // This is a fair amount of complexity, which could be replaced
1157 // by a simple canonicalization of `dims` and wrapping of `shifts`.
1158 // The work is done here to ensure that any roll operation
1159 // which could be a no-op is a no-op; simplifying the accounting
1160 // needed by backend-specific implementations of the inner roll op.
1161
1162 let item_count = dims.len();
1163
1164 let shape = self.shape().dims;
1165
1166 // Accumulate the effective shifts for each dimension.
1167 let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1168 for i in 0..item_count {
1169 let dim = dims[i].expect_dim_index(D);
1170 accumulated_shifts[dim] += shifts[i].as_index();
1171 }
1172
1173 // Do this after we've checked the validity of `dims` and `shifts`.
1174 if self.shape().num_elements() == 0 {
1175 // If the tensor is empty, return it as is.
1176 return self;
1177 }
1178
1179 // Wrap the accumulated shifts, and filter out empty dimensions.
1180 let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1181 let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1182 for dim in 0..shape.len() {
1183 // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1184 let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1185 if shift == 0 {
1186 continue;
1187 }
1188
1189 effective_dims.push(dim);
1190 effective_shifts.push(shift);
1191 }
1192
1193 // If no shifts are needed, return the original tensor.
1194 if effective_shifts.is_empty() {
1195 return self;
1196 }
1197
1198 // At this point:
1199 // - `dims` contains the effective dimensions to roll, in index order,
1200 // - `shifts` contains the effective usize shifts for each dimension.
1201 // - Every shift is non-zero, and less than the size of the corresponding dimension.
1202 self.unchecked_roll(&effective_shifts, &effective_dims)
1203 }
1204
1205 /// `roll` internal implementation.
1206 ///
1207 /// ## Parameters
1208 ///
1209 /// - `shifts`: A slice of shifts corresponding to each dimension;
1210 /// must be non-empty, the same length as `dims`, and all ``1..<size>``.
1211 /// - `dims`: A slice of dimensions to roll; must be non-empty;
1212 /// the same length as `shifts`, and must not contain repeats.
1213 ///
1214 /// ## Panics
1215 ///
1216 /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1217 ///
1218 /// ## Returns
1219 ///
1220 /// A new tensor with the specified dimensions rolled by the given shifts.
1221 #[inline(always)]
1222 fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1223 #[cfg(debug_assertions)]
1224 {
1225 assert!(!shifts.is_empty());
1226 assert_eq!(
1227 shifts.len(),
1228 dims.len(),
1229 "Shifts and dimensions must align; found {} shifts and {} dims",
1230 shifts.len(),
1231 dims.len()
1232 );
1233
1234 let mut unique_dims = dims.to_vec();
1235 unique_dims.dedup();
1236
1237 assert_eq!(
1238 unique_dims.len(),
1239 dims.len(),
1240 "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1241 unique_dims.len(),
1242 dims.len()
1243 )
1244 }
1245
1246 let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1247
1248 if dims.len() == 1 {
1249 x
1250 } else {
1251 x.unchecked_roll(&shifts[1..], &dims[1..])
1252 }
1253 }
1254
1255 /// Returns a tensor containing the elements selected from the given slices.
1256 ///
1257 /// This method provides flexible tensor slicing with support for various range types,
1258 /// negative indices, and stepped slicing. The method accepts both single slices and
1259 /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1260 ///
1261 /// # Arguments
1262 ///
1263 /// * `slices` - Can be:
1264 /// - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1265 /// - An array of ranges (e.g., `[0..2, 1..4]`)
1266 /// - The [`s!`] macro output for advanced slicing with steps
1267 /// - a `&Vec<Slice>` or `&[Slice]`
1268 ///
1269 /// # Behavior
1270 ///
1271 /// - Supports partial and full slicing in any number of dimensions
1272 /// - Handles negative indices by wrapping from the end (-1 is the last element)
1273 /// - Automatically clamps ranges that exceed tensor dimensions
1274 /// - Supports stepped slicing for selecting every nth element
1275 /// - Negative steps reverse the selection order
1276 ///
1277 /// # Panics
1278 ///
1279 /// - If the number of slices exceeds the tensor's dimensions
1280 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1281 /// - If a step is zero
1282 ///
1283 /// # Examples
1284 ///
1285 /// ```rust
1286 /// use burn_tensor::backend::Backend;
1287 /// use burn_tensor::{Tensor, Shape, s};
1288 ///
1289 /// fn example<B: Backend>() {
1290 /// let device = B::Device::default();
1291 ///
1292 /// // Single dimension slicing - no brackets needed!
1293 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1294 /// let slice = tensor.clone().slice(2..8); // Simple range
1295 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1296 ///
1297 /// // Using s! macro for single dimension with step
1298 /// let slice = tensor.clone().slice(s![0..10;2]); // Every 2nd element
1299 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1300 ///
1301 /// // Reverse a dimension with negative step
1302 /// let slice = tensor.slice(s![..;-1]); // Reverse entire tensor
1303 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1304 ///
1305 /// // Multi-dimensional slicing
1306 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1307 ///
1308 /// // Array syntax for simple ranges
1309 /// let slice = tensor.clone().slice([1..3, 2..5]);
1310 /// assert_eq!(slice.dims(), [2, 3]);
1311 ///
1312 /// // Advanced multi-dimensional with s! macro
1313 /// let slice = tensor.clone().slice(s![0..4;2, ..;-1]); // Every 2nd row, reverse columns
1314 /// assert_eq!(slice.dims(), [2, 6]);
1315 ///
1316 /// // Complex 3D example with mixed slice types
1317 /// let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1318 /// let slice = tensor.slice(s![1..3, ..;2, -3..]); // Rows 1-2, every 2nd col, last 3 depth
1319 /// assert_eq!(slice.dims(), [2, 3, 3]);
1320 ///
1321 /// // Using negative indices
1322 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1323 /// let slice = tensor.slice(s![-2.., ..-1]); // Last 2 rows, all but last column
1324 /// assert_eq!(slice.dims(), [2, 5]);
1325 /// }
1326 /// ```
1327 ///
1328 /// # See Also
1329 ///
1330 /// - [`s!`] - The recommended macro for creating complex slice specifications
1331 /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1332 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1333 /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1334 ///
1335 /// [`s!`]: crate::s!
1336 pub fn slice<S>(self, slices: S) -> Self
1337 where
1338 S: SliceArg,
1339 {
1340 let shape = self.shape();
1341 let slices = slices.into_slices(&shape);
1342
1343 // Validate slices
1344 check!(TensorCheck::slice::<D>(&shape, &slices));
1345
1346 // Calculate output shape and check for empty slices
1347 let mut output_dims = shape.dims.clone();
1348 for (dim, slice) in slices.iter().enumerate() {
1349 output_dims[dim] = slice.output_size(shape.dims[dim]);
1350 }
1351
1352 // Return empty tensor if any dimension is 0 (empty slice)
1353 if output_dims.contains(&0) {
1354 return Self::empty(output_dims, &self.device());
1355 }
1356 Self::new(K::slice(self.primitive, &slices))
1357 }
1358
1359 /// Assigns values to a slice of the tensor and returns the updated tensor.
1360 ///
1361 /// This method supports advanced slicing with steps, including negative steps for reverse
1362 /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1363 /// providing powerful syntax for complex patterns.
1364 ///
1365 /// # Arguments
1366 ///
1367 /// * `slices` - Slice specification (same format as `slice` method)
1368 /// * `values` - Tensor with values to assign (must match slice dimensions)
1369 ///
1370 /// # Panics
1371 ///
1372 /// - If slices exceed tensor dimensions
1373 /// - If values dimensions don't match the selected slice shape
1374 /// - If a step is zero
1375 ///
1376 /// # Examples
1377 ///
1378 /// ```rust
1379 /// use burn_tensor::backend::Backend;
1380 /// use burn_tensor::{Tensor, s};
1381 ///
1382 /// fn example<B: Backend>() {
1383 /// let device = B::Device::default();
1384 ///
1385 /// // Simple assignment to a sub-region
1386 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1387 /// let values = Tensor::<B, 2>::ones([2, 3], &device);
1388 /// tensor = tensor.slice_assign([1..3, 2..5], values);
1389 /// // Now tensor[1..3, 2..5] contains ones
1390 ///
1391 /// // Single dimension assignment with step
1392 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1393 /// let values = Tensor::<B, 1>::ones([5], &device);
1394 /// tensor = tensor.slice_assign(s![0..10;2], values);
1395 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1396 ///
1397 /// // Reverse assignment with negative step
1398 /// let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1399 /// let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1400 /// tensor = tensor.slice_assign(s![..;-1], values);
1401 /// // Assigns in reverse: [14, 13, 12, 11, 10]
1402 ///
1403 /// // Complex multi-dimensional assignment
1404 /// let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1405 /// let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1406 /// tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1407 /// // Assigns to every 2nd row, every 2nd column, last 3 in depth
1408 ///
1409 /// // Mixed syntax example
1410 /// let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1411 /// let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1412 /// tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1413 /// // Creates a checkerboard pattern with ones
1414 /// }
1415 /// ```
1416 ///
1417 /// # See Also
1418 ///
1419 /// - [`s!`] - The recommended macro for creating complex slice specifications
1420 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1421 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1422 ///
1423 /// [`s!`]: crate::s!
1424 pub fn slice_assign<S>(self, slices: S, values: Self) -> Self
1425 where
1426 S: SliceArg,
1427 {
1428 let shape = self.shape();
1429 let slices = slices.into_slices(&shape);
1430
1431 // Check if any slice produces 0 elements (empty assignment).
1432 // Empty assignments are no-ops and would cause issues in backend implementations.
1433 let is_empty_assignment = slices
1434 .iter()
1435 .enumerate()
1436 .any(|(i, slice)| slice.output_size(shape.dims[i]) == 0);
1437
1438 if is_empty_assignment {
1439 return self;
1440 }
1441
1442 check!(TensorCheck::slice_assign::<D>(
1443 &shape,
1444 &values.shape(),
1445 &slices
1446 ));
1447
1448 Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1449 }
1450
1451 /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1452 ///
1453 /// Like other slice methods, accepts both single slices and arrays. However, this method
1454 /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1455 /// with a constant tensor for stepped patterns.
1456 ///
1457 /// # Arguments
1458 ///
1459 /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1460 /// * `value` - The value to fill the slice with
1461 ///
1462 /// # Panics
1463 ///
1464 /// - If slices exceed tensor dimensions
1465 /// - If any slice has a step != 1 (not yet supported)
1466 ///
1467 /// # Examples
1468 ///
1469 /// ```rust
1470 /// use burn_tensor::backend::Backend;
1471 /// use burn_tensor::{Tensor, s};
1472 ///
1473 /// fn example<B: Backend>() {
1474 /// let device = B::Device::default();
1475 ///
1476 /// // Simple fill for a single dimension
1477 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1478 /// tensor = tensor.slice_fill(2..5, 1.0);
1479 /// // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1480 ///
1481 /// // Multi-dimensional fill
1482 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1483 /// tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1484 /// // Fills the rectangle at rows 1-2, columns 2-4 with -1
1485 ///
1486 /// // Using negative indices
1487 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1488 /// tensor = tensor.slice_fill(-3.., 2.0);
1489 /// // Fills the last 3 elements with 2.0
1490 ///
1491 /// // Complex multi-dimensional example
1492 /// let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1493 /// tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1494 /// // Sets rows 1-2, all columns, last 2 in depth to 0
1495 ///
1496 /// // Stepped slicing is supported
1497 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1498 /// tensor = tensor.slice_fill(s![0..10;2], 1.0);
1499 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1500 /// }
1501 /// ```
1502 ///
1503 /// # See Also
1504 ///
1505 /// - [`s!`] - The macro for creating slice specifications with steps
1506 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1507 /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1508 ///
1509 /// [`s!`]: crate::s!
1510 pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self
1511 where
1512 S: SliceArg,
1513 {
1514 let shape = self.shape();
1515 let slices = slices.into_slices(&shape);
1516
1517 check!(TensorCheck::slice::<D>(&shape, &slices));
1518
1519 let slice_shape = shape.slice(&slices).unwrap();
1520 let value = Tensor::<B, 1, K>::from_data_dtype(
1521 [value.elem::<K::Elem>()],
1522 &self.device(),
1523 self.dtype(),
1524 );
1525 let value = value.expand(slice_shape);
1526 self.slice_assign(&slices, value)
1527 }
1528
1529 /// Returns a new tensor with the specified dimension sliced.
1530 ///
1531 /// # Arguments
1532 ///
1533 /// * `dim`: The dimension to slice.
1534 /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1535 /// slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1536 ///
1537 /// # Returns
1538 ///
1539 /// A new tensor with the specified dimension sliced.
1540 ///
1541 /// # Panics
1542 ///
1543 /// If the slice is out of bounds for the specified dimension.
1544 ///
1545 /// # Examples
1546 ///
1547 /// ```rust
1548 /// # use burn_tensor::{Tensor, s};
1549 /// # use burn_tensor::backend::Backend;
1550 /// #
1551 /// # fn example<B: Backend>() {
1552 /// # let device = B::Device::default();
1553 /// let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1554 ///
1555 /// // Simple range slicing
1556 /// let sliced = tensor.clone().slice_dim(1, 1..3);
1557 /// assert_eq!(sliced.shape().dims, [3, 2, 5]);
1558 ///
1559 /// // Slicing with step - take every 2nd element
1560 /// let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1561 /// assert_eq!(sliced.shape().dims, [3, 4, 3]); // Takes indices 0, 2, 4
1562 ///
1563 /// // Reverse slicing with negative step
1564 /// let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1565 /// assert_eq!(sliced.shape().dims, [3, 4, 5]); // Reverses dimension 1
1566 ///
1567 /// // Select from index 2 with step 3
1568 /// let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1569 /// assert_eq!(sliced.shape().dims, [1, 4, 5]); // Takes only index 2
1570 ///
1571 /// // Select single index (reduces dimension to size 1)
1572 /// let sliced = tensor.slice_dim(0, 1);
1573 /// assert_eq!(sliced.shape().dims, [1, 4, 5]);
1574 /// # }
1575 /// ```
1576 ///
1577 /// # See Also
1578 ///
1579 /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1580 /// - [`s!`] - The macro for creating complex slice specifications
1581 ///
1582 /// [`s!`]: crate::s!
1583 pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1584 where
1585 S: Into<Slice>,
1586 {
1587 check!(TensorCheck::check_dim::<D>(dim));
1588 let slice: Slice = slice.into();
1589
1590 let mut slices = vec![Slice::full(); D];
1591 slices[dim] = slice;
1592
1593 self.slice(&slices)
1594 }
1595
1596 /// Returns the device of the current tensor.
1597 pub fn device(&self) -> B::Device {
1598 K::device(&self.primitive)
1599 }
1600
1601 /// Move the tensor to the given device.
1602 pub fn to_device(self, device: &B::Device) -> Self {
1603 Self::new(K::to_device(self.primitive, device))
1604 }
1605
1606 /// Select tensor elements along the given dimension corresponding to the given indices.
1607 ///
1608 /// # Arguments
1609 ///
1610 /// * `dim` - The dimension to select from. Supports negative indexing.
1611 /// * `indices` - The indices of the elements to select.
1612 ///
1613 /// # Example
1614 ///
1615 /// ```rust
1616 /// use burn_tensor::backend::Backend;
1617 /// use burn_tensor::{Tensor, Int};
1618 ///
1619 /// fn example<B: Backend>() {
1620 /// let device = B::Device::default();
1621 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1622 /// let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1623 /// let tensor = tensor.select(0, indices);
1624 /// println!("{tensor}");
1625 /// // [[1.0, -2.0, 3.0]]
1626 /// }
1627 /// ```
1628 pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1629 let dim = dim.expect_dim_index(D);
1630 check!(TensorCheck::select::<D>(dim));
1631 Self::new(K::select(self.primitive, dim, indices.primitive))
1632 }
1633
1634 /// Assign the selected elements along the given dimension corresponding to the given indices
1635 /// from the value tensor to the original tensor using sum reduction.
1636 ///
1637 /// # Note
1638 /// For booleans, the sum operator is logical or.
1639 ///
1640 /// # Arguments
1641 ///
1642 /// * `dim` - The dimension along which to select. Supports negative indexing.
1643 /// * `indices` - The indices to select from the tensor.
1644 /// * `values` - The values to assign to the selected indices.
1645 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1646 ///
1647 /// # Example
1648 ///
1649 /// Example using a 3D tensor:
1650 ///
1651 /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1652 /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1653 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1654 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1655 ///
1656 /// # Warning
1657 ///
1658 /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1659 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1660 pub fn select_assign(
1661 self,
1662 dim: impl AsIndex,
1663 indices: Tensor<B, 1, Int>,
1664 values: Tensor<B, D, K>,
1665 update: IndexingUpdateOp,
1666 ) -> Self {
1667 let dim = dim.expect_dim_index(D);
1668 check!(TensorCheck::select_assign::<D>(
1669 dim,
1670 &indices.shape(),
1671 &values.shape()
1672 ));
1673
1674 Self::new(K::select_assign(
1675 self.primitive,
1676 dim,
1677 indices.primitive,
1678 values.primitive,
1679 update,
1680 ))
1681 }
1682
1683 /// Update the given tensor with the value tensor where the mask is true.
1684 ///
1685 /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1686 /// a scalar.
1687 ///
1688 /// # Example
1689 ///
1690 /// ```rust
1691 /// use burn_tensor::backend::Backend;
1692 /// use burn_tensor::{Tensor, Shape, Bool};
1693 ///
1694 /// fn example<B: Backend>() {
1695 /// let device = B::Device::default();
1696 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1697 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1698 /// let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1699 /// let tensor = tensor.mask_where(mask, value);
1700 /// println!("{tensor}");
1701 /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1702 /// }
1703 /// ```
1704 pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1705 Self::new(K::mask_where(
1706 self.primitive,
1707 mask.primitive,
1708 value.primitive,
1709 ))
1710 }
1711
1712 /// Update the given tensor with the value where the mask is true.
1713 ///
1714 /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1715 /// a tensor.
1716 ///
1717 /// # Example
1718 ///
1719 /// ```rust
1720 /// use burn_tensor::backend::Backend;
1721 /// use burn_tensor::{Tensor, Shape, Bool};
1722 ///
1723 /// fn example<B: Backend>() {
1724 /// let device = B::Device::default();
1725 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1726 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1727 /// let tensor = tensor.mask_fill(mask, 3.0);
1728 /// println!("{tensor}");
1729 /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1730 /// }
1731 /// ```
1732 pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1733 Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
1734 }
1735
1736 /// Gather tensor elements corresponding to the given indices from the specified dim.
1737 ///
1738 /// Example using a 3D tensor:
1739 ///
1740 /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1741 /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1742 /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1743 ///
1744 /// # Notes
1745 ///
1746 /// The index tensor should have the same shape as the original tensor except for the dim
1747 /// specified.
1748 ///
1749 /// # Warning
1750 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1751 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1752 pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1753 check!(TensorCheck::gather::<D>(
1754 dim,
1755 &self.shape(),
1756 &indices.shape()
1757 ));
1758
1759 Self::new(K::gather(dim, self.primitive, indices.primitive))
1760 }
1761
1762 /// Assign the gathered elements corresponding to the given indices along the specified dimension
1763 /// from the value tensor to the original tensor using sum reduction.
1764 ///
1765 /// Example using a 3D tensor:
1766 ///
1767 /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1768 /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1769 /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1770 ///
1771 /// # Arguments
1772 /// * `dim` - The axis along which to scatter elements.
1773 /// * `indices` - The indices of the elements to scatter.
1774 /// * `values` - The values to scatter into the tensor.
1775 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1776 ///
1777 /// # Notes
1778 ///
1779 /// The index tensor should have the same shape as the original tensor except for the specified
1780 /// dimension. The value and index tensors should have the same shape.
1781 ///
1782 /// Other references to the input tensor will not be modified by this operation.
1783 ///
1784 /// # Warning
1785 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1786 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1787 pub fn scatter(
1788 self,
1789 dim: usize,
1790 indices: Tensor<B, D, Int>,
1791 values: Self,
1792 update: IndexingUpdateOp,
1793 ) -> Self {
1794 check!(TensorCheck::scatter::<D>(
1795 dim,
1796 &self.shape(),
1797 &indices.shape(),
1798 &values.shape()
1799 ));
1800
1801 Self::new(K::scatter(
1802 dim,
1803 self.primitive,
1804 indices.primitive,
1805 values.primitive,
1806 update,
1807 ))
1808 }
1809
1810 /// Converts the data of the current tensor.
1811 ///
1812 /// # Note
1813 ///
1814 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1815 /// tensors at once. This may improve laziness, especially if executed on a different
1816 /// thread in native environments.
1817 pub fn into_data(self) -> TensorData {
1818 self.try_into_data().expect(
1819 "Error while reading data: use `try_into_data` instead to catch the error at runtime",
1820 )
1821 }
1822
1823 /// Converts the data of the current tensor and returns any error that might have occurred since the
1824 /// last time the device was synchronized.
1825 ///
1826 /// # Note
1827 ///
1828 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1829 /// tensors at once. This may improve laziness, especially if executed on a different
1830 /// thread in native environments.
1831 pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {
1832 crate::try_read_sync(self.into_data_async()).expect(
1833 "Failed to read tensor data synchronously.
1834 This can happen on platforms that don't support blocking futures like WASM.
1835 If possible, try using into_data_async instead.",
1836 )
1837 }
1838
1839 /// Converts the data of the current tensor.
1840 ///
1841 /// # Note
1842 ///
1843 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1844 /// tensors at once. This may improve laziness, especially if executed on a different
1845 /// thread in native environments.
1846 pub fn to_data(&self) -> TensorData {
1847 self.clone().into_data()
1848 }
1849
1850 /// Returns the data of the current tensor.
1851 pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {
1852 K::into_data_async(self.primitive).await
1853 }
1854
1855 /// Returns the data of the current tensor.
1856 pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {
1857 self.clone().into_data_async().await
1858 }
1859
1860 /// Create a tensor from the given data on the given device.
1861 pub fn from_data<T>(data: T, device: &B::Device) -> Self
1862 where
1863 T: Into<TensorData>,
1864 {
1865 let data = data.into();
1866 check!(TensorCheck::creation_ops::<D>(
1867 "From Data",
1868 data.shape.as_slice()
1869 ));
1870 Self::new(K::from_data(data, device))
1871 }
1872
1873 /// Create a tensor from the given data on the given device enforcing the given data type.
1874 pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
1875 where
1876 T: Into<TensorData>,
1877 {
1878 let data = data.into();
1879 check!(TensorCheck::creation_ops::<D>(
1880 "From Data",
1881 data.shape.as_slice()
1882 ));
1883 Self::new(K::from_data_dtype(data, device, dtype))
1884 }
1885
1886 /// Repeat the tensor along the given dimension.
1887 ///
1888 /// The output tensor has the same shape, except along the given dimension.
1889 ///
1890 /// # Arguments
1891 /// - `dim`: The dimension to repeat.
1892 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1893 ///
1894 /// # Returns
1895 ///
1896 /// A new tensor with the given dimension repeated `times` times.
1897 ///
1898 /// # Example
1899 ///
1900 /// ```rust
1901 /// use burn_tensor::backend::Backend;
1902 /// use burn_tensor::Tensor;
1903 ///
1904 /// fn example<B: Backend>() {
1905 /// let device = Default::default();
1906 /// // Create a 2D tensor with dimensions [3, 2]
1907 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1908 ///
1909 /// // Repeat the tensor along the dimension 0 twice.
1910 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1911 /// // The resulting tensor will have dimensions [6, 2].
1912 /// let repeated = tensor.repeat_dim(0, 2);
1913 /// println!("{repeated}");
1914 /// }
1915 /// ```
1916 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1917 if times > 0 {
1918 Self::new(K::repeat_dim(self.primitive, dim, times))
1919 } else {
1920 let shape = self.shape().repeat(dim, times).unwrap();
1921 Self::empty(shape, &self.device())
1922 }
1923 }
1924
1925 /// Repeat the tensor along the given dimensions.
1926 /// # Arguments
1927 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1928 ///
1929 /// # Returns
1930 ///
1931 /// A new tensor with the given dimensions repeated `times` times.
1932 ///
1933 /// # Panics
1934 ///
1935 /// If `sizes` contains more elements than the number of dimensions.
1936 ///
1937 /// # Example
1938 ///
1939 /// ```rust
1940 ///
1941 /// use burn_tensor::backend::Backend;
1942 /// use burn_tensor::Tensor;
1943 ///
1944 /// fn example<B: Backend>() {
1945 /// let device = Default::default();
1946 /// // Create a 2D tensor with dimensions [3, 2]
1947 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1948 ///
1949 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1950 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1951 /// // The resulting tensor will have dimensions [6, 2].
1952 /// let repeated = tensor.repeat(&[2, 1]);
1953 /// }
1954 /// ```
1955 pub fn repeat(self, sizes: &[usize]) -> Self {
1956 if sizes.contains(&0) {
1957 let mut shape = self.shape();
1958 for (dim, ×) in sizes.iter().enumerate() {
1959 shape = shape.repeat(dim, times).unwrap();
1960 }
1961
1962 return Self::empty(shape, &self.device());
1963 }
1964
1965 let mut tensor = self;
1966 for (dim, ×) in sizes.iter().enumerate() {
1967 if times > 1 {
1968 tensor = tensor.repeat_dim(dim, times);
1969 }
1970 }
1971 tensor
1972 }
1973
1974 /// Applies element-wise equal comparison.
1975 ///
1976 /// # Returns
1977 /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1978 ///
1979 /// # Panics
1980 ///
1981 /// If the two tensors don't have the same shape.
1982 ///
1983 /// # Example
1984 ///
1985 /// ```rust
1986 /// use burn_tensor::backend::Backend;
1987 /// use burn_tensor::Tensor;
1988 ///
1989 /// fn example<B: Backend>() {
1990 /// let device = Default::default();
1991 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1992 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1993 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
1994 /// // [[false, true], [true, true], [true, true]]
1995 /// let equal = t1.equal(t2);
1996 /// println!("{equal}");
1997 /// }
1998 /// ```
1999 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
2000 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
2001 Tensor::new(K::equal(self.primitive, other.primitive))
2002 }
2003
2004 /// Applies element-wise non-equality comparison.
2005 ///
2006 /// # Returns
2007 /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
2008 ///
2009 /// # Panics
2010 ///
2011 /// If the two tensors don't have the same shape.
2012 ///
2013 /// # Example
2014 ///
2015 /// ```rust
2016 /// use burn_tensor::backend::Backend;
2017 /// use burn_tensor::Tensor;
2018 ///
2019 /// fn example<B: Backend>() {
2020 /// let device = Default::default();
2021 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2022 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2023 /// // Compare the elements of the two 2D tensors for inequality.
2024 /// // [[true, false], [false, false], [false, false]]
2025 /// let not_equal = t1.not_equal(t2);
2026 /// println!("{not_equal}");
2027 /// }
2028 /// ```
2029 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
2030 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
2031 Tensor::new(K::not_equal(self.primitive, other.primitive))
2032 }
2033
2034 /// Applies element wise equal comparison and returns a boolean tensor.
2035 ///
2036 /// # Arguments
2037 ///
2038 /// * `other` - The element to compare.
2039 ///
2040 /// # Example
2041 ///
2042 /// ```rust
2043 /// use burn_tensor::backend::Backend;
2044 /// use burn_tensor::{Tensor, Shape};
2045 ///
2046 /// fn example<B: Backend>() {
2047 /// let device = B::Device::default();
2048 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2049 /// let tensor = tensor.equal_elem(3.0);
2050 /// println!("{tensor}");
2051 /// // [[false, false, true], [false, false, false]]
2052 /// }
2053 /// ```
2054 pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2055 Tensor::new(K::equal_elem(self.primitive, other.elem()))
2056 }
2057
2058 /// Applies element wise non-equality comparison and returns a boolean tensor.
2059 ///
2060 /// # Arguments
2061 ///
2062 /// * `other` - The element to compare.
2063 ///
2064 /// # Example
2065 ///
2066 /// ```rust
2067 /// use burn_tensor::backend::Backend;
2068 /// use burn_tensor::{Tensor, Shape};
2069 ///
2070 /// fn example<B: Backend>() {
2071 /// let device = B::Device::default();
2072 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2073 /// let tensor = tensor.not_equal_elem(3.0);
2074 /// println!("{tensor}");
2075 /// // [[true, true, false], [true, true, true]]
2076 /// }
2077 /// ```
2078 pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2079 Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
2080 }
2081
2082 /// Concatenates all tensors into a new one along the given dimension.
2083 ///
2084 /// # Panics
2085 ///
2086 /// - If `dim` is higher than the rank.
2087 /// - If `tensors` is an empty vector.
2088 /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
2089 ///
2090 /// # Example
2091 ///
2092 /// ```rust
2093 /// use burn_tensor::backend::Backend;
2094 /// use burn_tensor::Tensor;
2095 ///
2096 /// fn example<B: Backend>() {
2097 /// let device = Default::default();
2098 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
2099 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2100 ///
2101 /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
2102 /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
2103 /// // The resulting tensor will have shape [2, 7].
2104 /// let concat = Tensor::cat(vec![t1, t2], 1);
2105 /// println!("{concat}");
2106 /// }
2107 /// ```
2108 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
2109 check!(TensorCheck::cat(&tensors, dim));
2110
2111 // Filter out tensors with size 0 along the concatenation dimension.
2112 // Empty tensors don't contribute to the output and would cause issues
2113 // in backend implementations (e.g., division by zero in slice_assign).
2114 // Safety: TensorCheck::cat ensures tensors is non-empty
2115 let first_tensor = tensors.first().unwrap();
2116 let device = first_tensor.device();
2117 let mut shape = first_tensor.shape();
2118
2119 let non_empty_primitives: Vec<_> = tensors
2120 .into_iter()
2121 .filter(|t| t.shape().dims[dim] > 0)
2122 .map(|t| t.primitive)
2123 .collect();
2124
2125 // If all tensors were empty, return an empty tensor with size 0 on concat dim
2126 if non_empty_primitives.is_empty() {
2127 shape.dims[dim] = 0;
2128 return Self::empty(shape, &device);
2129 }
2130
2131 Self::new(K::cat(non_empty_primitives, dim))
2132 }
2133
2134 /// Concatenates all tensors into a new one along a new dimension.
2135 ///
2136 /// # Panics
2137 ///
2138 /// - If all tensors don't have the same shape.
2139 /// - If given dimension is not with range of 0..D2
2140 ///
2141 /// # Example
2142 ///
2143 /// ```rust
2144 /// use burn_tensor::backend::Backend;
2145 /// use burn_tensor::Tensor;
2146 ///
2147 /// fn example<B: Backend>() {
2148 /// let device = Default::default();
2149 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2150 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2151 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2152 ///
2153 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
2154 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
2155 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
2156 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
2157 /// // The resulting tensor will have shape [3, 2, 3].
2158 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
2159 /// println!("{stacked}");
2160 /// }
2161 /// ```
2162 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
2163 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
2164 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
2165 Tensor::<B, D2, K>::cat(tensors, dim)
2166 }
2167
2168 /// Iterate over slices of tensors alongside a given dimension.
2169 ///
2170 /// # Panics
2171 ///
2172 /// If given dimension is greater than or equal to tensor rank.
2173 ///
2174 /// # Returns
2175 ///
2176 /// A tensor iterator.
2177 ///
2178 /// # Example
2179 ///
2180 /// ```rust
2181 /// use burn_tensor::backend::Backend;
2182 /// use burn_tensor::Tensor;
2183 /// fn example<B: Backend>() {
2184 /// let device = Default::default();
2185 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2186 /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
2187 /// let iter = tensor.iter_dim(0);
2188 /// for (i,tensor) in iter.enumerate() {
2189 /// println!("Tensor {}: {}", i, tensor);
2190 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
2191 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
2192 /// }
2193 /// }
2194 /// ```
2195 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
2196 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
2197 DimIter::new(self, dim)
2198 }
2199
2200 /// Returns a new tensor with the given dimension narrowed to the given range.
2201 ///
2202 /// # Panics
2203 ///
2204 /// - If the dimension is greater than the number of dimensions of the tensor.
2205 /// - If the given range exceeds the number of elements on the given dimension.
2206 ///
2207 /// # Returns
2208 ///
2209 /// A new tensor with the given dimension narrowed to the given range.
2210 ///
2211 /// # Example
2212 ///
2213 /// ```rust
2214 /// use burn_tensor::backend::Backend;
2215 /// use burn_tensor::Tensor;
2216 ///
2217 /// fn example<B: Backend>() {
2218 /// let device = Default::default();
2219 /// // Create a 2D tensor with dimensions [4, 3]
2220 /// let tensor = Tensor::<B, 2>::from_data(
2221 /// [
2222 /// [3.0, 4.9, 2.0],
2223 /// [2.0, 1.9, 3.0],
2224 /// [6.0, 1.5, 7.0],
2225 /// [3.0, 4.9, 9.0],
2226 /// ],
2227 /// &device,
2228 /// );
2229 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
2230 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
2231 /// // The resulting tensor will have dimensions [3, 3].
2232 /// let narrowed = tensor.narrow(0, 1, 3);
2233 /// println!("{narrowed}");
2234 /// }
2235 /// ```
2236 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
2237 check!(TensorCheck::dim_ops::<D>("narrow", dim));
2238 check!(TensorCheck::narrow(&self, dim, start, length));
2239 let dims = self.dims();
2240
2241 let ranges: [Range<usize>; D] = dims
2242 .iter()
2243 .enumerate()
2244 .map(|(i, d)| {
2245 if i == dim {
2246 start..(start + length)
2247 } else {
2248 0..*d
2249 }
2250 })
2251 .collect::<Vec<_>>()
2252 .try_into()
2253 .unwrap();
2254
2255 Self::slice(self, ranges)
2256 }
2257
2258 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
2259 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2260 ///
2261 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2262 /// Otherwise all chunks will be of equal size except for the last one.
2263 ///
2264 /// # Panics
2265 ///
2266 /// If the dimension is greater than the number of dimensions of the tensor.
2267 ///
2268 /// # Returns
2269 /// A vector of tensors.
2270 ///
2271 /// # Example
2272 ///
2273 /// ```rust
2274 /// use burn_tensor::backend::Backend;
2275 /// use burn_tensor::Tensor;
2276 ///
2277 /// fn example<B: Backend>() {
2278 /// let device = Default::default();
2279 /// // Create a 2D tensor with dimensions [4, 3]
2280 /// let tensor = Tensor::<B, 2>::from_data(
2281 /// [
2282 /// [3.0, 4.9, 2.0],
2283 /// [2.0, 1.9, 3.0],
2284 /// [6.0, 1.5, 7.0],
2285 /// [3.0, 4.9, 9.0],
2286 /// ],
2287 /// &device,
2288 /// );
2289 /// // Split the tensor along the dimension 1 into 2 chunks.
2290 /// // The first chuck will have shape [4, 2]:
2291 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
2292 /// // The second chunk will have shape [4, 1]:
2293 /// // [[2.0], [3.0], [7.0], [9.0]]
2294 /// let chunks = tensor.chunk(2, 1);
2295 /// println!("{chunks:?}");
2296 /// }
2297 /// ```
2298 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
2299 check!(TensorCheck::dim_ops::<D>("chunk", dim));
2300 let size = self.shape().dims[dim];
2301 if size < chunks {
2302 return (0..size)
2303 .map(|i| Self::narrow(self.clone(), dim, i, 1))
2304 .collect();
2305 }
2306
2307 let mut tensors = Vec::with_capacity(chunks);
2308 let mut sum_chunk_size = 0;
2309 if size.is_multiple_of(chunks) {
2310 let chunk_size = size / chunks;
2311 for _ in 0..chunks {
2312 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2313 sum_chunk_size += chunk_size;
2314 }
2315 } else {
2316 let chunk_size = (size / chunks) + 1; // assumes not divisible
2317 for _ in 0..chunks - 1 {
2318 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2319 sum_chunk_size += chunk_size;
2320 }
2321 let remainder = size % chunk_size;
2322 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
2323 }
2324
2325 tensors
2326 }
2327
2328 /// Splits the tensor into chunks of a specified size along a given dimension.
2329 /// Each chunk is a view of the original tensor.
2330 ///
2331 /// If the tensor size along the given dimension is not divisible by `split_size`,
2332 /// then the last chunk will be smaller.
2333 ///
2334 /// # Panics
2335 ///
2336 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2337 ///
2338 /// # Returns
2339 ///
2340 /// A vector of tensors.
2341 ///
2342 /// # Example
2343 /// ```rust
2344 /// use burn_tensor::backend::Backend;
2345 /// use burn_tensor::Tensor;
2346 ///
2347 /// fn example<B: Backend>() {
2348 /// let device = Default::default();
2349 /// // Create a 1D tensor with 5 elements
2350 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2351 /// // Split the tensor into chunks of size 2 along dimension 0
2352 /// let chunks = tensor.split(2, 0);
2353 /// // The result is a vector of tensors:
2354 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2355 /// println!("{:?}", chunks);
2356 /// }
2357 /// ```
2358 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2359 check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2360 let size = self.shape().dims[dim];
2361 let mut tensors = Vec::new();
2362
2363 let mut start = 0;
2364 while start < size {
2365 let length = usize::min(split_size, size - start);
2366 tensors.push(Self::narrow(self.clone(), dim, start, length));
2367 start += length;
2368 }
2369
2370 tensors
2371 }
2372
2373 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2374 /// Each chunk is a view of the original tensor.
2375 ///
2376 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2377 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2378 ///
2379 /// # Panics
2380 ///
2381 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2382 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2383 ///
2384 /// # Returns
2385 ///
2386 /// A vector of tensors.
2387 ///
2388 /// # Example
2389 /// ```rust
2390 /// use burn_tensor::backend::Backend;
2391 /// use burn_tensor::Tensor;
2392 ///
2393 /// fn example<B: Backend>() {
2394 /// let device = Default::default();
2395 /// // Create a 1D tensor with 5 elements
2396 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2397 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
2398 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2399 /// // The result is a vector of tensors:
2400 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2401 /// println!("{:?}", chunks);
2402 /// }
2403 /// ```
2404 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2405 check!(TensorCheck::split_with_sizes::<D>(
2406 &self.shape(),
2407 &split_sizes,
2408 dim
2409 ));
2410 let mut tensors = Vec::new();
2411
2412 let mut start = 0;
2413 for length in split_sizes {
2414 if length == 0 {
2415 continue;
2416 }
2417 tensors.push(Self::narrow(self.clone(), dim, start, length));
2418 start += length;
2419 }
2420
2421 tensors
2422 }
2423
2424 /// Tests if any element in the `tensor` evaluates to True.
2425 ///
2426 /// # Arguments
2427 ///
2428 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2429 ///
2430 /// # Returns
2431 ///
2432 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2433 /// evaluates to True, False otherwise.
2434 ///
2435 /// # Example
2436 ///
2437 /// ```rust
2438 /// use burn_tensor::backend::Backend;
2439 /// use burn_tensor::{Tensor, Bool};
2440 ///
2441 /// fn example<B: Backend>() {
2442 /// let device = Default::default();
2443 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2444 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2445 ///
2446 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2447 /// let any_tensor = tensor.any();
2448 /// println!("{}", any_tensor);
2449 /// // Tensor { data: [true], ... }
2450 ///
2451 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2452 /// let any_tensor_two = tensor_two.any();
2453 /// println!("{}", any_tensor_two);
2454 /// // Tensor { data: [false], ... }
2455 /// }
2456 /// ```
2457 pub fn any(self) -> Tensor<B, 1, Bool> {
2458 Tensor::new(K::any(self.primitive))
2459 }
2460
2461 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2462 ///
2463 /// # Arguments
2464 ///
2465 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2466 /// * `dim` - The axis along which to test.
2467 ///
2468 /// # Returns
2469 ///
2470 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2471 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2472 /// evaluates to True, False otherwise.
2473 ///
2474 /// # Example
2475 ///
2476 /// ```rust
2477 /// use burn_tensor::backend::Backend;
2478 /// use burn_tensor::{Tensor, Bool};
2479 ///
2480 /// fn example<B: Backend>() {
2481 /// let device = Default::default();
2482 /// let tensor =
2483 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2484 /// // Check if any element in the tensor evaluates to True along the dimension 1.
2485 /// // [[true], [true]],
2486 /// let any_dim = tensor.clone().any_dim(1);
2487 /// println!("{any_dim}");
2488 /// }
2489 /// ```
2490 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2491 Tensor::new(K::any_dim(self.primitive, dim))
2492 }
2493
2494 /// Tests if all elements in the `tensor` evaluate to True.
2495 ///
2496 /// # Arguments
2497 ///
2498 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2499 ///
2500 /// # Returns
2501 ///
2502 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2503 /// evaluate to True, False otherwise.
2504 ///
2505 /// # Example
2506 ///
2507 /// ```rust
2508 /// use burn_tensor::backend::Backend;
2509 /// use burn_tensor::{Tensor, Bool};
2510 ///
2511 /// fn example<B: Backend>() {
2512 /// let device = Default::default();
2513 /// let tensor =
2514 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2515 /// // Check if all elements in the tensor evaluate to True (which is not the case).
2516 /// // [false]
2517 /// let all = tensor.all();
2518 /// println!("{all}");
2519 /// }
2520 /// ```
2521 pub fn all(self) -> Tensor<B, 1, Bool> {
2522 Tensor::new(K::all(self.primitive))
2523 }
2524
2525 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2526 ///
2527 /// # Arguments
2528 ///
2529 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2530 /// * `dim` - The axis along which to test.
2531 ///
2532 /// # Returns
2533 ///
2534 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2535 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2536 /// evaluates to True, False otherwise.
2537 ///
2538 /// # Example
2539 ///
2540 /// ```rust
2541 /// use burn_tensor::backend::Backend;
2542 /// use burn_tensor::{Tensor, Bool};
2543 ///
2544 /// fn example<B: Backend>() {
2545 /// let device = Default::default();
2546 /// let tensor =
2547 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2548 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
2549 /// // [[true, true, false]]
2550 /// let all_dim = tensor.clone().all_dim(0);
2551 /// println!("{all_dim}");
2552 /// }
2553 /// ```
2554 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2555 Tensor::new(K::all_dim(self.primitive, dim))
2556 }
2557
2558 /// Convert the tensor into a scalar.
2559 ///
2560 /// # Panics
2561 ///
2562 /// - If the tensor doesn't have one element.
2563 /// - If the backend fails to read the tensor data synchronously.
2564 ///
2565 /// # Returns
2566 ///
2567 /// The scalar value of the tensor.
2568 ///
2569 /// # Example
2570 ///
2571 /// ```rust
2572 /// use burn_tensor::backend::Backend;
2573 /// use burn_tensor::Tensor;
2574 ///
2575 /// fn example<B: Backend>() {
2576 /// let device = Default::default();
2577 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2578 /// // Convert the tensor with a single element into a scalar.
2579 /// let scalar = tensor.into_scalar();
2580 /// println!("{scalar}");
2581 /// }
2582 /// ```
2583 pub fn into_scalar(self) -> K::Elem {
2584 crate::try_read_sync(self.into_scalar_async())
2585 .expect(
2586 "Failed to read tensor data synchronously. This can happen on platforms
2587 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2588 )
2589 .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime")
2590 }
2591
2592 /// Convert the tensor into a scalar and returns any error that might have occurred since the
2593 /// last time the device was synchronized.
2594 ///
2595 /// # Panics
2596 ///
2597 /// - If the tensor doesn't have one element.
2598 /// - If the backend fails to read the tensor data synchronously.
2599 ///
2600 /// # Returns
2601 ///
2602 /// The scalar value of the tensor.
2603 pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {
2604 crate::try_read_sync(self.into_scalar_async()).expect(
2605 "Failed to read tensor data synchronously. This can happen on platforms
2606 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2607 )
2608 }
2609
2610 /// Convert the tensor into a scalar.
2611 ///
2612 /// # Panics
2613 ///
2614 /// If the tensor doesn't have one element.
2615 pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {
2616 check!(TensorCheck::into_scalar::<D>(&self.shape()));
2617
2618 Ok(self.into_data_async().await?.iter().next().unwrap())
2619 }
2620
2621 /// Broadcast the tensor to the given shape.
2622 ///
2623 /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2624 /// (which can be inferred with `-1`).
2625 ///
2626 /// # Arguments
2627 ///
2628 /// * `shape` - The shape to broadcast the tensor to.
2629 /// Can contain -1 for dimensions that should be inferred.
2630 /// The number of elements in the shape must be greater or equal as
2631 /// the number of dimensions of the tensor.
2632 ///
2633 /// # Panics
2634 ///
2635 /// If the tensor cannot be broadcasted to the given shape.
2636 ///
2637 /// # Returns
2638 ///
2639 /// A new tensor with the given shape.
2640 ///
2641 /// # Example
2642 ///
2643 /// ```rust
2644 /// use burn_tensor::backend::Backend;
2645 /// use burn_tensor::Tensor;
2646 ///
2647 /// fn example<B: Backend>() {
2648 /// let device = Default::default();
2649 /// // Create a 2D tensor with dimensions [3, 1]
2650 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2651 /// // Expand the tensor to a new shape [3, 4]
2652 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2653 /// let expanded = tensor.expand([3, 4]);
2654 /// println!("{}", expanded);
2655 /// }
2656 /// ```
2657 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2658 let shape = shape.into_shape(&self.shape());
2659 check!(TensorCheck::expand::<D, D2>(
2660 "expand",
2661 &self.shape(),
2662 &shape,
2663 ));
2664
2665 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2666 }
2667
2668 /// Unfold windows along a dimension.
2669 ///
2670 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2671 /// where windows are advanced by `step` at each index.
2672 ///
2673 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2674 ///
2675 /// The new view will have the unfolded dimension replaced by two dimensions;
2676 /// one in the position of the original dimension, with size equal to the number of windows,
2677 /// and one appended to the right-most position, with size equal to `size`.
2678 ///
2679 /// # Warning
2680 ///
2681 /// For the `ndarray` and `candle` backends; this is not a view but a copy
2682 /// with duplicated data.
2683 ///
2684 /// # Arguments
2685 ///
2686 /// * `dim` - the dimension to unfold.
2687 /// * `size` - the size of each unfolded window.
2688 /// * `step` - the step between each window.
2689 ///
2690 /// # Returns
2691 ///
2692 /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2693 pub fn unfold<const D2: usize, I: AsIndex>(
2694 self,
2695 dim: I,
2696 size: usize,
2697 step: usize,
2698 ) -> Tensor<B, D2, K> {
2699 let dim = dim.expect_dim_index(D);
2700 check!(TensorCheck::unfold::<D, D2>(
2701 "unfold",
2702 &self.shape(),
2703 dim,
2704 size,
2705 step,
2706 ));
2707 Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2708 }
2709}
2710
2711/// Iterator given by (Tensor::iter_dim).
2712pub struct DimIter<B, const D: usize, K>
2713where
2714 B: Backend,
2715 K: BasicOps<B>,
2716{
2717 start: usize,
2718 end: usize,
2719 dim: usize,
2720 ranges: [Range<usize>; D],
2721 tensor: Tensor<B, D, K>,
2722}
2723
2724impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2725 type Item = Tensor<B, D, K>;
2726
2727 fn next(&mut self) -> Option<Self::Item> {
2728 if self.start >= self.end {
2729 return None;
2730 }
2731
2732 let mut ranges = self.ranges.clone();
2733 ranges[self.dim] = self.start..(self.start + 1);
2734
2735 let slice = self.tensor.clone().slice(ranges);
2736 self.start += 1;
2737
2738 Some(slice)
2739 }
2740}
2741
2742impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2743 fn next_back(&mut self) -> Option<Self::Item> {
2744 if self.start >= self.end {
2745 return None;
2746 }
2747
2748 let mut ranges = self.ranges.clone();
2749 ranges[self.dim] = (self.end - 1)..self.end;
2750
2751 let slice = self.tensor.clone().slice(ranges);
2752 self.end = self.end.saturating_sub(1);
2753
2754 Some(slice)
2755 }
2756}
2757
2758impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2759 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2760 let dims = tensor.dims();
2761 let ranges = dims
2762 .iter()
2763 .map(|&dim| 0..dim)
2764 .collect::<Vec<Range<usize>>>();
2765 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2766 Self {
2767 end: dims[dim],
2768 ranges,
2769 start: 0,
2770 dim,
2771 tensor,
2772 }
2773 }
2774}
2775
2776impl<B, const D: usize, K> Tensor<B, D, K>
2777where
2778 B: Backend,
2779 K: BasicOps<B>,
2780 <K as BasicOps<B>>::Elem: Debug,
2781{
2782 #[inline]
2783 fn push_newline_indent(acc: &mut String, indent: usize) {
2784 acc.push('\n');
2785 for _ in 0..indent {
2786 acc.push(' ');
2787 }
2788 }
2789 fn fmt_inner_tensor(
2790 &self,
2791 acc: &mut String,
2792 depth: usize,
2793 multi_index: &mut [usize],
2794 range: (usize, usize),
2795 precision: Option<usize>,
2796 ) {
2797 let (start, end) = range;
2798 for i in start..end {
2799 if i > 0 {
2800 acc.push_str(", ");
2801 }
2802 multi_index[depth] = i;
2803 let range: [Range<usize>; D] =
2804 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2805
2806 let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());
2807
2808 if let Some(Ok(data)) = data {
2809 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2810 match (precision, K::name()) {
2811 (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2812 (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2813 _ => acc.push_str(&format!("{elem:?}")),
2814 }
2815 } else {
2816 acc.push_str("<Tensor data not available>");
2817 }
2818 }
2819 }
2820
2821 fn fmt_outer_tensor(
2822 &self,
2823 acc: &mut String,
2824 depth: usize,
2825 multi_index: &mut [usize],
2826 print_options: &PrintOptions,
2827 summarize: bool,
2828 range: (usize, usize),
2829 ) {
2830 let (start, end) = range;
2831 for i in start..end {
2832 if i > start {
2833 acc.push(',');
2834 Self::push_newline_indent(acc, depth + 1);
2835 }
2836 acc.push('[');
2837 multi_index[depth] = i;
2838 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2839 acc.push(']');
2840 }
2841 }
2842
2843 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2844 ///
2845 /// This function is designed to work with tensors of any dimensionality.
2846 /// It traverses the tensor dimensions recursively, converting the elements
2847 /// to strings and appending them to the accumulator string with the
2848 /// appropriate formatting.
2849 ///
2850 /// # Arguments
2851 ///
2852 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2853 /// * `depth` - The current depth of the tensor dimensions being processed.
2854 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2855 fn display_recursive(
2856 &self,
2857 acc: &mut String,
2858 depth: usize,
2859 multi_index: &mut [usize],
2860 print_options: &PrintOptions,
2861 summarize: bool,
2862 ) {
2863 let edge_items = print_options.edge_items;
2864
2865 if depth == 0 {
2866 acc.push('[');
2867 }
2868
2869 if depth == self.dims().len() - 1 {
2870 // if we are at the innermost dimension, just push its elements into the accumulator
2871 if summarize && self.dims()[depth] > 2 * edge_items {
2872 // print the starting `edge_items` elements
2873 self.fmt_inner_tensor(
2874 acc,
2875 depth,
2876 multi_index,
2877 (0, edge_items),
2878 print_options.precision,
2879 );
2880 acc.push_str(", ...");
2881 // print the last `edge_items` elements
2882 self.fmt_inner_tensor(
2883 acc,
2884 depth,
2885 multi_index,
2886 (self.dims()[depth] - edge_items, self.dims()[depth]),
2887 print_options.precision,
2888 );
2889 } else {
2890 // print all the elements
2891 self.fmt_inner_tensor(
2892 acc,
2893 depth,
2894 multi_index,
2895 (0, self.dims()[depth]),
2896 print_options.precision,
2897 );
2898 }
2899 } else {
2900 // otherwise, iterate through the current dimension and recursively display the inner tensors
2901 if summarize && self.dims()[depth] > 2 * edge_items {
2902 self.fmt_outer_tensor(
2903 acc,
2904 depth,
2905 multi_index,
2906 print_options,
2907 summarize,
2908 (0, edge_items),
2909 );
2910
2911 acc.push(',');
2912 Self::push_newline_indent(acc, depth + 1);
2913 acc.push_str("...");
2914 Self::push_newline_indent(acc, depth + 1);
2915
2916 self.fmt_outer_tensor(
2917 acc,
2918 depth,
2919 multi_index,
2920 print_options,
2921 summarize,
2922 (self.dims()[depth] - edge_items, self.dims()[depth]),
2923 );
2924 } else {
2925 self.fmt_outer_tensor(
2926 acc,
2927 depth,
2928 multi_index,
2929 print_options,
2930 summarize,
2931 (0, self.dims()[depth]),
2932 );
2933 }
2934 }
2935
2936 if depth == 0 {
2937 acc.push(']');
2938 }
2939 }
2940}
2941
2942#[derive(Clone, Debug)]
2943/// Options for Tensor pretty printing
2944pub struct PrintOptions {
2945 /// number of elements to start summarizing tensor
2946 pub threshold: usize,
2947
2948 /// number of starting elements and ending elements to display
2949 pub edge_items: usize,
2950
2951 /// Precision for floating point numbers
2952 pub precision: Option<usize>,
2953}
2954
2955static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
2956
2957impl PrintOptions {
2958 /// Print options with default values
2959 pub const fn const_default() -> Self {
2960 Self {
2961 threshold: 1000,
2962 edge_items: 3,
2963 precision: None,
2964 }
2965 }
2966}
2967
2968impl Default for PrintOptions {
2969 fn default() -> Self {
2970 Self::const_default()
2971 }
2972}
2973
2974/// Set print options
2975pub fn set_print_options(options: PrintOptions) {
2976 let mut print_opts = PRINT_OPTS.write().unwrap();
2977 *print_opts = options;
2978}
2979
2980/// Pretty print tensors
2981impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
2982where
2983 B: Backend,
2984 B::IntElem: core::fmt::Display,
2985 K: BasicOps<B>,
2986 <K as BasicOps<B>>::Elem: Debug,
2987{
2988 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2989 writeln!(f, "Tensor {{")?;
2990
2991 {
2992 // Do not lock the mutex for the whole function
2993 let mut po = { PRINT_OPTS.read().unwrap().clone() };
2994
2995 // Override the precision if it is set from the formatter
2996 // This will be possible when the tensor is printed using the `{:.*}` syntax
2997 if let Some(precision) = f.precision() {
2998 po.precision = Some(precision);
2999 }
3000
3001 let mut acc = String::new();
3002 let mut multi_index = vec![0; D];
3003 let summarize = self.shape().num_elements() > po.threshold;
3004
3005 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
3006
3007 writeln!(f, " data:")?;
3008 write!(f, "{acc}")?;
3009 writeln!(f, ",")?;
3010 }
3011
3012 writeln!(f, " shape: {:?},", self.dims())?;
3013 writeln!(f, " device: {:?},", self.device())?;
3014 writeln!(f, " backend: {:?},", B::name(&self.device()))?;
3015 writeln!(f, " kind: {:?},", K::name())?;
3016
3017 let dtype = self.primitive.dtype();
3018
3019 writeln!(f, " dtype: {:?},", dtype.name())?;
3020 write!(f, "}}")
3021 }
3022}
3023
3024/// Trait used for movedim arguments
3025pub trait MovedimArgs {
3026 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3027 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3028}
3029
3030impl MovedimArgs for Vec<i32> {
3031 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3032 let set = self
3033 .iter()
3034 .map(|&dim| {
3035 if dim < 0 {
3036 (D as i32 + dim) as usize
3037 } else {
3038 dim as usize
3039 }
3040 })
3041 .collect::<Vec<usize>>();
3042 check!(TensorCheck::movedim_args_vec::<D>(&set));
3043
3044 set
3045 }
3046}
3047
3048impl MovedimArgs for Vec<usize> {
3049 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3050 check!(TensorCheck::movedim_args_vec::<D>(&self));
3051 self
3052 }
3053}
3054
3055impl MovedimArgs for usize {
3056 #[allow(clippy::vec_init_then_push)]
3057 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3058 check!(TensorCheck::movedim_args_usize::<D>(self));
3059
3060 let mut set = Vec::with_capacity(1);
3061 set.push(self);
3062
3063 set
3064 }
3065}
3066
3067impl MovedimArgs for i32 {
3068 #[allow(clippy::vec_init_then_push)]
3069 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3070 check!(TensorCheck::movedim_args_i32::<D>(self));
3071
3072 let dim = if self < 0 {
3073 (D as i32 + self) as usize
3074 } else {
3075 self as usize
3076 };
3077
3078 let mut set = Vec::with_capacity(1);
3079 set.push(dim);
3080
3081 set
3082 }
3083}
3084
3085/// Trait used for reshape arguments.
3086pub trait ReshapeArgs<const D2: usize>: Debug {
3087 /// Converts to a shape.
3088 fn into_shape<const D: usize>(self, source: Shape) -> Shape;
3089}
3090
3091impl<const D2: usize, I: AsIndex> ReshapeArgs<D2> for [I; D2] {
3092 fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3093 unwrap_shape_reshape(source.reshape(self))
3094 }
3095}
3096
3097impl<const D2: usize> ReshapeArgs<D2> for Shape {
3098 fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3099 unwrap_shape_reshape(source.reshape(self))
3100 }
3101}
3102
3103/// Trait used for broadcast arguments.
3104pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3105 /// Converts to a shape.
3106 fn into_shape(self, shape: &Shape) -> Shape;
3107}
3108
3109impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3110 fn into_shape(self, _shape: &Shape) -> Shape {
3111 self
3112 }
3113}
3114
3115impl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {
3116 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3117 fn into_shape(self, shape: &Shape) -> Shape {
3118 if self.len() < shape.num_dims() {
3119 panic!("Broadcast arguments must be greater than the number of dimensions");
3120 }
3121
3122 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3123 let new_shape: Vec<_> = self
3124 .iter()
3125 .rev()
3126 .map(|x| {
3127 let primitive = x.as_index();
3128 if primitive < -1 || primitive == 0 {
3129 panic!("Broadcast arguments must be positive or -1");
3130 }
3131 primitive
3132 })
3133 .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3134 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3135 .collect::<Vec<_>>()
3136 .into_iter()
3137 .rev()
3138 .collect();
3139
3140 if new_shape.contains(&0) {
3141 panic!("Cannot substitute -1 for a non-existing dimension");
3142 }
3143
3144 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3145
3146 Shape::from(new_shape)
3147 }
3148}
3149
3150impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3151where
3152 B: Backend,
3153 K: BasicOps<B>,
3154 K::Elem: Debug + Copy + Serialize,
3155{
3156 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3157 let data = self.to_data();
3158 data.serialize(serializer)
3159 }
3160}
3161
3162impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3163where
3164 B: Backend,
3165 K: BasicOps<B>,
3166 K::Elem: Debug + Copy + Deserialize<'de>,
3167{
3168 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3169 let tensor = Tensor::from_data(
3170 TensorData::deserialize(deserializer)?,
3171 &<B::Device as Default>::default(),
3172 );
3173 Ok(tensor)
3174 }
3175}
3176
3177#[cfg(test)]
3178mod tests {
3179 use crate::{Shape, s};
3180
3181 #[test]
3182 fn slice_range_single_dim_leading() {
3183 let shape = Shape::new([8, 4]);
3184
3185 // Half-open range
3186 let slices = shape.clone().into_slices([0..5]);
3187 assert_eq!(slices[0].to_range(8), 0..5);
3188 let slices = shape.clone().into_slices([-3..-1]);
3189 assert_eq!(slices[0].to_range(8), 5..7);
3190
3191 // Inclusive range
3192 let slices = shape.clone().into_slices([0..=4]);
3193 assert_eq!(slices[0].to_range(8), 0..5);
3194 let slices = shape.clone().into_slices([-2..=-1]);
3195 assert_eq!(slices[0].to_range(8), 6..8);
3196
3197 // Unbounded start
3198 let slices = shape.clone().into_slices([..3]);
3199 assert_eq!(slices[0].to_range(8), 0..3);
3200 let slices = shape.clone().into_slices([..-5]);
3201 assert_eq!(slices[0].to_range(8), 0..3);
3202
3203 // Unbounded end
3204 let slices = shape.clone().into_slices([5..]);
3205 assert_eq!(slices[0].to_range(8), 5..8);
3206 let slices = shape.clone().into_slices([-3..]);
3207 assert_eq!(slices[0].to_range(8), 5..8);
3208
3209 // Full range
3210 let slices = shape.into_slices([..]);
3211 assert_eq!(slices[0].to_range(8), 0..8);
3212 }
3213
3214 #[test]
3215 fn test_negative_slice_indices() {
3216 use crate::Slice;
3217
3218 // Test negative indices conversion
3219 let slice: Slice = (-3..-1).into();
3220 assert_eq!(slice.start, -3);
3221 assert_eq!(slice.end, Some(-1));
3222
3223 // Test to_range conversion with size 8
3224 let range = slice.to_range(8);
3225 assert_eq!(range, 5..7);
3226
3227 // Test with shape slice
3228 let shape = Shape::new([8, 4]);
3229 let result = shape.clone().into_slices([-3..-1]);
3230 assert_eq!(result[0].to_range(8), 5..7);
3231
3232 // Test more negative index cases
3233 let slice2: Slice = (-5..).into();
3234 assert_eq!(slice2.to_range(10), 5..10);
3235
3236 let slice3: Slice = (..-2).into();
3237 assert_eq!(slice3.to_range(10), 0..8);
3238
3239 // Test with s! macro - single dimension returns Slice directly
3240 let slice4 = s![-3..-1];
3241 assert_eq!(slice4.start, -3);
3242 assert_eq!(slice4.end, Some(-1));
3243 }
3244
3245 #[test]
3246 fn slice_range_multi_dim() {
3247 let shape = Shape::new([8, 4]);
3248
3249 // Multiple ways to provide ranges
3250 let slices = shape.clone().into_slices([0..5, 0..4]);
3251 assert_eq!(slices[0].to_range(8), 0..5);
3252 assert_eq!(slices[1].to_range(4), 0..4);
3253
3254 let slices = shape.clone().into_slices([0.., 0..]);
3255 assert_eq!(slices[0].to_range(8), 0..8);
3256 assert_eq!(slices[1].to_range(4), 0..4);
3257
3258 let slices = shape.clone().into_slices([0..=7, 0..=3]);
3259 assert_eq!(slices[0].to_range(8), 0..8);
3260 assert_eq!(slices[1].to_range(4), 0..4);
3261
3262 let slices = shape.clone().into_slices([0..5, 0..3]);
3263 assert_eq!(slices[0].to_range(8), 0..5);
3264 assert_eq!(slices[1].to_range(4), 0..3);
3265
3266 let slices = shape.into_slices([0.., 0..]);
3267 assert_eq!(slices[0].to_range(8), 0..8);
3268 assert_eq!(slices[1].to_range(4), 0..4);
3269 }
3270
3271 #[test]
3272 fn slice_range_multi_dim_index() {
3273 let shape = Shape::new([8, 4]);
3274
3275 // Indices (single integer) should also convert to correct range
3276 let slices = shape.clone().into_slices([0, 2]);
3277 assert_eq!(slices[0].to_range(8), 0..1);
3278 assert_eq!(slices[1].to_range(4), 2..3);
3279
3280 let slices = shape.into_slices([-1, -1]);
3281 assert_eq!(slices[0].to_range(8), 7..8);
3282 assert_eq!(slices[1].to_range(4), 3..4);
3283 }
3284
3285 #[test]
3286 fn slice_range_multi_dim_heterogeneous() {
3287 // Slice macro `s![]` can be used to provide different range types
3288 let shape = Shape::new([8, 4, 2]);
3289 let slice = s![0..5, .., -1];
3290 let slices = shape.into_slices(slice);
3291 assert_eq!(slices[0].to_range(8), 0..5);
3292 assert_eq!(slices[1].to_range(4), 0..4);
3293 assert_eq!(slices[2].to_range(2), 1..2);
3294
3295 let shape = Shape::new([8, 4, 2, 3]);
3296 let slice = s![..=4, 0..=3, .., -2..];
3297 let slices = shape.into_slices(slice);
3298 assert_eq!(slices[0].to_range(8), 0..5);
3299 assert_eq!(slices[1].to_range(4), 0..4);
3300 assert_eq!(slices[2].to_range(2), 0..2);
3301 assert_eq!(slices[3].to_range(3), 1..3);
3302
3303 let shape = Shape::new([3, 4]);
3304 let slice = s![1..-1, ..];
3305 let slices = shape.into_slices(slice);
3306 assert_eq!(slices[0].to_range(3), 1..2);
3307 assert_eq!(slices[1].to_range(4), 0..4);
3308 }
3309}