Skip to main content

candle_core/
layout.rs

1//! Tensor Layouts including contiguous or sparse strides
2
3use crate::{Error, Result, Shape};
4
5#[derive(Debug, PartialEq, Eq, Clone)]
6pub struct Layout {
7    shape: Shape,
8    // The strides are given in number of elements and not in bytes.
9    stride: Vec<usize>,
10    start_offset: usize,
11}
12
13impl Layout {
14    pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
15        Self {
16            shape,
17            stride,
18            start_offset,
19        }
20    }
21
22    pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
23        let shape = shape.into();
24        let stride = shape.stride_contiguous();
25        Self {
26            shape,
27            stride,
28            start_offset,
29        }
30    }
31
32    pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
33        Self::contiguous_with_offset(shape, 0)
34    }
35
36    pub fn dims(&self) -> &[usize] {
37        self.shape.dims()
38    }
39
40    /// The dimension size for a specified dimension index.
41    pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
42        let dim = dim.to_index(&self.shape, "dim")?;
43        Ok(self.dims()[dim])
44    }
45
46    pub fn shape(&self) -> &Shape {
47        &self.shape
48    }
49
50    pub fn stride(&self) -> &[usize] {
51        &self.stride
52    }
53
54    pub fn start_offset(&self) -> usize {
55        self.start_offset
56    }
57
58    /// Returns outer stride along `dim` if valid.
59    ///
60    /// Two conditions must hold:
61    ///  1. Inner dims `[dim..]` has standard contiguous strides.
62    ///  2. Outer dims `[..dim]` are contiguous among themselves, i.e.
63    ///     `stride[k] == dims[k+1] * stride[k+1]` for `k` in `0..dim-1`.
64    ///
65    /// When the tensor is fully contiguous this returns `Some(dims[dim..].product())`.
66    pub(crate) fn outer_stride_for_dim(&self, dim: usize) -> Option<usize> {
67        let dims = self.dims();
68        let strides = self.stride();
69
70        // 1. Inner `dims[dim..]` must have contiguous strides.
71        let mut expected = 1usize;
72        for i in (dim..dims.len()).rev() {
73            if strides[i] != expected {
74                return None;
75            }
76            expected *= dims[i];
77        }
78
79        if dim == 0 {
80            // No outer dims.
81            // `expected = dims[dim..].product()`
82            return Some(expected);
83        }
84
85        // 2. Outer `dims[0..dim]` must be internally contiguous.
86        let outer_stride = strides[dim - 1];
87        let mut expected_outer = outer_stride;
88        for k in (0..dim - 1).rev() {
89            expected_outer *= dims[k + 1];
90            if strides[k] != expected_outer {
91                return None;
92            }
93        }
94
95        Some(outer_stride)
96    }
97
98    /// Returns the appropriate start and stop offset if the data is stored in a C
99    /// contiguous (aka row major) way.
100    pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
101        if self.is_contiguous() {
102            let start_o = self.start_offset;
103            Some((start_o, start_o + self.shape.elem_count()))
104        } else {
105            None
106        }
107    }
108
109    /// Returns true if the data is stored in a C contiguous (aka row major) way.
110    /// Note that this does not implies that the start offset is 0 or that there are no extra
111    /// elements at the end of the storage.
112    pub fn is_contiguous(&self) -> bool {
113        self.shape.is_contiguous(&self.stride)
114    }
115
116    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
117    pub fn is_fortran_contiguous(&self) -> bool {
118        self.shape.is_fortran_contiguous(&self.stride)
119    }
120
121    pub fn is_scalar(&self) -> bool {
122        let dims = self.dims();
123        dims.is_empty() || dims.iter().all(|d| *d == 1)
124    }
125
126    /// Returns true if the data is actually a scalar during broadcast
127    pub fn is_scalar_broadcast(&self) -> bool {
128        self.stride().iter().all(|s| *s == 0)
129    }
130
131    pub fn is_scalar_like(&self) -> bool {
132        self.is_scalar() || self.is_scalar_broadcast()
133    }
134
135    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
136        let dims = self.shape().dims();
137        if dim >= dims.len() {
138            Err(Error::DimOutOfRange {
139                shape: self.shape().clone(),
140                dim: dim as i32,
141                op: "narrow",
142            }
143            .bt())?
144        }
145        if start + len > dims[dim] {
146            Err(Error::NarrowInvalidArgs {
147                shape: self.shape.clone(),
148                dim,
149                start,
150                len,
151                msg: "start + len > dim_len",
152            }
153            .bt())?
154        }
155        let mut dims = dims.to_vec();
156        dims[dim] = len;
157        Ok(Self {
158            shape: Shape::from(dims),
159            stride: self.stride.clone(),
160            start_offset: self.start_offset + self.stride[dim] * start,
161        })
162    }
163
164    pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
165        let rank = self.shape.rank();
166        if rank <= dim1 || rank <= dim2 {
167            Err(Error::UnexpectedNumberOfDims {
168                expected: usize::max(dim1, dim2),
169                got: rank,
170                shape: self.shape().clone(),
171            }
172            .bt())?
173        }
174        let mut stride = self.stride().to_vec();
175        let mut dims = self.shape().dims().to_vec();
176        dims.swap(dim1, dim2);
177        stride.swap(dim1, dim2);
178        Ok(Self {
179            shape: Shape::from(dims),
180            stride,
181            start_offset: self.start_offset,
182        })
183    }
184
185    pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
186        let is_permutation =
187            idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
188        if !is_permutation {
189            crate::bail!(
190                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
191                self.dims(),
192                idxs
193            )
194        }
195        let stride = self.stride();
196        let dims = self.shape().dims();
197        let mut perm_stride = stride.to_vec();
198        let mut perm_dims = dims.to_vec();
199        for (i, &idx) in idxs.iter().enumerate() {
200            perm_stride[i] = stride[idx];
201            perm_dims[i] = dims[idx];
202        }
203        Ok(Self {
204            shape: Shape::from(perm_dims),
205            stride: perm_stride,
206            start_offset: self.start_offset,
207        })
208    }
209
210    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
211        let shape = shape.into();
212        if shape.rank() < self.shape().rank() {
213            return Err(Error::BroadcastIncompatibleShapes {
214                src_shape: self.shape().clone(),
215                dst_shape: shape,
216            }
217            .bt());
218        }
219        let added_dims = shape.rank() - self.shape().rank();
220        let mut stride = vec![0; added_dims];
221        for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
222            .iter()
223            .zip(self.dims().iter().zip(self.stride()))
224        {
225            let s = if dst_dim == src_dim {
226                src_stride
227            } else if src_dim != 1 {
228                return Err(Error::BroadcastIncompatibleShapes {
229                    src_shape: self.shape().clone(),
230                    dst_shape: shape,
231                }
232                .bt());
233            } else {
234                0
235            };
236            stride.push(s)
237        }
238        Ok(Self {
239            shape,
240            stride,
241            start_offset: self.start_offset,
242        })
243    }
244
245    pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> {
246        crate::StridedIndex::from_layout(self)
247    }
248
249    pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
250        let mut block_len = 1usize;
251        let mut contiguous_dims = 0usize; // Counted from the right.
252        for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
253            // Size-1 dimensions are trivially contiguous regardless of their stride.
254            if dim == 1 {
255                contiguous_dims += 1;
256                continue;
257            }
258            if stride != block_len {
259                break;
260            }
261            block_len *= dim;
262            contiguous_dims += 1;
263        }
264        let index_dims = self.dims().len() - contiguous_dims;
265        match index_dims {
266            0 => crate::StridedBlocks::SingleBlock {
267                start_offset: self.start_offset,
268                len: block_len,
269            },
270            1 => crate::StridedBlocks::UniformBlocks {
271                start_offset: self.start_offset,
272                block_len,
273                count: self.dims()[0],
274                src_stride: self.stride[0],
275            },
276            _ => {
277                let block_start_index = crate::StridedIndex::new(
278                    &self.dims()[..index_dims],
279                    &self.stride[..index_dims],
280                    self.start_offset,
281                );
282                crate::StridedBlocks::MultipleBlocks {
283                    block_start_index,
284                    block_len,
285                }
286            }
287        }
288    }
289}