burn_tensor/tensor/api/take.rs
1use alloc::vec::Vec;
2
3use crate::{
4 BasicOps, Int, Tensor,
5 backend::Backend,
6 check,
7 check::TensorCheck,
8 indexing::{AsIndex, canonicalize_dim},
9};
10
11impl<B, const D: usize, K> Tensor<B, D, K>
12where
13 B: Backend,
14 K: BasicOps<B>,
15{
16 /// Takes elements from the tensor along the given dimension using indices of any dimensionality.
17 ///
18 /// This behaves like numpy's take function. When indices is multi-dimensional,
19 /// the output shape will be: input.shape\[:dim\] + indices.shape + input.shape\[dim+1:\]
20 ///
21 /// # Arguments
22 ///
23 /// * `dim` - The dimension along which to select elements. Supports negative indexing.
24 /// * `indices` - The indices of elements to select. Can be any dimensionality.
25 /// Must be valid indices in the range [0, dim_size).
26 ///
27 /// # Example
28 ///
29 /// ```rust
30 /// use burn_tensor::backend::Backend;
31 /// use burn_tensor::{Tensor, Int};
32 ///
33 /// fn example<B: Backend>() {
34 /// let device = B::Device::default();
35 ///
36 /// // Example with 1D indices
37 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
38 /// let indices = Tensor::<B, 1, Int>::from_data([2, 0, 1], &device);
39 /// let result: Tensor<B, 2> = tensor.clone().take::<1, 2>(-1, indices); // -1 refers to last dimension
40 /// println!("{result}");
41 /// // [[3.0, 1.0, 2.0], [6.0, 4.0, 5.0]]
42 ///
43 /// // Example with 2D indices - output will have +1 dimension (2D -> 3D)
44 /// let indices_2d = Tensor::<B, 2, Int>::from_data([[0, 2], [1, 0]], &device);
45 /// let result: Tensor<B, 3> = tensor.take::<2, 3>(1, indices_2d);
46 /// println!("{result}");
47 /// // [[[1.0, 3.0], [2.0, 1.0]], [[4.0, 6.0], [5.0, 4.0]]]
48 /// }
49 /// ```
50 pub fn take<const DI: usize, const DO: usize>(
51 self,
52 dim: impl AsIndex,
53 indices: Tensor<B, DI, Int>,
54 ) -> Tensor<B, DO, K> {
55 let dim = canonicalize_dim(dim, D, false);
56 check!(TensorCheck::take::<D, DI, DO>(dim));
57
58 // Store the indices shape for reshaping later
59 let indices_shape = indices.shape();
60 let indices_dims = indices_shape.clone();
61
62 // Flatten indices to 1D for processing
63 let indices_flat = indices.reshape([indices_shape.num_elements()]);
64
65 // Perform the selection with the flattened indices
66 let selected = self.select(dim, indices_flat);
67
68 // Build the output shape
69 // Output shape = input.shape[:dim] + indices.shape + input.shape[dim+1:]
70 let selected_shape = selected.shape();
71 let mut new_shape = Vec::with_capacity(DO);
72
73 // Add dimensions before the selected dimension
74 for i in 0..dim {
75 new_shape.push(selected_shape[i]);
76 }
77
78 // Add all indices dimensions
79 for idx_dim in indices_dims {
80 new_shape.push(idx_dim);
81 }
82
83 // Add dimensions after the selected dimension
84 for i in (dim + 1)..D {
85 new_shape.push(selected_shape[i]);
86 }
87
88 // Verify we have the correct number of dimensions
89 assert_eq!(
90 new_shape.len(),
91 DO,
92 "Internal error: shape calculation resulted in {} dims but expected {}",
93 new_shape.len(),
94 DO
95 );
96
97 // Convert to fixed-size array for reshape
98 let mut shape_array = [0; DO];
99 for (i, &s) in new_shape.iter().enumerate() {
100 shape_array[i] = s;
101 }
102
103 selected.reshape(shape_array)
104 }
105}