candle_core/
indexer.rs

1use crate::{Error, Tensor};
2use std::ops::{
3    Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
4};
5
6impl Tensor {
7    /// Intended to be use by the trait `.i()`
8    ///
9    /// ```
10    /// # use candle_core::{Tensor, DType, Device, IndexOp};
11    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
12    ///
13    /// let c = a.i(0..1)?;
14    /// assert_eq!(c.shape().dims(), &[1, 3]);
15    ///
16    /// let c = a.i(0)?;
17    /// assert_eq!(c.shape().dims(), &[3]);
18    ///
19    /// let c = a.i((.., ..2) )?;
20    /// assert_eq!(c.shape().dims(), &[2, 2]);
21    ///
22    /// let c = a.i((.., ..=2))?;
23    /// assert_eq!(c.shape().dims(), &[2, 3]);
24    ///
25    /// # Ok::<(), candle_core::Error>(())
26    /// ```
27    fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
28        let mut x = self.clone();
29        let dims = self.shape().dims();
30        let mut current_dim = 0;
31        for (i, indexer) in indexers.iter().enumerate() {
32            x = match indexer {
33                TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
34                TensorIndexer::Narrow(left_bound, right_bound) => {
35                    let start = match left_bound {
36                        Bound::Included(n) => *n,
37                        Bound::Excluded(n) => *n + 1,
38                        Bound::Unbounded => 0,
39                    };
40                    let stop = match right_bound {
41                        Bound::Included(n) => *n + 1,
42                        Bound::Excluded(n) => *n,
43                        Bound::Unbounded => dims[i],
44                    };
45                    let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
46                    current_dim += 1;
47                    out
48                }
49                TensorIndexer::IndexSelect(indexes) => {
50                    if indexes.rank() != 1 {
51                        crate::bail!("multi-dimensional tensor indexing is not supported")
52                    }
53                    let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
54                    current_dim += 1;
55                    out
56                }
57                TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
58            };
59        }
60        Ok(x)
61    }
62}
63
64#[derive(Debug)]
65/// Generic structure used to index a slice of the tensor
66pub enum TensorIndexer {
67    /// This selects the elements for which an index has some specific value.
68    Select(usize),
69    /// This is a regular slice, purely indexing a chunk of the tensor
70    Narrow(Bound<usize>, Bound<usize>),
71    /// Indexing via a 1d tensor
72    IndexSelect(Tensor),
73    Err(Error),
74}
75
76impl From<usize> for TensorIndexer {
77    fn from(index: usize) -> Self {
78        TensorIndexer::Select(index)
79    }
80}
81
82impl From<&[u32]> for TensorIndexer {
83    fn from(index: &[u32]) -> Self {
84        match Tensor::new(index, &crate::Device::Cpu) {
85            Ok(tensor) => TensorIndexer::IndexSelect(tensor),
86            Err(e) => TensorIndexer::Err(e),
87        }
88    }
89}
90
91impl From<Vec<u32>> for TensorIndexer {
92    fn from(index: Vec<u32>) -> Self {
93        let len = index.len();
94        match Tensor::from_vec(index, len, &crate::Device::Cpu) {
95            Ok(tensor) => TensorIndexer::IndexSelect(tensor),
96            Err(e) => TensorIndexer::Err(e),
97        }
98    }
99}
100
101impl From<&Tensor> for TensorIndexer {
102    fn from(tensor: &Tensor) -> Self {
103        TensorIndexer::IndexSelect(tensor.clone())
104    }
105}
106
107trait RB: RangeBounds<usize> {}
108impl RB for Range<usize> {}
109impl RB for RangeFrom<usize> {}
110impl RB for RangeFull {}
111impl RB for RangeInclusive<usize> {}
112impl RB for RangeTo<usize> {}
113impl RB for RangeToInclusive<usize> {}
114
115impl<T: RB> From<T> for TensorIndexer {
116    fn from(range: T) -> Self {
117        use std::ops::Bound::*;
118        let start = match range.start_bound() {
119            Included(idx) => Included(*idx),
120            Excluded(idx) => Excluded(*idx),
121            Unbounded => Unbounded,
122        };
123        let end = match range.end_bound() {
124            Included(idx) => Included(*idx),
125            Excluded(idx) => Excluded(*idx),
126            Unbounded => Unbounded,
127        };
128        TensorIndexer::Narrow(start, end)
129    }
130}
131
132/// Trait used to implement multiple signatures for ease of use of the slicing
133/// of a tensor
134pub trait IndexOp<T> {
135    /// Returns a slicing iterator which are the chunks of data necessary to
136    /// reconstruct the desired tensor.
137    fn i(&self, index: T) -> Result<Tensor, Error>;
138}
139
140impl<T> IndexOp<T> for Tensor
141where
142    T: Into<TensorIndexer>,
143{
144    ///```rust
145    /// use candle_core::{Tensor, DType, Device, IndexOp};
146    /// let a = Tensor::new(&[
147    ///     [0., 1.],
148    ///     [2., 3.],
149    ///     [4., 5.]
150    /// ], &Device::Cpu)?;
151    ///
152    /// let b = a.i(0)?;
153    /// assert_eq!(b.shape().dims(), &[2]);
154    /// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
155    ///
156    /// let c = a.i(..2)?;
157    /// assert_eq!(c.shape().dims(), &[2, 2]);
158    /// assert_eq!(c.to_vec2::<f64>()?, &[
159    ///     [0., 1.],
160    ///     [2., 3.]
161    /// ]);
162    ///
163    /// let d = a.i(1..)?;
164    /// assert_eq!(d.shape().dims(), &[2, 2]);
165    /// assert_eq!(d.to_vec2::<f64>()?, &[
166    ///     [2., 3.],
167    ///     [4., 5.]
168    /// ]);
169    /// # Ok::<(), candle_core::Error>(())
170    /// ```
171    fn i(&self, index: T) -> Result<Tensor, Error> {
172        self.index(&[index.into()])
173    }
174}
175
176impl<A> IndexOp<(A,)> for Tensor
177where
178    A: Into<TensorIndexer>,
179{
180    ///```rust
181    /// use candle_core::{Tensor, DType, Device, IndexOp};
182    /// let a = Tensor::new(&[
183    ///     [0f32, 1.],
184    ///     [2.  , 3.],
185    ///     [4.  , 5.]
186    /// ], &Device::Cpu)?;
187    ///
188    /// let b = a.i((0,))?;
189    /// assert_eq!(b.shape().dims(), &[2]);
190    /// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
191    ///
192    /// let c = a.i((..2,))?;
193    /// assert_eq!(c.shape().dims(), &[2, 2]);
194    /// assert_eq!(c.to_vec2::<f32>()?, &[
195    ///     [0., 1.],
196    ///     [2., 3.]
197    /// ]);
198    ///
199    /// let d = a.i((1..,))?;
200    /// assert_eq!(d.shape().dims(), &[2, 2]);
201    /// assert_eq!(d.to_vec2::<f32>()?, &[
202    ///     [2., 3.],
203    ///     [4., 5.]
204    /// ]);
205    /// # Ok::<(), candle_core::Error>(())
206    /// ```
207    fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
208        self.index(&[a.into()])
209    }
210}
211#[allow(non_snake_case)]
212impl<A, B> IndexOp<(A, B)> for Tensor
213where
214    A: Into<TensorIndexer>,
215    B: Into<TensorIndexer>,
216{
217    ///```rust
218    /// use candle_core::{Tensor, DType, Device, IndexOp};
219    /// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
220    ///
221    /// let b = a.i((1, 0))?;
222    /// assert_eq!(b.to_vec0::<f32>()?, 3.);
223    ///
224    /// let c = a.i((..2, 1))?;
225    /// assert_eq!(c.shape().dims(), &[2]);
226    /// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
227    ///
228    /// let d = a.i((2.., ..))?;
229    /// assert_eq!(c.shape().dims(), &[2]);
230    /// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
231    /// # Ok::<(), candle_core::Error>(())
232    /// ```
233    fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
234        self.index(&[a.into(), b.into()])
235    }
236}
237
238macro_rules! index_op_tuple {
239    ($doc:tt, $($t:ident),+) => {
240        #[allow(non_snake_case)]
241        impl<$($t),*> IndexOp<($($t,)*)> for Tensor
242        where
243            $($t: Into<TensorIndexer>,)*
244        {
245            #[doc=$doc]
246            fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
247                self.index(&[$($t.into(),)*])
248            }
249        }
250    };
251}
252
253index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
254index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
255index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
256index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
257index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);