candle_core/indexer.rs
1use crate::{Error, Tensor};
2use std::ops::{
3 Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
4};
5
6impl Tensor {
7 /// Intended to be use by the trait `.i()`
8 ///
9 /// ```
10 /// # use candle_core::{Tensor, DType, Device, IndexOp};
11 /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
12 ///
13 /// let c = a.i(0..1)?;
14 /// assert_eq!(c.shape().dims(), &[1, 3]);
15 ///
16 /// let c = a.i(0)?;
17 /// assert_eq!(c.shape().dims(), &[3]);
18 ///
19 /// let c = a.i((.., ..2) )?;
20 /// assert_eq!(c.shape().dims(), &[2, 2]);
21 ///
22 /// let c = a.i((.., ..=2))?;
23 /// assert_eq!(c.shape().dims(), &[2, 3]);
24 ///
25 /// # Ok::<(), candle_core::Error>(())
26 /// ```
27 fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
28 let mut x = self.clone();
29 let dims = self.shape().dims();
30 let mut current_dim = 0;
31 for (i, indexer) in indexers.iter().enumerate() {
32 x = match indexer {
33 TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
34 TensorIndexer::Narrow(left_bound, right_bound) => {
35 let start = match left_bound {
36 Bound::Included(n) => *n,
37 Bound::Excluded(n) => *n + 1,
38 Bound::Unbounded => 0,
39 };
40 let stop = match right_bound {
41 Bound::Included(n) => *n + 1,
42 Bound::Excluded(n) => *n,
43 Bound::Unbounded => dims[i],
44 };
45 let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
46 current_dim += 1;
47 out
48 }
49 TensorIndexer::IndexSelect(indexes) => {
50 if indexes.rank() != 1 {
51 crate::bail!("multi-dimensional tensor indexing is not supported")
52 }
53 let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
54 current_dim += 1;
55 out
56 }
57 TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
58 };
59 }
60 Ok(x)
61 }
62}
63
64#[derive(Debug)]
65/// Generic structure used to index a slice of the tensor
66pub enum TensorIndexer {
67 /// This selects the elements for which an index has some specific value.
68 Select(usize),
69 /// This is a regular slice, purely indexing a chunk of the tensor
70 Narrow(Bound<usize>, Bound<usize>),
71 /// Indexing via a 1d tensor
72 IndexSelect(Tensor),
73 Err(Error),
74}
75
76impl From<usize> for TensorIndexer {
77 fn from(index: usize) -> Self {
78 TensorIndexer::Select(index)
79 }
80}
81
82impl From<&[u32]> for TensorIndexer {
83 fn from(index: &[u32]) -> Self {
84 match Tensor::new(index, &crate::Device::Cpu) {
85 Ok(tensor) => TensorIndexer::IndexSelect(tensor),
86 Err(e) => TensorIndexer::Err(e),
87 }
88 }
89}
90
91impl From<Vec<u32>> for TensorIndexer {
92 fn from(index: Vec<u32>) -> Self {
93 let len = index.len();
94 match Tensor::from_vec(index, len, &crate::Device::Cpu) {
95 Ok(tensor) => TensorIndexer::IndexSelect(tensor),
96 Err(e) => TensorIndexer::Err(e),
97 }
98 }
99}
100
101impl From<&Tensor> for TensorIndexer {
102 fn from(tensor: &Tensor) -> Self {
103 TensorIndexer::IndexSelect(tensor.clone())
104 }
105}
106
107trait RB: RangeBounds<usize> {}
108impl RB for Range<usize> {}
109impl RB for RangeFrom<usize> {}
110impl RB for RangeFull {}
111impl RB for RangeInclusive<usize> {}
112impl RB for RangeTo<usize> {}
113impl RB for RangeToInclusive<usize> {}
114
115impl<T: RB> From<T> for TensorIndexer {
116 fn from(range: T) -> Self {
117 use std::ops::Bound::*;
118 let start = match range.start_bound() {
119 Included(idx) => Included(*idx),
120 Excluded(idx) => Excluded(*idx),
121 Unbounded => Unbounded,
122 };
123 let end = match range.end_bound() {
124 Included(idx) => Included(*idx),
125 Excluded(idx) => Excluded(*idx),
126 Unbounded => Unbounded,
127 };
128 TensorIndexer::Narrow(start, end)
129 }
130}
131
132/// Trait used to implement multiple signatures for ease of use of the slicing
133/// of a tensor
134pub trait IndexOp<T> {
135 /// Returns a slicing iterator which are the chunks of data necessary to
136 /// reconstruct the desired tensor.
137 fn i(&self, index: T) -> Result<Tensor, Error>;
138}
139
140impl<T> IndexOp<T> for Tensor
141where
142 T: Into<TensorIndexer>,
143{
144 ///```rust
145 /// use candle_core::{Tensor, DType, Device, IndexOp};
146 /// let a = Tensor::new(&[
147 /// [0., 1.],
148 /// [2., 3.],
149 /// [4., 5.]
150 /// ], &Device::Cpu)?;
151 ///
152 /// let b = a.i(0)?;
153 /// assert_eq!(b.shape().dims(), &[2]);
154 /// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
155 ///
156 /// let c = a.i(..2)?;
157 /// assert_eq!(c.shape().dims(), &[2, 2]);
158 /// assert_eq!(c.to_vec2::<f64>()?, &[
159 /// [0., 1.],
160 /// [2., 3.]
161 /// ]);
162 ///
163 /// let d = a.i(1..)?;
164 /// assert_eq!(d.shape().dims(), &[2, 2]);
165 /// assert_eq!(d.to_vec2::<f64>()?, &[
166 /// [2., 3.],
167 /// [4., 5.]
168 /// ]);
169 /// # Ok::<(), candle_core::Error>(())
170 /// ```
171 fn i(&self, index: T) -> Result<Tensor, Error> {
172 self.index(&[index.into()])
173 }
174}
175
176impl<A> IndexOp<(A,)> for Tensor
177where
178 A: Into<TensorIndexer>,
179{
180 ///```rust
181 /// use candle_core::{Tensor, DType, Device, IndexOp};
182 /// let a = Tensor::new(&[
183 /// [0f32, 1.],
184 /// [2. , 3.],
185 /// [4. , 5.]
186 /// ], &Device::Cpu)?;
187 ///
188 /// let b = a.i((0,))?;
189 /// assert_eq!(b.shape().dims(), &[2]);
190 /// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
191 ///
192 /// let c = a.i((..2,))?;
193 /// assert_eq!(c.shape().dims(), &[2, 2]);
194 /// assert_eq!(c.to_vec2::<f32>()?, &[
195 /// [0., 1.],
196 /// [2., 3.]
197 /// ]);
198 ///
199 /// let d = a.i((1..,))?;
200 /// assert_eq!(d.shape().dims(), &[2, 2]);
201 /// assert_eq!(d.to_vec2::<f32>()?, &[
202 /// [2., 3.],
203 /// [4., 5.]
204 /// ]);
205 /// # Ok::<(), candle_core::Error>(())
206 /// ```
207 fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
208 self.index(&[a.into()])
209 }
210}
211#[allow(non_snake_case)]
212impl<A, B> IndexOp<(A, B)> for Tensor
213where
214 A: Into<TensorIndexer>,
215 B: Into<TensorIndexer>,
216{
217 ///```rust
218 /// use candle_core::{Tensor, DType, Device, IndexOp};
219 /// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
220 ///
221 /// let b = a.i((1, 0))?;
222 /// assert_eq!(b.to_vec0::<f32>()?, 3.);
223 ///
224 /// let c = a.i((..2, 1))?;
225 /// assert_eq!(c.shape().dims(), &[2]);
226 /// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
227 ///
228 /// let d = a.i((2.., ..))?;
229 /// assert_eq!(c.shape().dims(), &[2]);
230 /// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
231 /// # Ok::<(), candle_core::Error>(())
232 /// ```
233 fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
234 self.index(&[a.into(), b.into()])
235 }
236}
237
238macro_rules! index_op_tuple {
239 ($doc:tt, $($t:ident),+) => {
240 #[allow(non_snake_case)]
241 impl<$($t),*> IndexOp<($($t,)*)> for Tensor
242 where
243 $($t: Into<TensorIndexer>,)*
244 {
245 #[doc=$doc]
246 fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
247 self.index(&[$($t.into(),)*])
248 }
249 }
250 };
251}
252
253index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
254index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
255index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
256index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
257index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);