burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::future::Future;
11use core::iter::repeat;
12use core::{fmt::Debug, ops::Range};
13use serde::{Deserialize, Deserializer};
14
15use serde::{Serialize, Serializer};
16
17use crate::tensor::api::narrow::narrow;
18use crate::{
19 Bool, Float, Int, Shape, TensorData, TensorKind, backend::Backend, check, ops::Device,
20};
21use crate::{DType, Element, TensorPrimitive};
22use crate::{cast::ToElement, check::TensorCheck};
23
24use super::{Slice, TensorMetadata, Transaction};
25
26/// A tensor with a given backend, shape and data type.
27///
28/// # Indexing
29/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
30/// or [`select`](Tensor::select) for numeric types.
31///
32/// ## Example
33///
34/// ```rust
35/// use burn_tensor::backend::Backend;
36/// use burn_tensor::Tensor;
37/// use burn_tensor::Int;
38///
39/// fn example<B: Backend>() {
40/// let device = Default::default();
41///
42/// let tensor = Tensor::<B, 2>::from_data(
43/// [
44/// [3.0, 4.9, 2.0],
45/// [2.0, 1.9, 3.0],
46/// [6.0, 1.5, 7.0],
47/// [3.0, 4.9, 9.0],
48/// ],
49/// &device,
50/// );
51///
52/// // Slice the tensor to get the second and third rows:
53/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
54/// // The resulting tensor will have dimensions [2, 3].
55/// let slice = tensor.clone().slice([1..3]);
56/// println!("{slice}");
57///
58/// // Slice the tensor to get the first two rows and the first 2 columns:
59/// // [[3.0, 4.9], [2.0, 1.9]]
60/// // The resulting tensor will have dimensions [2, 2].
61/// let slice = tensor.clone().slice([0..2, 0..2]);
62/// println!("{slice}");
63///
64/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
65/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
66/// // The resulting tensor will have dimensions [4, 2]
67/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
68/// let indexed = tensor.select(1, indices);
69/// println!("{indexed}");
70/// }
71/// ```
72#[derive(new, Clone, Debug)]
73pub struct Tensor<B, const D: usize, K = Float>
74where
75 B: Backend,
76 K: TensorKind<B>,
77{
78 pub(crate) primitive: K::Primitive,
79}
80
81impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
82where
83 B: Backend,
84 K: BasicOps<B>,
85 T: Into<TensorData>,
86{
87 fn from(value: T) -> Self {
88 Tensor::from_data(value.into(), &Default::default())
89 }
90}
91
92impl<B, const D: usize, K> Tensor<B, D, K>
93where
94 B: Backend,
95 K: BasicOps<B>,
96{
97 /// Converts the tensor into a primitive tensor.
98 pub fn into_primitive(self) -> K::Primitive {
99 self.primitive
100 }
101
102 /// Converts from a primitive tensor into a tensor.
103 pub fn from_primitive(tensor: K::Primitive) -> Self {
104 Self::new(tensor)
105 }
106
107 /// Returns the tensor primitive data type.
108 ///
109 /// # Note
110 /// Some element types are encoded in different primitive types depending on the backend
111 /// (e.g., bool could be encoded as `u8` or `u32`).
112 pub fn dtype(&self) -> DType {
113 self.primitive.dtype()
114 }
115
116 /// Create an empty tensor of the given shape.
117 ///
118 /// # Arguments
119 ///
120 /// - shape: The shape of the tensor.
121 /// - device: The device where the tensor will be created.
122 ///
123 /// # Example
124 /// ```rust
125 /// use burn_tensor::backend::Backend;
126 /// use burn_tensor::Tensor;
127 ///
128 /// fn example<B: Backend>() {
129 /// let device = Default::default();
130 /// // Create an empty tensor with dimensions [2, 3, 4].
131 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
132 /// }
133 /// ```
134 pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
135 let shape = shape.into();
136 check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
137 Self::new(K::empty(shape, device))
138 }
139
140 /// Returns the dimensions of the current tensor.
141 ///
142 /// # Example
143 /// ```rust
144 /// use burn_tensor::backend::Backend;
145 /// use burn_tensor::Tensor;
146 ///
147 /// fn example<B: Backend>() {
148 /// let device = Default::default();
149 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
150 /// let dims = tensor.dims(); // [2, 3, 4]
151 /// println!("{dims:?}");
152 /// }
153 /// ```
154 pub fn dims(&self) -> [usize; D] {
155 Self::shape(self).dims()
156 }
157
158 /// Returns the shape of the current tensor.
159 ///
160 /// # Example
161 /// ```rust
162 /// use burn_tensor::backend::Backend;
163 /// use burn_tensor::Tensor;
164 ///
165 /// fn example<B: Backend>() {
166 /// let device = Default::default();
167 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
168 /// // Shape { dims: [2, 3, 4] }
169 /// let shape = tensor.shape();
170 /// }
171 /// ```
172 pub fn shape(&self) -> Shape {
173 self.primitive.shape()
174 }
175
176 /// Reshape the tensor to have the given shape.
177 ///
178 /// The tensor has the same data and number of elements as the input.
179 ///
180 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
181 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
182 ///
183 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
184 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
185 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
186 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
187 /// with [1, 3, 4] dimensions to [1, 12].
188 ///
189 /// # Arguments
190 /// - `shape`: The new shape of the tensor.
191 ///
192 /// # Panics
193 /// - If the tensor contains more than one `-1` in the shape.
194 /// - If the tensor contains values that are not positive (other than -1).
195 /// - If the shape does not match the number of elements of the original shape.
196 ///
197 /// # Example
198 ///
199 /// ```rust
200 /// use burn_tensor::backend::Backend;
201 /// use burn_tensor::Tensor;
202 ///
203 /// fn example<B: Backend>() {
204 /// let device = Default::default();
205 /// // Create a tensor with dimensions [2, 3, 4]
206 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
207 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
208 /// let reshaped = tensor.reshape([2, -1]);
209 /// println!("{reshaped}");
210 /// }
211 /// ```
212 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
213 // Convert reshape args to shape
214 let shape = shape.into_shape(&self);
215 Tensor::new(K::reshape(self.primitive, shape))
216 }
217
218 /// Transpose the tensor.
219 ///
220 /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
221 /// applied on the last two dimensions. For example, the transpose of a tensor with shape
222 /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
223 ///
224 /// See also [`permute`](Tensor::permute).
225 ///
226 /// # Arguments
227 ///
228 /// * `tensor` - The tensor to transpose.
229 ///
230 /// # Returns
231 ///
232 /// The transposed tensor.
233 ///
234 /// # Example
235 ///
236 /// ```rust
237 /// use burn_tensor::backend::Backend;
238 /// use burn_tensor::Tensor;
239 ///
240 /// fn example<B: Backend>() {
241 /// let device = Default::default();
242 /// // Create a 2D tensor of shape [2, 3]
243 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
244 ///
245 /// // Transpose the tensor:
246 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
247 /// // The resulting tensor will have dimensions [3, 2].
248 /// let transposed = tensor.transpose();
249 /// println!("{transposed}");
250 /// }
251 /// ```
252 pub fn transpose(self) -> Tensor<B, D, K> {
253 Tensor::new(K::transpose(self.primitive))
254 }
255
256 /// Swaps two dimensions of a tensor.
257 ///
258 /// # Arguments
259 ///
260 /// * `tensor` - The tensor to swap the dimensions of.
261 /// * `dim1` - The first dimension to swap.
262 /// * `dim2` - The second dimension to swap.
263 ///
264 /// # Returns
265 ///
266 /// The tensor with the dimensions swapped.
267 ///
268 /// # Example
269 ///
270 /// ```rust
271 /// use burn_tensor::backend::Backend;
272 /// use burn_tensor::Tensor;
273 ///
274 /// fn example<B: Backend>() {
275 /// let device = Default::default();
276 /// // Create a 2D tensor of shape [2, 3]
277 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
278 ///
279 /// // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`):
280 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
281 /// // The resulting tensor will have dimensions [3, 2].
282 /// let swapped = tensor.swap_dims(0, 1);
283 /// println!("{swapped}");
284 /// }
285 /// ```
286 pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
287 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
288 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
289 }
290
291 /// Permute the dimensions of the tensor.
292 ///
293 /// # Arguments
294 ///
295 /// * `axes` - The new order of the dimensions. The length of the axes
296 /// must be equal to the number of dimensions of the tensor.
297 /// The values must be unique and in the range of the number of dimensions.
298 /// The values can be negative, in which case they are used as an offset from the end.
299 ///
300 /// # Returns
301 ///
302 /// The tensor with the dimensions permuted.
303 ///
304 /// # Example
305 ///
306 /// ```rust
307 /// use burn_tensor::backend::Backend;
308 /// use burn_tensor::Tensor;
309 ///
310 /// fn example<B: Backend>() {
311 /// let device = Default::default();
312 /// // Create a 2D tensor of shape [3, 2]
313 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
314 ///
315 /// // Permute the dimensions 1 and 0:
316 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
317 /// // The resulting tensor will have dimensions [3, 2].
318 /// let permuted = tensor.permute([1, 0]);
319 /// println!("{permuted}");
320 /// }
321 /// ```
322 pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
323 // Convert the axes to usize and handle negative values without using vector
324 let mut transformed_axes: [usize; D] = [0; D];
325 for (i, &x) in axes.iter().enumerate() {
326 transformed_axes[i] = if x < 0 {
327 (D as isize + x) as usize
328 } else {
329 x as usize
330 };
331 }
332
333 // Check if the axes are valid after the transformation
334 check!(TensorCheck::permute(transformed_axes));
335
336 Tensor::new(K::permute(self.primitive, &transformed_axes))
337 }
338
339 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
340 ///
341 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
342 /// at the positions not specified in destination.
343 ///
344 /// # Arguments
345 ///
346 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
347 /// The values can be negative, in which case they are used as an offset from the end.
348 ///
349 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
350 ///
351 /// # Panics
352 ///
353 /// - If the source and destination dimensions are not of the same length.
354 /// - If the source and destination vectors contain duplicate values.
355 /// - If the source and destination vectors contain values that are out of bounds.
356 ///
357 /// # Returns
358 ///
359 /// The tensor with the dimensions moved.
360 ///
361 /// # Example
362 ///
363 /// ```rust
364 /// use burn_tensor::backend::Backend;
365 /// use burn_tensor::Tensor;
366 ///
367 /// fn example<B: Backend>() {
368 /// let device = Default::default();
369 /// // Create a 3D tensor of shape [3, 2, 1]
370 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
371 ///
372 /// // Move the dimensions 0 and 1:
373 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
374 /// // The resulting tensor will have dimensions [2, 3, 1].
375 /// let moved = tensor.movedim(1, 0);
376 /// println!("{moved}");
377 /// }
378 /// ```
379 ///
380 /// # Note
381 ///
382 /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
383 /// for it
384 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
385 let source_dims = src.into_dim_vec::<D>();
386 let destination_dims = dst.into_dim_vec::<D>();
387
388 check!(TensorCheck::movedim_args_length(
389 &source_dims,
390 &destination_dims
391 ));
392
393 let mut m = [-1; D];
394 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
395 m[d] = s as isize;
396 }
397 let mut axes: [isize; D] = [0; D];
398 let mut source_i = 0;
399 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
400 *item = if m[dest_i] != -1 {
401 m[dest_i]
402 } else {
403 while source_dims.contains(&source_i) {
404 source_i += 1;
405 }
406 let result = source_i as isize;
407 source_i += 1;
408 result
409 };
410 }
411
412 self.permute(axes)
413 }
414
415 /// Reverse the order of elements in the tensor along the given dimensions.
416 ///
417 /// # Arguments
418 ///
419 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
420 /// The values can be negative, in which case they are used as an offset from the end.
421 ///
422 /// # Returns
423 ///
424 /// The tensor with the axes flipped.
425 ///
426 /// # Example
427 ///
428 /// ```rust
429 /// use burn_tensor::backend::Backend;
430 /// use burn_tensor::Tensor;
431 ///
432 /// fn example<B: Backend>() {
433 /// let device = Default::default();
434 /// // Create a 2D tensor with dimensions [4, 3]
435 /// let tensor = Tensor::<B, 2>::from_data(
436 /// [
437 /// [3.0, 4.9, 2.0],
438 /// [2.0, 1.9, 3.0],
439 /// [4.0, 5.9, 8.0],
440 /// [1.4, 5.8, 6.0],
441 /// ],
442 /// &device,
443 /// );
444 ///
445 /// // Flip the elements in dimensions 0 and 1:
446 /// // [[6.0, 5.8, 1.4],
447 /// // [8.0, 5.9, 4.0],
448 /// // [3.0, 1.9, 2.0],
449 /// // [2.0, 4.9, 3.0]]
450 /// // The resulting tensor will have dimensions [4, 3].
451 /// let flipped = tensor.flip([0, 1]);
452 /// println!("{flipped}");
453 /// }
454 /// ```
455 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
456 // Convert the axes to usize and handle negative values without using vector
457 let mut transformed_axes: [usize; N] = [0; N];
458 for (i, &x) in axes.iter().enumerate() {
459 transformed_axes[i] = if x < 0 {
460 (D as isize + x) as usize
461 } else {
462 x as usize
463 };
464 }
465
466 // Check if the axes are valid
467 check!(TensorCheck::flip(D, &transformed_axes));
468
469 Tensor::new(K::flip(self.primitive, &transformed_axes))
470 }
471
472 /// Flatten the tensor along a given range of dimensions.
473 ///
474 /// This function collapses the specified range of dimensions into a single dimension,
475 /// effectively flattening the tensor in that range.
476 ///
477 /// # Arguments
478 ///
479 /// - `start_dim`: The starting dimension of the range to be flattened.
480 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
481 ///
482 /// # Type Parameters
483 ///
484 /// - `D2`: The resulting number of dimensions in the flattened tensor.
485 ///
486 /// # Returns
487 ///
488 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
489 ///
490 /// # Example
491 ///
492 /// ```rust
493 ///
494 /// use burn_tensor::backend::Backend;
495 /// use burn_tensor::{Tensor, Shape};
496 ///
497 /// fn example<B: Backend>() {
498 /// let device = Default::default();
499 /// // Create a 3D tensor with dimensions [2, 3, 4]
500 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
501 ///
502 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
503 /// // The resulting tensor will have dimensions [2, 12]
504 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
505 /// println!("{flattened}");
506 /// }
507 /// ```
508 pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
509 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
510
511 let current_dims = self.shape().dims;
512 let mut new_dims: [usize; D2] = [0; D2];
513 let mut flatten_dims = 1;
514
515 for i in current_dims[start_dim..=end_dim].iter() {
516 flatten_dims *= i;
517 }
518
519 new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]);
520 new_dims[start_dim] = flatten_dims;
521 new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]);
522
523 Tensor::new(K::reshape(self.primitive, new_dims.into()))
524 }
525
526 /// Squeeze the tensor along the given dimension, removing the specified dimension
527 /// of size one, and effectively reducing the rank of the tensor by one.
528 ///
529 /// # Arguments
530 ///
531 /// - `dim`: The dimension to be squeezed.
532 ///
533 /// # Type Parameters
534 ///
535 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
536 ///
537 /// # Panics
538 ///
539 /// If the size in the squeezed dimension is not 1.
540 ///
541 /// # Returns
542 ///
543 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
544 ///
545 /// # Example
546 ///
547 /// ```rust
548 ///
549 /// use burn_tensor::backend::Backend;
550 /// use burn_tensor::{Tensor, Shape};
551 ///
552 /// fn example<B: Backend>() {
553 /// let device = Default::default();
554 /// // Create a 3D tensor with dimensions [3, 1, 3]
555 /// let tensor = Tensor::<B, 3>::from_data(
556 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
557 /// &device,
558 /// );
559 ///
560 /// // Squeeze the dimension 1.
561 /// // The resulting tensor will have dimensions [3, 3].
562 /// let squeezed = tensor.squeeze::<2>(1);
563 /// println!("{squeezed}");
564 /// }
565 /// ```
566 pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
567 check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
568
569 let current_dims = self.shape().dims;
570 let mut new_dims: [usize; D2] = [0; D2];
571
572 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
573 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
574
575 Tensor::new(K::reshape(self.primitive, new_dims.into()))
576 }
577
578 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
579 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
580 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
581 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
582 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
583 /// from the back.
584 ///
585 /// # Arguments
586 ///
587 /// - `dims`: The dimension(s) to be squeezed.
588 ///
589 /// # Type Parameters
590 ///
591 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
592 ///
593 /// # Returns
594 ///
595 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
596 ///
597 /// # Example
598 ///
599 /// ```rust
600 ///
601 /// use burn_tensor::backend::Backend;
602 /// use burn_tensor::{Tensor, Shape};
603 ///
604 /// fn example<B: Backend>() {
605 /// let device = Default::default();
606 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
607 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
608 ///
609 /// // Squeeze the dimensions 1 and 3.
610 /// // The resulting tensor will have dimensions [2, 4].
611 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
612 /// println!("{squeezed}");
613 /// }
614 /// ```
615 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
616 let current_dims = self.shape().dims;
617 let mut dim_indices: Vec<usize>;
618
619 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
620 if dims.is_empty() {
621 dim_indices = current_dims
622 .iter()
623 .enumerate()
624 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
625 .collect();
626 } else {
627 // If negative dims, count from the back
628 dim_indices = dims
629 .iter()
630 .map(|&d| {
631 if d < 0 {
632 (current_dims.len() as isize + d) as usize
633 } else {
634 d as usize
635 }
636 })
637 .collect();
638 }
639
640 // Sort indices and remove duplicates
641 dim_indices.sort_unstable();
642 dim_indices.dedup();
643
644 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
645 check!(TensorCheck::squeeze_dims_input::<D2>(
646 &dim_indices,
647 ¤t_dims
648 ));
649
650 // Calculate new dimensions
651 let mut new_dims = Vec::new();
652 for (index, &dim_size) in current_dims.iter().enumerate() {
653 // Exclude the dimension if it's explicitly marked for squeezing
654 if dim_indices.contains(&index) {
655 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
656 continue;
657 }
658 new_dims.push(dim_size);
659 }
660
661 // Check that after squeezing, we still respect the D2 size
662 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
663
664 Tensor::new(K::reshape(self.primitive, new_dims.into()))
665 }
666
667 /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
668 ///
669 /// # Type Parameters
670 ///
671 /// - `D2`: The resulting number of dimensions in the unsqueezed tensor.
672 ///
673 /// # Panics
674 ///
675 /// If the output size `D2` is smaller than the current number of dimensions.
676 ///
677 /// # Returns
678 ///
679 /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
680 ///
681 /// # Example
682 ///
683 /// ```rust
684 /// use burn_tensor::backend::Backend;
685 /// use burn_tensor::{Tensor, Shape};
686 ///
687 /// fn example<B: Backend>() {
688 /// let device = Default::default();
689 /// // Create a 2D tensor with dimensions [3, 3]
690 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
691 /// // Unsqueeze the tensor up to 4 dimensions.
692 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
693 /// let unsqueezed = tensor.unsqueeze::<4>();
694 /// println!("{unsqueezed}");
695 /// }
696 /// ```
697 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
698 check!(TensorCheck::unsqueeze::<D, D2>());
699
700 let mut dims = [1; D2];
701 let num_ones = D2 - D;
702 let shape = self.shape();
703
704 dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
705
706 let shape = Shape::new(dims);
707 self.reshape(shape)
708 }
709
710 /// Creates a new tensor with a dimension of size one inserted at the specified position.
711 ///
712 /// # Example
713 ///
714 /// ```rust
715 /// use burn_tensor::backend::Backend;
716 /// use burn_tensor::{Tensor, Shape};
717 ///
718 /// fn example<B: Backend>() {
719 /// let device = Default::default();
720 /// // Create a 2D tensor with dimensions [3, 3]
721 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
722 /// // Unsqueeze the dimension 1.
723 /// // The resulting tensor will have dimensions [3, 1, 3].
724 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
725 /// println!("{unsqueezed}");
726 /// }
727 /// ```
728 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
729 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
730
731 let mut dims = [1; D2];
732 let shape = self.shape();
733
734 dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
735
736 if dim < D {
737 dims[dim] = 1;
738 dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
739 } else {
740 dims[dim] = 1;
741 }
742
743 let shape = Shape::new(dims);
744 self.reshape(shape)
745 }
746
747 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
748 /// The indices can be negative, in which case they are counted from the last to the first dimension.
749 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
750 /// is the number of duplicates.
751 /// # Example
752 ///
753 /// ```rust
754 /// use burn_tensor::backend::Backend;
755 /// use burn_tensor::{Tensor, Shape};
756 ///
757 /// fn example<B: Backend>() {
758 /// let device = Default::default();
759 /// // Create a 3D tensor with dimensions [3, 4, 5]
760 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
761 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
762 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
763 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
764 /// println!("{unsqueezed}");
765 /// }
766 /// ```
767 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
768 let mut new_dims = [1; D2];
769 let old_dims = self.shape().dims;
770 //for checking if the dimension is in the acceptable range
771
772 //part 1: convert the negative indices to positive
773 let mut neg_offset = D2;
774 let mut dim_indices = axes
775 .iter()
776 .map(|d| {
777 // check if the dimension is in the acceptable range
778 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
779 (if *d < 0 {
780 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
781 d + neg_offset as isize + 1
782 } else {
783 *d
784 }) as usize
785 })
786 .collect::<Vec<usize>>();
787
788 //sort the indices
789 dim_indices.sort_unstable();
790
791 //Now use this to copy the chunks of the dims
792 let mut prev_idx: usize = 0;
793 let mut current_left_b: usize = 0;
794 let mut current_right_b: usize = 0;
795 let mut offset: usize = 0;
796 dim_indices.iter().for_each(|d| {
797 //check if there is space for at least one dimension
798 if prev_idx < *d {
799 current_right_b = *d - offset;
800 //copy the chunks of the dims
801 if current_right_b < D {
802 new_dims[prev_idx..*d]
803 .copy_from_slice(&old_dims[current_left_b..current_right_b]);
804 } else {
805 new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
806 }
807 prev_idx = *d + 1;
808 //offset is equal to the number of extracted elements from the original shape
809 offset += current_right_b - current_left_b;
810 current_left_b = current_right_b;
811 } else {
812 //it's sorted so the only reason this would happen
813 //is if multiple indices are the same
814 prev_idx += 1;
815 }
816 });
817 //copy over anything past the index of the last new dimension
818 if current_left_b < D {
819 new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
820 }
821
822 //lastly, create the shape and reshape
823 let shape = Shape::new(new_dims);
824 self.reshape(shape)
825 }
826
827 /// Returns a tensor containing the elements selected from the given ranges.
828 ///
829 /// For more complex indexing with different slice ranges, see also the slice
830 /// macro [`s!`](crate::s).
831 ///
832 /// # Arguments
833 ///
834 /// * `ranges` - A type implementing the `RangesArg` trait, which can be:
835 /// - A single range (slice the first dimension)
836 /// - A single index (slice the first dimension)
837 /// - An array of ranges
838 ///
839 /// # Behavior
840 ///
841 /// - Supports partial and full slicing in any number of dimensions.
842 /// - Missing ranges are treated as full slices if D > D2.
843 /// - Handles negative indices by wrapping around from the end of the dimension.
844 /// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
845 ///
846 /// # Panics
847 ///
848 /// - If the number of ranges provided exceeds the tensor's dimensions.
849 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
850 ///
851 /// # Examples
852 ///
853 /// ```rust
854 /// use burn_tensor::backend::Backend;
855 /// use burn_tensor::{Tensor, Shape, s};
856 ///
857 /// fn example<B: Backend>() {
858 /// let device = B::Device::default();
859 ///
860 /// // 1D slicing
861 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
862 /// let slice = tensor.slice([1..4]);
863 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
864 ///
865 /// // 2D slicing
866 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
867 /// let slice = tensor.slice([1..3, 0..2]);
868 /// assert_eq!(slice.dims(), [2, 2]);
869 ///
870 /// // Using negative indices
871 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
872 /// let slice = tensor.slice([1..-1]); // Equivalent to 1..4
873 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
874 ///
875 /// // Using the slice macro to select different ranges
876 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
877 /// let slice = tensor.slice(s![1.., ..]); // Select rows 1 and 2, all columns
878 /// assert_eq!(slice.dims(), [2, 4]);
879 ///
880 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..16, &device).reshape([2, 4, 2]);
881 /// let slice = tensor.slice(s![1.., 1..=3, -1]);
882 /// assert_eq!(slice.dims(), [1, 3, 1]);
883 /// }
884 /// ```
885 ///
886 /// # Note
887 ///
888 /// This function uses the `RangesArg` trait for flexible range specification. The trait
889 /// handles the conversion of various range formats and applies clamping and negative
890 /// index handling internally.
891 pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
892 let ranges = ranges.into_ranges(self.shape());
893
894 check!(TensorCheck::slice::<D, D2>(&self.shape(), &ranges));
895 Self::new(K::slice(self.primitive, &ranges))
896 }
897
898 /// Returns a copy of the current tensor with the selected elements changed to the new ones at
899 /// the selected indices.
900 ///
901 /// # Panics
902 ///
903 /// - If a range exceeds the number of elements on a dimension.
904 /// - If the given values don't match the given ranges.
905 ///
906 /// # Example
907 ///
908 /// ```rust
909 /// use burn_tensor::backend::Backend;
910 /// use burn_tensor::Tensor;
911 ///
912 /// fn example<B: Backend>() {
913 /// let device = B::Device::default();
914 /// let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
915 /// let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
916 /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
917 /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
918 /// }
919 /// ```
920 pub fn slice_assign<const D2: usize>(self, ranges: [Range<usize>; D2], values: Self) -> Self {
921 check!(TensorCheck::slice_assign::<D, D2>(
922 &self.shape(),
923 &values.shape(),
924 &ranges
925 ));
926 Self::new(K::slice_assign(self.primitive, &ranges, values.primitive))
927 }
928
929 /// Returns the device of the current tensor.
930 pub fn device(&self) -> B::Device {
931 K::device(&self.primitive)
932 }
933
934 /// Move the tensor to the given device.
935 pub fn to_device(self, device: &B::Device) -> Self {
936 Self::new(K::to_device(self.primitive, device))
937 }
938
939 /// Converts the data of the current tensor.
940 ///
941 /// # Note
942 ///
943 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
944 /// tensors at once. This may improve laziness, especially if executed on a different
945 /// thread in native environments.
946 pub fn into_data(self) -> TensorData {
947 crate::try_read_sync(self.into_data_async()).expect(
948 "Failed to read tensor data synchronously.
949 This can happen on platforms that don't support blocking futures like WASM.
950 If possible, try using into_data_async instead.",
951 )
952 }
953
954 /// Converts the data of the current tensor.
955 ///
956 /// # Note
957 ///
958 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
959 /// tensors at once. This may improve laziness, especially if executed on a different
960 /// thread in native environments.
961 pub fn to_data(&self) -> TensorData {
962 self.clone().into_data()
963 }
964
965 /// Returns the data of the current tensor.
966 pub async fn into_data_async(self) -> TensorData {
967 K::into_data_async(self.primitive).await
968 }
969
970 /// Returns the data of the current tensor.
971 pub async fn to_data_async(&self) -> TensorData {
972 self.clone().into_data_async().await
973 }
974
975 /// Create a tensor from the given data on the given device.
976 pub fn from_data<T>(data: T, device: &B::Device) -> Self
977 where
978 T: Into<TensorData>,
979 {
980 let data = data.into();
981 check!(TensorCheck::creation_ops::<D>(
982 "From Data",
983 data.shape.as_slice()
984 ));
985 Self::new(K::from_data(data, device))
986 }
987
988 /// Create a tensor from the given data on the given device enforcing the given data type.
989 pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
990 where
991 T: Into<TensorData>,
992 {
993 let data = data.into();
994 check!(TensorCheck::creation_ops::<D>(
995 "From Data",
996 data.shape.as_slice()
997 ));
998 Self::new(K::from_data_dtype(data, device, dtype))
999 }
1000
1001 /// Repeat the tensor along the given dimension.
1002 ///
1003 /// The output tensor has the same shape, except along the given dimension.
1004 ///
1005 /// # Arguments
1006 /// - `dim`: The dimension to repeat.
1007 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1008 ///
1009 /// # Returns
1010 ///
1011 /// A new tensor with the given dimension repeated `times` times.
1012 ///
1013 /// # Example
1014 ///
1015 /// ```rust
1016 /// use burn_tensor::backend::Backend;
1017 /// use burn_tensor::Tensor;
1018 ///
1019 /// fn example<B: Backend>() {
1020 /// let device = Default::default();
1021 /// // Create a 2D tensor with dimensions [3, 2]
1022 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1023 ///
1024 /// // Repeat the tensor along the dimension 0 twice.
1025 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1026 /// // The resulting tensor will have dimensions [6, 2].
1027 /// let repeated = tensor.repeat_dim(0, 2);
1028 /// println!("{repeated}");
1029 /// }
1030 /// ```
1031 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1032 Self::new(K::repeat_dim(self.primitive, dim, times))
1033 }
1034
1035 /// Repeat the tensor along the given dimensions.
1036 /// # Arguments
1037 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1038 ///
1039 /// # Returns
1040 ///
1041 /// A new tensor with the given dimensions repeated `times` times.
1042 ///
1043 /// # Panics
1044 ///
1045 /// If `sizes` contains more elements than the number of dimensions.
1046 ///
1047 /// # Example
1048 ///
1049 /// ```rust
1050 ///
1051 /// use burn_tensor::backend::Backend;
1052 /// use burn_tensor::Tensor;
1053 ///
1054 /// fn example<B: Backend>() {
1055 /// let device = Default::default();
1056 /// // Create a 2D tensor with dimensions [3, 2]
1057 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1058 ///
1059 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1060 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1061 /// // The resulting tensor will have dimensions [6, 2].
1062 /// let repeated = tensor.repeat(&[2, 1]);
1063 /// }
1064 /// ```
1065 pub fn repeat(self, sizes: &[usize]) -> Self {
1066 let mut tensor = self;
1067 for (dim, ×) in sizes.iter().enumerate() {
1068 if times > 1 {
1069 tensor = tensor.repeat_dim(dim, times);
1070 }
1071 }
1072 tensor
1073 }
1074
1075 /// Applies element-wise equal comparison.
1076 ///
1077 /// # Returns
1078 /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1079 ///
1080 /// # Panics
1081 ///
1082 /// If the two tensors don't have the same shape.
1083 ///
1084 /// # Example
1085 ///
1086 /// ```rust
1087 /// use burn_tensor::backend::Backend;
1088 /// use burn_tensor::Tensor;
1089 ///
1090 /// fn example<B: Backend>() {
1091 /// let device = Default::default();
1092 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1093 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1094 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
1095 /// // [[false, true], [true, true], [true, true]]
1096 /// let equal = t1.equal(t2);
1097 /// println!("{equal}");
1098 /// }
1099 /// ```
1100 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1101 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1102 Tensor::new(K::equal(self.primitive, other.primitive))
1103 }
1104
1105 /// Applies element-wise non-equality comparison.
1106 ///
1107 /// # Returns
1108 /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1109 ///
1110 /// # Panics
1111 ///
1112 /// If the two tensors don't have the same shape.
1113 ///
1114 /// # Example
1115 ///
1116 /// ```rust
1117 /// use burn_tensor::backend::Backend;
1118 /// use burn_tensor::Tensor;
1119 ///
1120 /// fn example<B: Backend>() {
1121 /// let device = Default::default();
1122 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1123 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1124 /// // Compare the elements of the two 2D tensors for inequality.
1125 /// // [[true, false], [false, false], [false, false]]
1126 /// let not_equal = t1.not_equal(t2);
1127 /// println!("{not_equal}");
1128 /// }
1129 /// ```
1130 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1131 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1132 Tensor::new(K::not_equal(self.primitive, other.primitive))
1133 }
1134
1135 /// Concatenates all tensors into a new one along the given dimension.
1136 ///
1137 /// # Panics
1138 ///
1139 /// - If `dim` is higher than the rank.
1140 /// - If `tensors` is an empty vector.
1141 /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
1142 ///
1143 /// # Example
1144 ///
1145 /// ```rust
1146 /// use burn_tensor::backend::Backend;
1147 /// use burn_tensor::Tensor;
1148 ///
1149 /// fn example<B: Backend>() {
1150 /// let device = Default::default();
1151 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
1152 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1153 ///
1154 /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
1155 /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
1156 /// // The resulting tensor will have shape [2, 7].
1157 /// let concat = Tensor::cat(vec![t1, t2], 1);
1158 /// println!("{concat}");
1159 /// }
1160 /// ```
1161 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1162 check!(TensorCheck::cat(&tensors, dim));
1163
1164 Self::new(K::cat(
1165 tensors.into_iter().map(|vector| vector.primitive).collect(),
1166 dim,
1167 ))
1168 }
1169
1170 /// Concatenates all tensors into a new one along a new dimension.
1171 ///
1172 /// # Panics
1173 ///
1174 /// - If all tensors don't have the same shape.
1175 /// - If given dimension is not with range of 0..D2
1176 ///
1177 /// # Example
1178 ///
1179 /// ```rust
1180 /// use burn_tensor::backend::Backend;
1181 /// use burn_tensor::Tensor;
1182 ///
1183 /// fn example<B: Backend>() {
1184 /// let device = Default::default();
1185 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1186 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1187 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1188 ///
1189 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1190 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1191 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1192 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1193 /// // The resulting tensor will have shape [3, 2, 3].
1194 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1195 /// println!("{stacked}");
1196 /// }
1197 /// ```
1198 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1199 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1200 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1201 Tensor::<B, D2, K>::cat(tensors, dim)
1202 }
1203
1204 /// Iterate over slices of tensors alongside a given dimension.
1205 ///
1206 /// # Panics
1207 ///
1208 /// If given dimension is greater than or equal to tensor rank.
1209 ///
1210 /// # Returns
1211 ///
1212 /// A tensor iterator.
1213 ///
1214 /// # Example
1215 ///
1216 /// ```rust
1217 /// use burn_tensor::backend::Backend;
1218 /// use burn_tensor::Tensor;
1219 /// fn example<B: Backend>() {
1220 /// let device = Default::default();
1221 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1222 /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
1223 /// let iter = tensor.iter_dim(0);
1224 /// for (i,tensor) in iter.enumerate() {
1225 /// println!("Tensor {}: {}", i, tensor);
1226 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1227 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1228 /// }
1229 /// }
1230 /// ```
1231 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1232 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1233 DimIter::new(self, dim)
1234 }
1235
1236 /// Returns a new tensor with the given dimension narrowed to the given range.
1237 ///
1238 /// # Panics
1239 ///
1240 /// - If the dimension is greater than the number of dimensions of the tensor.
1241 /// - If the given range exceeds the number of elements on the given dimension.
1242 ///
1243 /// # Returns
1244 ///
1245 /// A new tensor with the given dimension narrowed to the given range.
1246 ///
1247 /// # Example
1248 ///
1249 /// ```rust
1250 /// use burn_tensor::backend::Backend;
1251 /// use burn_tensor::Tensor;
1252 ///
1253 /// fn example<B: Backend>() {
1254 /// let device = Default::default();
1255 /// // Create a 2D tensor with dimensions [4, 3]
1256 /// let tensor = Tensor::<B, 2>::from_data(
1257 /// [
1258 /// [3.0, 4.9, 2.0],
1259 /// [2.0, 1.9, 3.0],
1260 /// [6.0, 1.5, 7.0],
1261 /// [3.0, 4.9, 9.0],
1262 /// ],
1263 /// &device,
1264 /// );
1265 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1266 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1267 /// // The resulting tensor will have dimensions [3, 3].
1268 /// let narrowed = tensor.narrow(0, 1, 3);
1269 /// println!("{narrowed}");
1270 /// }
1271 /// ```
1272 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1273 check!(TensorCheck::dim_ops::<D>("narrow", dim));
1274 check!(TensorCheck::narrow(&self, dim, start, length));
1275 Self::new(narrow::<B, K>(self.primitive, dim, start, length))
1276 }
1277
1278 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1279 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1280 ///
1281 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1282 /// Otherwise all chunks will be of equal size except for the last one.
1283 ///
1284 /// # Panics
1285 ///
1286 /// If the dimension is greater than the number of dimensions of the tensor.
1287 ///
1288 /// # Returns
1289 /// A vector of tensors.
1290 ///
1291 /// # Example
1292 ///
1293 /// ```rust
1294 /// use burn_tensor::backend::Backend;
1295 /// use burn_tensor::Tensor;
1296 ///
1297 /// fn example<B: Backend>() {
1298 /// let device = Default::default();
1299 /// // Create a 2D tensor with dimensions [4, 3]
1300 /// let tensor = Tensor::<B, 2>::from_data(
1301 /// [
1302 /// [3.0, 4.9, 2.0],
1303 /// [2.0, 1.9, 3.0],
1304 /// [6.0, 1.5, 7.0],
1305 /// [3.0, 4.9, 9.0],
1306 /// ],
1307 /// &device,
1308 /// );
1309 /// // Split the tensor along the dimension 1 into 2 chunks.
1310 /// // The first chuck will have shape [4, 2]:
1311 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1312 /// // The second chunk will have shape [4, 1]:
1313 /// // [[2.0], [3.0], [7.0], [9.0]]
1314 /// let chunks = tensor.chunk(2, 1);
1315 /// println!("{chunks:?}");
1316 /// }
1317 /// ```
1318 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1319 check!(TensorCheck::dim_ops::<D>("chunk", dim));
1320 K::chunk(self.primitive, chunks, dim)
1321 .into_iter()
1322 .map(Self::new)
1323 .collect()
1324 }
1325
1326 /// Splits the tensor into chunks of a specified size along a given dimension.
1327 /// Each chunk is a view of the original tensor.
1328 ///
1329 /// If the tensor size along the given dimension is not divisible by `split_size`,
1330 /// then the last chunk will be smaller.
1331 ///
1332 /// # Panics
1333 ///
1334 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
1335 ///
1336 /// # Returns
1337 ///
1338 /// A vector of tensors.
1339 ///
1340 /// # Example
1341 /// ```rust
1342 /// use burn_tensor::backend::Backend;
1343 /// use burn_tensor::Tensor;
1344 ///
1345 /// fn example<B: Backend>() {
1346 /// let device = Default::default();
1347 /// // Create a 1D tensor with 5 elements
1348 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1349 /// // Split the tensor into chunks of size 2 along dimension 0
1350 /// let chunks = tensor.split(2, 0);
1351 /// // The result is a vector of tensors:
1352 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
1353 /// println!("{:?}", chunks);
1354 /// }
1355 /// ```
1356 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
1357 check!(TensorCheck::split::<D>(
1358 self.shape().dims.as_ref(),
1359 split_size,
1360 dim
1361 ));
1362 K::split(self.primitive, split_size, dim)
1363 .into_iter()
1364 .map(Self::new)
1365 .collect()
1366 }
1367
1368 /// Splits the tensor into chunks with the specified sizes along a given dimension.
1369 /// Each chunk is a view of the original tensor.
1370 ///
1371 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
1372 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
1373 ///
1374 /// # Panics
1375 ///
1376 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
1377 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
1378 ///
1379 /// # Returns
1380 ///
1381 /// A vector of tensors.
1382 ///
1383 /// # Example
1384 /// ```rust
1385 /// use burn_tensor::backend::Backend;
1386 /// use burn_tensor::Tensor;
1387 ///
1388 /// fn example<B: Backend>() {
1389 /// let device = Default::default();
1390 /// // Create a 1D tensor with 5 elements
1391 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1392 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
1393 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
1394 /// // The result is a vector of tensors:
1395 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
1396 /// println!("{:?}", chunks);
1397 /// }
1398 /// ```
1399 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
1400 check!(TensorCheck::split_with_sizes::<D>(
1401 self.shape().dims.as_ref(),
1402 &split_sizes,
1403 dim
1404 ));
1405 K::split_with_sizes(self.primitive, split_sizes, dim)
1406 .into_iter()
1407 .map(Self::new)
1408 .collect()
1409 }
1410
1411 /// Tests if any element in the `tensor` evaluates to True.
1412 ///
1413 /// # Arguments
1414 ///
1415 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1416 ///
1417 /// # Returns
1418 ///
1419 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
1420 /// evaluates to True, False otherwise.
1421 ///
1422 /// # Example
1423 ///
1424 /// ```rust
1425 /// use burn_tensor::backend::Backend;
1426 /// use burn_tensor::{Tensor, Bool};
1427 ///
1428 /// fn example<B: Backend>() {
1429 /// let device = Default::default();
1430 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
1431 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
1432 ///
1433 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1434 /// let any_tensor = tensor.any();
1435 /// println!("{}", any_tensor);
1436 /// // Tensor { data: [true], ... }
1437 ///
1438 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1439 /// let any_tensor_two = tensor_two.any();
1440 /// println!("{}", any_tensor_two);
1441 /// // Tensor { data: [false], ... }
1442 /// }
1443 /// ```
1444 pub fn any(self) -> Tensor<B, 1, Bool> {
1445 Tensor::new(K::any(self.primitive))
1446 }
1447
1448 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
1449 ///
1450 /// # Arguments
1451 ///
1452 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1453 /// * `dim` - The axis along which to test.
1454 ///
1455 /// # Returns
1456 ///
1457 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1458 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1459 /// evaluates to True, False otherwise.
1460 ///
1461 /// # Example
1462 ///
1463 /// ```rust
1464 /// use burn_tensor::backend::Backend;
1465 /// use burn_tensor::{Tensor, Bool};
1466 ///
1467 /// fn example<B: Backend>() {
1468 /// let device = Default::default();
1469 /// let tensor =
1470 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
1471 /// // Check if any element in the tensor evaluates to True along the dimension 1.
1472 /// // [[true], [true]],
1473 /// let any_dim = tensor.clone().any_dim(1);
1474 /// println!("{any_dim}");
1475 /// }
1476 /// ```
1477 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1478 Tensor::new(K::any_dim(self.primitive, dim))
1479 }
1480
1481 /// Tests if all elements in the `tensor` evaluate to True.
1482 ///
1483 /// # Arguments
1484 ///
1485 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1486 ///
1487 /// # Returns
1488 ///
1489 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1490 /// evaluate to True, False otherwise.
1491 ///
1492 /// # Example
1493 ///
1494 /// ```rust
1495 /// use burn_tensor::backend::Backend;
1496 /// use burn_tensor::{Tensor, Bool};
1497 ///
1498 /// fn example<B: Backend>() {
1499 /// let device = Default::default();
1500 /// let tensor =
1501 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
1502 /// // Check if all elements in the tensor evaluate to True (which is not the case).
1503 /// // [false]
1504 /// let all = tensor.all();
1505 /// println!("{all}");
1506 /// }
1507 /// ```
1508 pub fn all(self) -> Tensor<B, 1, Bool> {
1509 Tensor::new(K::all(self.primitive))
1510 }
1511
1512 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1513 ///
1514 /// # Arguments
1515 ///
1516 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1517 /// * `dim` - The axis along which to test.
1518 ///
1519 /// # Returns
1520 ///
1521 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1522 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1523 /// evaluates to True, False otherwise.
1524 ///
1525 /// # Example
1526 ///
1527 /// ```rust
1528 /// use burn_tensor::backend::Backend;
1529 /// use burn_tensor::{Tensor, Bool};
1530 ///
1531 /// fn example<B: Backend>() {
1532 /// let device = Default::default();
1533 /// let tensor =
1534 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
1535 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
1536 /// // [[true, true, false]]
1537 /// let all_dim = tensor.clone().all_dim(0);
1538 /// println!("{all_dim}");
1539 /// }
1540 /// ```
1541 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1542 Tensor::new(K::all_dim(self.primitive, dim))
1543 }
1544
1545 /// Convert the tensor into a scalar.
1546 ///
1547 /// # Panics
1548 ///
1549 /// - If the tensor doesn't have one element.
1550 /// - If the backend fails to read the tensor data synchronously.
1551 ///
1552 /// # Returns
1553 ///
1554 /// The scalar value of the tensor.
1555 ///
1556 /// # Example
1557 ///
1558 /// ```rust
1559 /// use burn_tensor::backend::Backend;
1560 /// use burn_tensor::Tensor;
1561 ///
1562 /// fn example<B: Backend>() {
1563 /// let device = Default::default();
1564 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
1565 /// // Convert the tensor with a single element into a scalar.
1566 /// let scalar = tensor.into_scalar();
1567 /// println!("{scalar}");
1568 /// }
1569 /// ```
1570 pub fn into_scalar(self) -> K::Elem {
1571 crate::try_read_sync(self.into_scalar_async()).expect(
1572 "Failed to read tensor data synchronously. This can happen on platforms
1573 that don't support blocking futures like WASM. Try into_scalar_async instead.",
1574 )
1575 }
1576
1577 /// Convert the tensor into a scalar.
1578 ///
1579 /// # Panics
1580 ///
1581 /// If the tensor doesn't have one element.
1582 pub async fn into_scalar_async(self) -> K::Elem {
1583 check!(TensorCheck::into_scalar::<D>(&self.shape()));
1584 self.into_data_async().await.iter().next().unwrap()
1585 }
1586
1587 /// Broadcast the tensor to the given shape.
1588 ///
1589 /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
1590 /// (which can be inferred with `-1`).
1591 ///
1592 /// # Arguments
1593 ///
1594 /// * `shape` - The shape to broadcast the tensor to.
1595 /// Can contain -1 for dimensions that should be inferred.
1596 /// The number of elements in the shape must be greater or equal as
1597 /// the number of dimensions of the tensor.
1598 ///
1599 /// # Panics
1600 ///
1601 /// If the tensor cannot be broadcasted to the given shape.
1602 ///
1603 /// # Returns
1604 ///
1605 /// A new tensor with the given shape.
1606 ///
1607 /// # Example
1608 ///
1609 /// ```rust
1610 /// use burn_tensor::backend::Backend;
1611 /// use burn_tensor::Tensor;
1612 ///
1613 /// fn example<B: Backend>() {
1614 /// let device = Default::default();
1615 /// // Create a 2D tensor with dimensions [3, 1]
1616 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
1617 /// // Expand the tensor to a new shape [3, 4]
1618 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
1619 /// let expanded = tensor.expand([3, 4]);
1620 /// println!("{}", expanded);
1621 /// }
1622 /// ```
1623 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
1624 let shape = shape.into_shape(&self.shape());
1625 check!(TensorCheck::expand::<D, D2>(
1626 "expand",
1627 &self.shape(),
1628 &shape,
1629 ));
1630
1631 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
1632 }
1633}
1634
1635/// Iterator given by (Tensor::iter_dim).
1636pub struct DimIter<B, const D: usize, K>
1637where
1638 B: Backend,
1639 K: BasicOps<B>,
1640{
1641 start: usize,
1642 end: usize,
1643 dim: usize,
1644 ranges: [Range<usize>; D],
1645 tensor: Tensor<B, D, K>,
1646}
1647
1648impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
1649 type Item = Tensor<B, D, K>;
1650
1651 fn next(&mut self) -> Option<Self::Item> {
1652 if self.start >= self.end {
1653 return None;
1654 }
1655
1656 let mut ranges = self.ranges.clone();
1657 ranges[self.dim] = self.start..(self.start + 1);
1658
1659 let slice = self.tensor.clone().slice(ranges);
1660 self.start += 1;
1661
1662 Some(slice)
1663 }
1664}
1665
1666impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
1667 fn next_back(&mut self) -> Option<Self::Item> {
1668 if self.start >= self.end {
1669 return None;
1670 }
1671
1672 let mut ranges = self.ranges.clone();
1673 ranges[self.dim] = (self.end - 1)..self.end;
1674
1675 let slice = self.tensor.clone().slice(ranges);
1676 self.end = self.end.saturating_sub(1);
1677
1678 Some(slice)
1679 }
1680}
1681
1682impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
1683 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
1684 let dims = tensor.dims();
1685 let ranges = dims
1686 .iter()
1687 .map(|&dim| 0..dim)
1688 .collect::<Vec<Range<usize>>>();
1689 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
1690 Self {
1691 end: dims[dim],
1692 ranges,
1693 start: 0,
1694 dim,
1695 tensor,
1696 }
1697 }
1698}
1699
1700impl<B, const D: usize, K> Tensor<B, D, K>
1701where
1702 B: Backend,
1703 K: BasicOps<B>,
1704 <K as BasicOps<B>>::Elem: Debug,
1705{
1706 #[inline]
1707 fn push_newline_indent(acc: &mut String, indent: usize) {
1708 acc.push('\n');
1709 for _ in 0..indent {
1710 acc.push(' ');
1711 }
1712 }
1713 fn fmt_inner_tensor(
1714 &self,
1715 acc: &mut String,
1716 depth: usize,
1717 multi_index: &mut [usize],
1718 range: (usize, usize),
1719 precision: Option<usize>,
1720 ) {
1721 let (start, end) = range;
1722 for i in start..end {
1723 if i > 0 {
1724 acc.push_str(", ");
1725 }
1726 multi_index[depth] = i;
1727 let range: [Range<usize>; D] =
1728 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
1729
1730 let data =
1731 burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
1732
1733 if let Some(data) = data {
1734 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
1735 match (precision, K::name()) {
1736 (Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
1737 (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
1738 _ => acc.push_str(&format!("{:?}", elem)),
1739 }
1740 } else {
1741 acc.push_str("<Tensor data not available>");
1742 }
1743 }
1744 }
1745
1746 fn fmt_outer_tensor(
1747 &self,
1748 acc: &mut String,
1749 depth: usize,
1750 multi_index: &mut [usize],
1751 print_options: &PrintOptions,
1752 summarize: bool,
1753 range: (usize, usize),
1754 ) {
1755 let (start, end) = range;
1756 for i in start..end {
1757 if i > start {
1758 acc.push(',');
1759 Self::push_newline_indent(acc, depth + 1);
1760 }
1761 acc.push('[');
1762 multi_index[depth] = i;
1763 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
1764 acc.push(']');
1765 }
1766 }
1767
1768 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
1769 ///
1770 /// This function is designed to work with tensors of any dimensionality.
1771 /// It traverses the tensor dimensions recursively, converting the elements
1772 /// to strings and appending them to the accumulator string with the
1773 /// appropriate formatting.
1774 ///
1775 /// # Arguments
1776 ///
1777 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
1778 /// * `depth` - The current depth of the tensor dimensions being processed.
1779 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
1780 fn display_recursive(
1781 &self,
1782 acc: &mut String,
1783 depth: usize,
1784 multi_index: &mut [usize],
1785 print_options: &PrintOptions,
1786 summarize: bool,
1787 ) {
1788 let edge_items = print_options.edge_items;
1789
1790 if depth == 0 {
1791 acc.push('[');
1792 }
1793
1794 if depth == self.dims().len() - 1 {
1795 // if we are at the innermost dimension, just push its elements into the accumulator
1796 if summarize && self.dims()[depth] > 2 * edge_items {
1797 // print the starting `edge_items` elements
1798 self.fmt_inner_tensor(
1799 acc,
1800 depth,
1801 multi_index,
1802 (0, edge_items),
1803 print_options.precision,
1804 );
1805 acc.push_str(", ...");
1806 // print the last `edge_items` elements
1807 self.fmt_inner_tensor(
1808 acc,
1809 depth,
1810 multi_index,
1811 (self.dims()[depth] - edge_items, self.dims()[depth]),
1812 print_options.precision,
1813 );
1814 } else {
1815 // print all the elements
1816 self.fmt_inner_tensor(
1817 acc,
1818 depth,
1819 multi_index,
1820 (0, self.dims()[depth]),
1821 print_options.precision,
1822 );
1823 }
1824 } else {
1825 // otherwise, iterate through the current dimension and recursively display the inner tensors
1826 if summarize && self.dims()[depth] > 2 * edge_items {
1827 self.fmt_outer_tensor(
1828 acc,
1829 depth,
1830 multi_index,
1831 print_options,
1832 summarize,
1833 (0, edge_items),
1834 );
1835
1836 acc.push(',');
1837 Self::push_newline_indent(acc, depth + 1);
1838 acc.push_str("...");
1839 Self::push_newline_indent(acc, depth + 1);
1840
1841 self.fmt_outer_tensor(
1842 acc,
1843 depth,
1844 multi_index,
1845 print_options,
1846 summarize,
1847 (self.dims()[depth] - edge_items, self.dims()[depth]),
1848 );
1849 } else {
1850 self.fmt_outer_tensor(
1851 acc,
1852 depth,
1853 multi_index,
1854 print_options,
1855 summarize,
1856 (0, self.dims()[depth]),
1857 );
1858 }
1859 }
1860
1861 if depth == 0 {
1862 acc.push(']');
1863 }
1864 }
1865}
1866
1867#[derive(Clone, Debug)]
1868/// Options for Tensor pretty printing
1869pub struct PrintOptions {
1870 /// number of elements to start summarizing tensor
1871 pub threshold: usize,
1872
1873 /// number of starting elements and ending elements to display
1874 pub edge_items: usize,
1875
1876 /// Precision for floating point numbers
1877 pub precision: Option<usize>,
1878}
1879
1880static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
1881
1882impl PrintOptions {
1883 /// Print options with default values
1884 pub const fn const_default() -> Self {
1885 Self {
1886 threshold: 1000,
1887 edge_items: 3,
1888 precision: None,
1889 }
1890 }
1891}
1892
1893impl Default for PrintOptions {
1894 fn default() -> Self {
1895 Self::const_default()
1896 }
1897}
1898
1899/// Set print options
1900pub fn set_print_options(options: PrintOptions) {
1901 let mut print_opts = PRINT_OPTS.write().unwrap();
1902 *print_opts = options;
1903}
1904
1905/// Pretty print tensors
1906impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
1907where
1908 B: Backend,
1909 B::IntElem: core::fmt::Display,
1910 K: BasicOps<B>,
1911 <K as BasicOps<B>>::Elem: Debug,
1912{
1913 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1914 writeln!(f, "Tensor {{")?;
1915
1916 {
1917 // Do not lock the mutex for the whole function
1918 let mut po = { PRINT_OPTS.read().unwrap().clone() };
1919
1920 // Override the precision if it is set from the formatter
1921 // This will be possible when the tensor is printed using the `{:.*}` syntax
1922 if let Some(precision) = f.precision() {
1923 po.precision = Some(precision);
1924 }
1925
1926 let mut acc = String::new();
1927 let mut multi_index = vec![0; D];
1928 let summarize = self.shape().num_elements() > po.threshold;
1929
1930 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
1931
1932 writeln!(f, " data:")?;
1933 write!(f, "{acc}")?;
1934 writeln!(f, ",")?;
1935 }
1936
1937 writeln!(f, " shape: {:?},", self.dims())?;
1938 writeln!(f, " device: {:?},", self.device())?;
1939 writeln!(f, " backend: {:?},", B::name(&self.device()))?;
1940 writeln!(f, " kind: {:?},", K::name())?;
1941
1942 let dtype = self.primitive.dtype();
1943
1944 writeln!(f, " dtype: {:?},", dtype.name())?;
1945 write!(f, "}}")
1946 }
1947}
1948
1949/// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g.
1950/// ```rust
1951/// use burn_tensor::backend::Backend;
1952/// use burn_tensor::{Tensor, T};
1953///
1954/// fn example<B: Backend>() {
1955/// let device = Default::default();
1956/// let tensor = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
1957/// let transposed = tensor^T;
1958/// }
1959/// ```
1960pub struct T;
1961
1962impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
1963 type Output = Self;
1964 fn bitxor(self, _: T) -> Self::Output {
1965 self.transpose()
1966 }
1967}
1968
1969/// Trait that list all operations that can be applied on all tensors.
1970///
1971/// # Warnings
1972///
1973/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
1974pub trait BasicOps<B: Backend>: TensorKind<B> {
1975 /// The type of the tensor elements.
1976 type Elem: Element;
1977
1978 /// Creates an empty tensor with the given shape.
1979 ///
1980 /// # Arguments
1981 ///
1982 /// * `shape` - The shape of the tensor.
1983 /// * `device` - The device on which the tensor will be allocated.
1984 ///
1985 /// # Returns
1986 ///
1987 /// The empty tensor.
1988 ///
1989 /// # Remarks
1990 ///
1991 /// This is a low-level function used internally by the library to call different backend functions
1992 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1993 /// or use this function directly.
1994 ///
1995 /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
1996 /// which is more high-level and designed for public use.
1997 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive;
1998
1999 /// Reshapes the tensor.
2000 ///
2001 /// # Arguments
2002 ///
2003 /// * `tensor` - The tensor.
2004 /// * `shape` - The new shape of the tensor.
2005 ///
2006 /// # Returns
2007 ///
2008 /// The reshaped tensor.
2009 ///
2010 /// # Remarks
2011 ///
2012 /// This is a low-level function used internally by the library to call different backend functions
2013 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2014 /// or use this function directly.
2015 ///
2016 /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
2017 /// which is more high-level and designed for public use.
2018 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2019
2020 /// Transposes a tensor.
2021 ///
2022 /// # Arguments
2023 ///
2024 /// * `tensor` - The tensor to transpose.
2025 ///
2026 /// # Returns
2027 ///
2028 /// The transposed tensor.
2029 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
2030
2031 /// Swaps two dimensions of a tensor.
2032 ///
2033 /// # Arguments
2034 ///
2035 /// * `tensor` - The tensor to swap the dimensions of.
2036 /// * `dim1` - The first dimension to swap.
2037 /// * `dim2` - The second dimension to swap.
2038 ///
2039 /// # Returns
2040 ///
2041 /// The tensor with the dimensions swapped.
2042 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
2043
2044 /// Permutes the dimensions of a tensor.
2045 ///
2046 /// # Arguments
2047 ///
2048 /// * `tensor` - The tensor to permute the dimensions of.
2049 /// * `axes` - The new order of the dimensions.
2050 ///
2051 /// # Returns
2052 ///
2053 /// The tensor with the dimensions permuted.
2054 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2055
2056 /// Flips the tensor along the given axes.
2057 ///
2058 /// # Arguments
2059 ///
2060 /// * `tensor` - The tensor to flip.
2061 /// * `axes` - The axes to flip the tensor along.
2062 ///
2063 /// # Returns
2064 ///
2065 /// The tensor with the axes flipped.
2066 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2067
2068 /// Select tensor elements corresponding for the given ranges.
2069 ///
2070 /// # Arguments
2071 ///
2072 /// * `tensor` - The tensor.
2073 /// * `ranges` - The ranges of the elements to select.
2074 ///
2075 /// # Returns
2076 ///
2077 /// The selected elements.
2078 ///
2079 /// # Remarks
2080 ///
2081 /// This is a low-level function used internally by the library to call different backend functions
2082 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2083 /// or use this function directly.
2084 ///
2085 /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2086 /// which is more high-level and designed for public use.
2087 fn slice(tensor: Self::Primitive, range: &[Range<usize>]) -> Self::Primitive;
2088
2089 /// Assigns the given value to the tensor elements corresponding for the given ranges.
2090 ///
2091 /// # Arguments
2092 ///
2093 /// * `tensor` - The tensor.
2094 /// * `ranges` - The ranges of the elements to select.
2095 /// * `value` - The value to assign.
2096 ///
2097 /// # Returns
2098 ///
2099 /// The tensor with the assigned values.
2100 ///
2101 /// # Remarks
2102 ///
2103 /// This is a low-level function used internally by the library to call different backend functions
2104 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2105 /// or use this function directly.
2106 ///
2107 /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2108 /// which is more high-level and designed for public use.
2109 fn slice_assign(
2110 tensor: Self::Primitive,
2111 ranges: &[Range<usize>],
2112 value: Self::Primitive,
2113 ) -> Self::Primitive;
2114
2115 /// Returns the device on which the tensor is allocated.
2116 ///
2117 /// # Arguments
2118 ///
2119 /// * `tensor` - The tensor.
2120 ///
2121 /// # Returns
2122 ///
2123 /// The device on which the tensor is allocated.
2124 ///
2125 /// # Remarks
2126 ///
2127 /// This is a low-level function used internally by the library to call different backend functions
2128 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2129 /// or use this function directly.
2130 ///
2131 /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2132 /// which is more high-level and designed for public use.
2133 fn device(tensor: &Self::Primitive) -> B::Device;
2134
2135 /// Moves the tensor to the given device.
2136 ///
2137 /// # Arguments
2138 ///
2139 /// * `tensor` - The tensor.
2140 /// * `device` - The device on which the tensor will be moved.
2141 ///
2142 /// # Returns
2143 ///
2144 /// The tensor on the given device.
2145 ///
2146 /// # Remarks
2147 ///
2148 /// This is a low-level function used internally by the library to call different backend functions
2149 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2150 /// or use this function directly.
2151 ///
2152 /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2153 /// which is more high-level and designed for public use.
2154 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2155
2156 /// Extracts the data from the tensor asynchronously.
2157 ///
2158 /// # Arguments
2159 ///
2160 /// * `tensor` - The tensor.
2161 ///
2162 /// # Returns
2163 ///
2164 /// The data of the tensor.
2165 ///
2166 /// # Remarks
2167 ///
2168 /// This is a low-level function used internally by the library to call different backend functions
2169 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2170 /// or use this function directly.
2171 ///
2172 /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
2173 /// which is more high-level and designed for public use.
2174 fn into_data_async(
2175 tensor: Self::Primitive,
2176 ) -> impl Future<Output = TensorData> + 'static + Send;
2177
2178 /// Read the data from the tensor using a transaction.
2179 ///
2180 /// # Remarks
2181 ///
2182 /// This is a low-level function used internally by the library to call different backend functions
2183 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2184 /// or use this function directly.
2185 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
2186
2187 /// Creates a tensor from the given data.
2188 ///
2189 /// # Arguments
2190 ///
2191 /// * `data` - The data of the tensor.
2192 /// * `device` - The device on which the tensor will be allocated.
2193 ///
2194 /// # Returns
2195 ///
2196 /// The tensor.
2197 ///
2198 /// # Remarks
2199 ///
2200 /// This is a low-level function used internally by the library to call different backend functions
2201 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2202 /// or use this function directly.
2203 ///
2204 /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
2205 /// which is more high-level and designed for public use.
2206 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
2207 /// Creates a tensor from the given data enforcing the given data type.
2208 ///
2209 /// # Remarks
2210 ///
2211 /// This is a low-level function used internally by the library to call different backend functions
2212 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2213 /// or use this function directly.
2214 ///
2215 /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype)
2216 /// function, which is more high-level and designed for public use.
2217 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
2218
2219 /// Repeat the tensor along the given dimension.
2220 ///
2221 /// # Arguments
2222 ///
2223 /// * `tensor` - The tensor.
2224 /// * `dim` - The dimension along which the tensor will be repeated.
2225 /// * `times` - The number of times the tensor will be repeated.
2226 ///
2227 /// # Returns
2228 ///
2229 /// The repeated tensor.
2230 ///
2231 /// # Remarks
2232 ///
2233 /// This is a low-level function used internally by the library to call different backend functions
2234 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2235 /// or use this function directly.
2236 ///
2237 /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
2238 /// which is more high-level and designed for public use.
2239 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
2240
2241 /// Concatenates the given tensors along the given dimension.
2242 ///
2243 /// # Arguments
2244 ///
2245 /// * `vectors` - The tensors to concatenate.
2246 /// * `dim` - The dimension along which the tensors will be concatenated.
2247 ///
2248 /// # Returns
2249 ///
2250 /// The concatenated tensor.
2251 ///
2252 /// # Remarks
2253 ///
2254 /// This is a low-level function used internally by the library to call different backend functions
2255 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2256 /// or use this function directly.
2257 ///
2258 /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
2259 /// which is more high-level and designed for public use.
2260 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
2261
2262 /// Attempts to split the tensor along the given dimension into chunks.
2263 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2264 ///
2265 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2266 /// Otherwise all chunks will be of equal size except for the last one.
2267 ///
2268 /// # Panics
2269 ///
2270 /// If the dimension is greater than the number of dimensions of the tensor.
2271 ///
2272 /// # Returns
2273 /// A vector of tensors.
2274 ///
2275 /// # Remarks
2276 ///
2277 /// This is a low-level function used internally by the library to call different backend functions
2278 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2279 /// or use this function directly.
2280 ///
2281 /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
2282 /// which is more high-level and designed for public use.
2283 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;
2284
2285 /// Splits the tensor into chunks of a specified size along a given dimension.
2286 /// Each chunk is a view of the original tensor.
2287 ///
2288 /// # Panics
2289 ///
2290 /// If the dimension to split along is greater than the number of dimensions of the tensor.
2291 ///
2292 /// # Returns
2293 ///
2294 /// A vector of tensors.
2295 ///
2296 /// # Remarks
2297 /// This is a low-level function used internally by the library to call different backend functions
2298 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2299 /// or use this function directly.
2300 ///
2301 /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
2302 /// which is more high-level and designed for public use.
2303 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
2304
2305 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2306 /// Each chunk is a view of the original tensor.
2307 ///
2308 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2309 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2310 ///
2311 /// # Panics
2312 ///
2313 /// If the dimension to split along is greater than the number of dimensions of the tensor or
2314 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2315 ///
2316 /// # Returns
2317 ///
2318 /// A vector of tensors.
2319 ///
2320 /// # Remarks
2321 /// This is a low-level function used internally by the library to call different backend functions
2322 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2323 /// or use this function directly.
2324 ///
2325 /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
2326 /// which is more high-level and designed for public use.
2327 fn split_with_sizes(
2328 tensor: Self::Primitive,
2329 split_sizes: Vec<usize>,
2330 dim: usize,
2331 ) -> Vec<Self::Primitive>;
2332
2333 /// Equates the given tensors.
2334 ///
2335 /// # Arguments
2336 ///
2337 /// * `lhs` - The left hand side tensor.
2338 /// * `rhs` - The right hand side tensor.
2339 ///
2340 /// # Returns
2341 ///
2342 /// The tensor of booleans indicating whether the corresponding elements are equal.
2343 ///
2344 /// # Remarks
2345 ///
2346 /// This is a low-level function used internally by the library to call different backend functions
2347 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2348 /// or use this function directly.
2349 ///
2350 /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
2351 /// which is more high-level and designed for public use.
2352 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2353
2354 /// Applies element-wise non-equality comparison between the given tensors.
2355 ///
2356 /// # Arguments
2357 ///
2358 /// * `lhs` - The left hand side tensor.
2359 /// * `rhs` - The right hand side tensor.
2360 ///
2361 /// # Returns
2362 ///
2363 /// The tensor of booleans indicating whether the corresponding elements are equal.
2364 ///
2365 /// # Remarks
2366 ///
2367 /// This is a low-level function used internally by the library to call different backend functions
2368 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2369 /// or use this function directly.
2370 ///
2371 /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
2372 /// function, which is more high-level and designed for public use.
2373 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2374
2375 /// Returns the name of the element type.
2376 fn elem_type_name() -> &'static str {
2377 core::any::type_name::<Self::Elem>()
2378 }
2379
2380 /// Returns the tensor data type.
2381 fn dtype(tensor: &Self::Primitive) -> DType {
2382 tensor.dtype()
2383 }
2384
2385 /// Tests if any element in the `tensor` evaluates to True.
2386 ///
2387 /// # Arguments
2388 ///
2389 /// * `tensor` - The tensor to test.
2390 ///
2391 /// # Returns
2392 ///
2393 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
2394 ///
2395 /// # Remarks
2396 ///
2397 /// This is a low-level function used internally by the library to call different backend functions
2398 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2399 /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
2400 /// which is more high-level and designed for public use.
2401 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2402
2403 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
2404 ///
2405 /// # Arguments
2406 ///
2407 /// * tensor - The tensor to test.
2408 /// * dim - The axis along which to test.
2409 ///
2410 /// # Returns
2411 ///
2412 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
2413 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
2414 ///
2415 /// # Remarks
2416 ///
2417 /// This is a low-level function used internally by the library to call different backend functions
2418 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2419 /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
2420 /// which is more high-level and designed for public use.
2421 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2422
2423 /// Tests if all elements in the `tensor` evaluate to True.
2424 ///
2425 /// # Arguments
2426 ///
2427 /// * `tensor` - The tensor to test.
2428 ///
2429 /// # Returns
2430 ///
2431 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
2432 ///
2433 /// # Remarks
2434 ///
2435 /// This is a low-level function used internally by the library to call different backend functions
2436 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2437 /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
2438 /// which is more high-level and designed for public use.
2439 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2440
2441 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2442 ///
2443 /// # Arguments
2444 ///
2445 /// * `tensor` - The tensor to test.
2446 ///
2447 /// # Returns
2448 ///
2449 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
2450 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
2451 ///
2452 /// # Remarks
2453 ///
2454 /// This is a low-level function used internally by the library to call different backend functions
2455 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2456 /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
2457 /// which is more high-level and designed for public use.
2458 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2459
2460 /// Broadcasts the given tensor to the specified shape.
2461 ///
2462 /// # Arguments
2463 ///
2464 /// * `tensor` - The tensor to broadcast.
2465 /// * `shape` - The shape to broadcast to.
2466 ///
2467 /// # Returns
2468 ///
2469 /// The broadcasted tensor.
2470 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2471}
2472
2473impl<B: Backend> BasicOps<B> for Float {
2474 type Elem = B::FloatElem;
2475
2476 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2477 TensorPrimitive::Float(B::float_empty(shape, device))
2478 }
2479
2480 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2481 tr.register_float(tensor);
2482 }
2483
2484 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2485 match tensor {
2486 TensorPrimitive::Float(tensor) => {
2487 TensorPrimitive::Float(B::float_reshape(tensor, shape))
2488 }
2489 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
2490 }
2491 }
2492
2493 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2494 match tensor {
2495 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
2496 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
2497 }
2498 }
2499
2500 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2501 match tensor {
2502 TensorPrimitive::Float(tensor) => {
2503 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
2504 }
2505 TensorPrimitive::QFloat(tensor) => {
2506 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
2507 }
2508 }
2509 }
2510
2511 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2512 match tensor {
2513 TensorPrimitive::Float(tensor) => {
2514 TensorPrimitive::Float(B::float_slice(tensor, ranges))
2515 }
2516 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)),
2517 }
2518 }
2519
2520 fn slice_assign(
2521 tensor: Self::Primitive,
2522 ranges: &[Range<usize>],
2523 value: Self::Primitive,
2524 ) -> Self::Primitive {
2525 match (tensor, value) {
2526 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => {
2527 TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value))
2528 }
2529 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => {
2530 TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value))
2531 }
2532 _ => panic!("Primitive type mismatch for tensor and value"),
2533 }
2534 }
2535
2536 fn device(tensor: &Self::Primitive) -> Device<B> {
2537 match tensor {
2538 TensorPrimitive::Float(tensor) => B::float_device(tensor),
2539 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
2540 }
2541 }
2542
2543 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2544 match tensor {
2545 TensorPrimitive::Float(tensor) => {
2546 TensorPrimitive::Float(B::float_to_device(tensor, device))
2547 }
2548 TensorPrimitive::QFloat(tensor) => {
2549 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
2550 }
2551 }
2552 }
2553
2554 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2555 match tensor {
2556 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
2557 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
2558 }
2559 }
2560
2561 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2562 match data.dtype {
2563 DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
2564 _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
2565 }
2566 }
2567
2568 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2569 match dtype {
2570 DType::QFloat(_strategy) => {
2571 TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
2572 }
2573 _ if dtype.is_float() => {
2574 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
2575 }
2576 _ => panic!("Expected float dtype, got {dtype:?}"),
2577 }
2578 }
2579
2580 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2581 match tensor {
2582 TensorPrimitive::Float(tensor) => {
2583 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
2584 }
2585 TensorPrimitive::QFloat(tensor) => {
2586 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
2587 }
2588 }
2589 }
2590
2591 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2592 match vectors.first().unwrap() {
2593 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
2594 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
2595 dim,
2596 )),
2597 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
2598 vectors
2599 .into_iter()
2600 .map(|tensor| {
2601 if let TensorPrimitive::QFloat(t) = tensor {
2602 t
2603 } else {
2604 panic!("Concatenation only works with vector of QFloat")
2605 }
2606 })
2607 .collect(),
2608 dim,
2609 )),
2610 }
2611 }
2612
2613 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2614 B::float_equal(lhs.tensor(), rhs.tensor())
2615 }
2616
2617 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2618 B::float_not_equal(lhs.tensor(), rhs.tensor())
2619 }
2620
2621 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2622 B::float_any(tensor.tensor())
2623 }
2624
2625 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2626 B::float_any_dim(tensor.tensor(), dim)
2627 }
2628
2629 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2630 B::float_all(tensor.tensor())
2631 }
2632
2633 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2634 B::float_all_dim(tensor.tensor(), dim)
2635 }
2636
2637 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2638 match tensor {
2639 TensorPrimitive::Float(tensor) => {
2640 TensorPrimitive::Float(B::float_permute(tensor, axes))
2641 }
2642 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
2643 }
2644 }
2645
2646 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2647 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
2648 }
2649
2650 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2651 match tensor {
2652 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
2653 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
2654 }
2655 }
2656
2657 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2658 match tensor {
2659 TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim)
2660 .into_iter()
2661 .map(TensorPrimitive::Float)
2662 .collect(),
2663 TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim)
2664 .into_iter()
2665 .map(TensorPrimitive::QFloat)
2666 .collect(),
2667 }
2668 }
2669
2670 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2671 match tensor {
2672 TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
2673 .into_iter()
2674 .map(TensorPrimitive::Float)
2675 .collect(),
2676 TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
2677 .into_iter()
2678 .map(TensorPrimitive::QFloat)
2679 .collect(),
2680 }
2681 }
2682
2683 fn split_with_sizes(
2684 tensor: Self::Primitive,
2685 split_sizes: Vec<usize>,
2686 dim: usize,
2687 ) -> Vec<Self::Primitive> {
2688 match tensor {
2689 TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
2690 .into_iter()
2691 .map(TensorPrimitive::Float)
2692 .collect(),
2693 TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
2694 .into_iter()
2695 .map(TensorPrimitive::QFloat)
2696 .collect(),
2697 }
2698 }
2699}
2700
2701impl<B: Backend> BasicOps<B> for Int {
2702 type Elem = B::IntElem;
2703
2704 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2705 B::int_empty(shape, device)
2706 }
2707
2708 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2709 tr.register_int(tensor);
2710 }
2711
2712 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2713 B::int_reshape(tensor, shape)
2714 }
2715
2716 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2717 B::int_transpose(tensor)
2718 }
2719
2720 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2721 B::int_swap_dims(tensor, dim1, dim2)
2722 }
2723
2724 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2725 B::int_slice(tensor, ranges)
2726 }
2727
2728 fn slice_assign(
2729 tensor: Self::Primitive,
2730 ranges: &[Range<usize>],
2731 value: Self::Primitive,
2732 ) -> Self::Primitive {
2733 B::int_slice_assign(tensor, ranges, value)
2734 }
2735
2736 fn device(tensor: &Self::Primitive) -> Device<B> {
2737 B::int_device(tensor)
2738 }
2739
2740 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2741 B::int_to_device(tensor, device)
2742 }
2743
2744 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2745 B::int_into_data(tensor).await
2746 }
2747
2748 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2749 B::int_from_data(data.convert::<B::IntElem>(), device)
2750 }
2751
2752 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2753 if !dtype.is_int() {
2754 panic!("Expected int dtype, got {dtype:?}")
2755 }
2756
2757 B::int_from_data(data.convert_dtype(dtype), device)
2758 }
2759
2760 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2761 B::int_repeat_dim(tensor, dim, times)
2762 }
2763
2764 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2765 B::int_equal(lhs, rhs)
2766 }
2767
2768 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2769 B::int_not_equal(lhs, rhs)
2770 }
2771
2772 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2773 B::int_cat(vectors, dim)
2774 }
2775
2776 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2777 B::int_any(tensor)
2778 }
2779
2780 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2781 B::int_any_dim(tensor, dim)
2782 }
2783
2784 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2785 B::int_all(tensor)
2786 }
2787
2788 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2789 B::int_all_dim(tensor, dim)
2790 }
2791
2792 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2793 B::int_permute(tensor, axes)
2794 }
2795
2796 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2797 B::int_expand(tensor, shape)
2798 }
2799
2800 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2801 B::int_flip(tensor, axes)
2802 }
2803
2804 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2805 B::int_chunk(tensor, chunks, dim)
2806 }
2807
2808 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2809 B::int_split(tensor, split_size, dim)
2810 }
2811
2812 fn split_with_sizes(
2813 tensor: Self::Primitive,
2814 split_sizes: Vec<usize>,
2815 dim: usize,
2816 ) -> Vec<Self::Primitive> {
2817 B::int_split_with_sizes(tensor, split_sizes, dim)
2818 }
2819}
2820
2821impl<B: Backend> BasicOps<B> for Bool {
2822 type Elem = B::BoolElem;
2823
2824 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2825 B::bool_empty(shape, device)
2826 }
2827
2828 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2829 tr.register_bool(tensor);
2830 }
2831
2832 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2833 B::bool_reshape(tensor, shape)
2834 }
2835
2836 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2837 B::bool_transpose(tensor)
2838 }
2839
2840 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2841 B::bool_swap_dims(tensor, dim1, dim2)
2842 }
2843
2844 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2845 B::bool_slice(tensor, ranges)
2846 }
2847
2848 fn slice_assign(
2849 tensor: Self::Primitive,
2850 ranges: &[Range<usize>],
2851 value: Self::Primitive,
2852 ) -> Self::Primitive {
2853 B::bool_slice_assign(tensor, ranges, value)
2854 }
2855
2856 fn device(tensor: &Self::Primitive) -> Device<B> {
2857 B::bool_device(tensor)
2858 }
2859
2860 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2861 B::bool_to_device(tensor, device)
2862 }
2863
2864 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2865 B::bool_into_data(tensor).await
2866 }
2867
2868 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2869 B::bool_from_data(data.convert::<B::BoolElem>(), device)
2870 }
2871
2872 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2873 // Backends only use one bool representation dtype
2874 if dtype != B::BoolElem::dtype() {
2875 panic!("Expected bool dtype, got {dtype:?}")
2876 }
2877 B::bool_from_data(data.convert_dtype(dtype), device)
2878 }
2879
2880 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2881 B::bool_repeat_dim(tensor, dim, times)
2882 }
2883
2884 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2885 B::bool_equal(lhs, rhs)
2886 }
2887
2888 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2889 B::bool_not_equal(lhs, rhs)
2890 }
2891
2892 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2893 B::bool_cat(vectors, dim)
2894 }
2895
2896 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2897 B::bool_any(tensor)
2898 }
2899
2900 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2901 B::bool_any_dim(tensor, dim)
2902 }
2903
2904 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2905 B::bool_all(tensor)
2906 }
2907
2908 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2909 B::bool_all_dim(tensor, dim)
2910 }
2911
2912 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2913 B::bool_permute(tensor, axes)
2914 }
2915
2916 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2917 B::bool_expand(tensor, shape)
2918 }
2919
2920 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2921 B::bool_flip(tensor, axes)
2922 }
2923
2924 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2925 B::bool_chunk(tensor, chunks, dim)
2926 }
2927
2928 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2929 B::bool_split(tensor, split_size, dim)
2930 }
2931
2932 fn split_with_sizes(
2933 tensor: Self::Primitive,
2934 split_sizes: Vec<usize>,
2935 dim: usize,
2936 ) -> Vec<Self::Primitive> {
2937 B::bool_split_with_sizes(tensor, split_sizes, dim)
2938 }
2939}
2940
2941/// Trait used for movedim arguments
2942pub trait MovedimArgs {
2943 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
2944 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
2945}
2946
2947impl MovedimArgs for Vec<i32> {
2948 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2949 let set = self
2950 .iter()
2951 .map(|&dim| {
2952 if dim < 0 {
2953 (D as i32 + dim) as usize
2954 } else {
2955 dim as usize
2956 }
2957 })
2958 .collect::<Vec<usize>>();
2959 check!(TensorCheck::movedim_args_vec::<D>(&set));
2960
2961 set
2962 }
2963}
2964
2965impl MovedimArgs for Vec<usize> {
2966 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2967 check!(TensorCheck::movedim_args_vec::<D>(&self));
2968 self
2969 }
2970}
2971
2972impl MovedimArgs for usize {
2973 #[allow(clippy::vec_init_then_push)]
2974 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2975 check!(TensorCheck::movedim_args_usize::<D>(self));
2976
2977 let mut set = Vec::with_capacity(1);
2978 set.push(self);
2979
2980 set
2981 }
2982}
2983
2984impl MovedimArgs for i32 {
2985 #[allow(clippy::vec_init_then_push)]
2986 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2987 check!(TensorCheck::movedim_args_i32::<D>(self));
2988
2989 let dim = if self < 0 {
2990 (D as i32 + self) as usize
2991 } else {
2992 self as usize
2993 };
2994
2995 let mut set = Vec::with_capacity(1);
2996 set.push(dim);
2997
2998 set
2999 }
3000}
3001
3002/// Trait used for slice arguments
3003pub trait RangesArg<const D2: usize> {
3004 /// Converts into a set of ranges to `[Range<usize>; D2]` for the `tensor.slice()` function
3005 fn into_ranges(self, shape: Shape) -> [Range<usize>; D2];
3006}
3007
3008impl<const D2: usize, T: Into<Slice>> RangesArg<D2> for [T; D2] {
3009 fn into_ranges(self, shape: Shape) -> [Range<usize>; D2] {
3010 // clamp the ranges to the shape dimensions
3011 let ranges = self
3012 .into_iter()
3013 .enumerate()
3014 .map(|(i, range)| range.into().into_range(shape.dims[i]))
3015 .collect::<Vec<_>>();
3016 ranges.try_into().unwrap()
3017 }
3018}
3019
3020impl<T: Into<Slice>> RangesArg<1> for T {
3021 fn into_ranges(self, shape: Shape) -> [Range<usize>; 1] {
3022 [self.into().into_range(shape.dims[0])]
3023 }
3024}
3025
3026/// Trait used for reshape arguments.
3027pub trait ReshapeArgs<const D2: usize> {
3028 /// Converts to a shape.
3029 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3030 self,
3031 tensor: &Tensor<B, D, K>,
3032 ) -> Shape;
3033}
3034
3035impl<const D2: usize> ReshapeArgs<D2> for Shape {
3036 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3037 self,
3038 tensor: &Tensor<B, D, K>,
3039 ) -> Shape {
3040 check!(TensorCheck::reshape_args_usize::<D, D2>(
3041 &tensor.shape(),
3042 &self
3043 ));
3044
3045 self
3046 }
3047}
3048impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3049 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3050 self,
3051 tensor: &Tensor<B, D, K>,
3052 ) -> Shape {
3053 let shape = Shape::from(self);
3054
3055 check!(TensorCheck::reshape_args_usize::<D, D2>(
3056 &tensor.shape(),
3057 &shape
3058 ));
3059
3060 shape
3061 }
3062}
3063
3064impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3065 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3066 self,
3067 tensor: &Tensor<B, D, K>,
3068 ) -> Shape {
3069 // Validate the reshape arguments
3070 check!(TensorCheck::reshape_args_i32(&self));
3071
3072 // Temporary shape
3073 let mut new_shape: [i32; D2] = [1; D2];
3074
3075 // We need to find the index of the 0 dimension and
3076 // replace it with the actual dimension value.
3077 for (i, &s) in self.iter().enumerate() {
3078 if s != 0 {
3079 new_shape[i] = s;
3080 } else {
3081 new_shape[i] = tensor.dims()[i] as i32;
3082 }
3083 }
3084
3085 // Find the index of the inferred dimension (-1)
3086 let infer_index = new_shape.iter().position(|x| x == &-1);
3087
3088 // Handle the case where the dimension is inferred (via -1)
3089 if let Some(index) = infer_index {
3090 // Handle the case where the dimension is inferred
3091 let mut product = 1;
3092 for (i, &s) in new_shape.iter().enumerate() {
3093 if i != index {
3094 product *= s;
3095 }
3096 }
3097 let product_current = tensor.shape().num_elements() as i32;
3098
3099 new_shape[index] = product_current / product;
3100
3101 // Check if the reshape is valid
3102 if product_current % product != 0 {
3103 panic!(
3104 "Cannot reshape tensor of shape {:?} to shape {:?}",
3105 tensor.shape(),
3106 new_shape
3107 );
3108 }
3109 };
3110
3111 // Convert each element to usize
3112 let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3113
3114 Shape::from(new_shape)
3115 }
3116}
3117
3118/// Trait used for broadcast arguments.
3119pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3120 /// Converts to a shape.
3121 fn into_shape(self, shape: &Shape) -> Shape;
3122}
3123
3124impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3125 fn into_shape(self, _shape: &Shape) -> Shape {
3126 self
3127 }
3128}
3129impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3130 fn into_shape(self, _shape: &Shape) -> Shape {
3131 Shape::from(self)
3132 }
3133}
3134
3135impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3136 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3137 fn into_shape(self, shape: &Shape) -> Shape {
3138 if self.len() < shape.num_dims() {
3139 panic!("Broadcast arguments must be greater than the number of dimensions");
3140 }
3141
3142 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3143 let new_shape: Vec<_> = self
3144 .iter()
3145 .rev()
3146 .map(|x| {
3147 let primitive = x.to_i64();
3148 if primitive < -1 || primitive == 0 {
3149 panic!("Broadcast arguments must be positive or -1");
3150 }
3151 primitive
3152 })
3153 .zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3154 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3155 .collect::<Vec<_>>()
3156 .into_iter()
3157 .rev()
3158 .collect();
3159
3160 if new_shape.contains(&0) {
3161 panic!("Cannot substitute -1 for a non-existing dimension");
3162 }
3163
3164 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3165
3166 Shape::from(new_shape)
3167 }
3168}
3169
3170impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3171where
3172 B: Backend,
3173 K: BasicOps<B>,
3174 K::Elem: Debug + Copy + Serialize,
3175{
3176 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3177 let data = self.to_data();
3178 data.serialize(serializer)
3179 }
3180}
3181
3182impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3183where
3184 B: Backend,
3185 K: BasicOps<B>,
3186 K::Elem: Debug + Copy + Deserialize<'de>,
3187{
3188 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3189 let tensor = Tensor::from_data(
3190 TensorData::deserialize(deserializer)?,
3191 &<B::Device as Default>::default(),
3192 );
3193 Ok(tensor)
3194 }
3195}
3196
3197#[cfg(test)]
3198mod tests {
3199 use crate::Shape;
3200 use crate::s;
3201
3202 use super::*;
3203
3204 #[test]
3205 fn slice_range_single_dim_leading() {
3206 let shape = Shape::new([8, 4]);
3207
3208 // Half-open range
3209 assert_eq!([0..5], (0..5).into_ranges(shape.clone()));
3210 assert_eq!([0..5], [0..5].into_ranges(shape.clone()));
3211 assert_eq!([5..7], [-3..-1].into_ranges(shape.clone()));
3212
3213 // Inclusive range
3214 assert_eq!([0..5], (0..=4).into_ranges(shape.clone()));
3215 assert_eq!([0..5], [0..=4].into_ranges(shape.clone()));
3216 assert_eq!([6..8], [-2..=-1].into_ranges(shape.clone()));
3217
3218 // Unbounded start
3219 assert_eq!([0..3], (..3).into_ranges(shape.clone()));
3220 assert_eq!([0..3], [..3].into_ranges(shape.clone()));
3221 assert_eq!([0..3], [..-5].into_ranges(shape.clone()));
3222
3223 // Unbounded end
3224 assert_eq!([5..8], (5..).into_ranges(shape.clone()));
3225 assert_eq!([5..8], [5..].into_ranges(shape.clone()));
3226 assert_eq!([5..8], [-3..].into_ranges(shape.clone()));
3227
3228 // Full range
3229 assert_eq!([0..8], [..].into_ranges(shape));
3230 }
3231
3232 #[test]
3233 fn slice_range_multi_dim() {
3234 let shape = Shape::new([8, 4]);
3235
3236 // Multiple ways to provide ranges
3237 assert_eq!([0..5, 0..4], [0..5, 0..4].into_ranges(shape.clone()));
3238 assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape.clone()));
3239 assert_eq!([0..8, 0..4], [0..=7, 0..=3].into_ranges(shape.clone()));
3240
3241 assert_eq!([0..5, 0..3], [0..5, 0..3].into_ranges(shape.clone()));
3242
3243 assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape));
3244 }
3245
3246 #[test]
3247 fn slice_range_multi_dim_index() {
3248 let shape = Shape::new([8, 4]);
3249
3250 // Indices (single integer) should also convert to correct range
3251 assert_eq!([0..1, 2..3], [0, 2].into_ranges(shape.clone()));
3252 assert_eq!([7..8, 3..4], [-1, -1].into_ranges(shape.clone()));
3253 assert_eq!([7..8], (-1).into_ranges(shape.clone()));
3254 assert_eq!([7..8], 7.into_ranges(shape));
3255 }
3256
3257 #[test]
3258 fn slice_range_multi_dim_heterogeneous() {
3259 // Slice macro `s![]` can be used to provide different range types
3260 let shape = Shape::new([8, 4, 2]);
3261 let slice = s![0..5, .., -1];
3262 assert_eq!([0..5, 0..4, 1..2], slice.into_ranges(shape));
3263
3264 let shape = Shape::new([8, 4, 2, 3]);
3265 let slice = s![..=4, 0..=3, .., -2..];
3266 assert_eq!([0..5, 0..4, 0..2, 1..3], slice.into_ranges(shape));
3267
3268 let shape = Shape::new([3, 4]);
3269 let slice = s![1..-1, ..];
3270 assert_eq!([1..2, 0..4], slice.into_ranges(shape));
3271 }
3272}