candle_core/
layout.rs

1//! Tensor Layouts including contiguous or sparse strides
2use crate::{Error, Result, Shape};
3
4#[derive(Debug, PartialEq, Eq, Clone)]
5pub struct Layout {
6    shape: Shape,
7    // The strides are given in number of elements and not in bytes.
8    stride: Vec<usize>,
9    start_offset: usize,
10}
11
12impl Layout {
13    pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
14        Self {
15            shape,
16            stride,
17            start_offset,
18        }
19    }
20
21    pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
22        let shape = shape.into();
23        let stride = shape.stride_contiguous();
24        Self {
25            shape,
26            stride,
27            start_offset,
28        }
29    }
30
31    pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
32        Self::contiguous_with_offset(shape, 0)
33    }
34
35    pub fn dims(&self) -> &[usize] {
36        self.shape.dims()
37    }
38
39    /// The dimension size for a specified dimension index.
40    pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
41        let dim = dim.to_index(&self.shape, "dim")?;
42        Ok(self.dims()[dim])
43    }
44
45    pub fn shape(&self) -> &Shape {
46        &self.shape
47    }
48
49    pub fn stride(&self) -> &[usize] {
50        &self.stride
51    }
52
53    pub fn start_offset(&self) -> usize {
54        self.start_offset
55    }
56
57    /// Returns the appropriate start and stop offset if the data is stored in a C
58    /// contiguous (aka row major) way.
59    pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
60        if self.is_contiguous() {
61            let start_o = self.start_offset;
62            Some((start_o, start_o + self.shape.elem_count()))
63        } else {
64            None
65        }
66    }
67
68    /// Returns true if the data is stored in a C contiguous (aka row major) way.
69    /// Note that this does not implies that the start offset is 0 or that there are no extra
70    /// elements at the end of the storage.
71    pub fn is_contiguous(&self) -> bool {
72        self.shape.is_contiguous(&self.stride)
73    }
74
75    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
76    pub fn is_fortran_contiguous(&self) -> bool {
77        self.shape.is_fortran_contiguous(&self.stride)
78    }
79
80    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
81        let dims = self.shape().dims();
82        if dim >= dims.len() {
83            Err(Error::DimOutOfRange {
84                shape: self.shape().clone(),
85                dim: dim as i32,
86                op: "narrow",
87            }
88            .bt())?
89        }
90        if start + len > dims[dim] {
91            Err(Error::NarrowInvalidArgs {
92                shape: self.shape.clone(),
93                dim,
94                start,
95                len,
96                msg: "start + len > dim_len",
97            }
98            .bt())?
99        }
100        let mut dims = dims.to_vec();
101        dims[dim] = len;
102        Ok(Self {
103            shape: Shape::from(dims),
104            stride: self.stride.clone(),
105            start_offset: self.start_offset + self.stride[dim] * start,
106        })
107    }
108
109    pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
110        let rank = self.shape.rank();
111        if rank <= dim1 || rank <= dim2 {
112            Err(Error::UnexpectedNumberOfDims {
113                expected: usize::max(dim1, dim2),
114                got: rank,
115                shape: self.shape().clone(),
116            }
117            .bt())?
118        }
119        let mut stride = self.stride().to_vec();
120        let mut dims = self.shape().dims().to_vec();
121        dims.swap(dim1, dim2);
122        stride.swap(dim1, dim2);
123        Ok(Self {
124            shape: Shape::from(dims),
125            stride,
126            start_offset: self.start_offset,
127        })
128    }
129
130    pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
131        let is_permutation =
132            idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
133        if !is_permutation {
134            crate::bail!(
135                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
136                self.dims(),
137                idxs
138            )
139        }
140        let stride = self.stride();
141        let dims = self.shape().dims();
142        let mut perm_stride = stride.to_vec();
143        let mut perm_dims = dims.to_vec();
144        for (i, &idx) in idxs.iter().enumerate() {
145            perm_stride[i] = stride[idx];
146            perm_dims[i] = dims[idx];
147        }
148        Ok(Self {
149            shape: Shape::from(perm_dims),
150            stride: perm_stride,
151            start_offset: self.start_offset,
152        })
153    }
154
155    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
156        let shape = shape.into();
157        if shape.rank() < self.shape().rank() {
158            return Err(Error::BroadcastIncompatibleShapes {
159                src_shape: self.shape().clone(),
160                dst_shape: shape,
161            }
162            .bt());
163        }
164        let added_dims = shape.rank() - self.shape().rank();
165        let mut stride = vec![0; added_dims];
166        for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
167            .iter()
168            .zip(self.dims().iter().zip(self.stride()))
169        {
170            let s = if dst_dim == src_dim {
171                src_stride
172            } else if src_dim != 1 {
173                return Err(Error::BroadcastIncompatibleShapes {
174                    src_shape: self.shape().clone(),
175                    dst_shape: shape,
176                }
177                .bt());
178            } else {
179                0
180            };
181            stride.push(s)
182        }
183        Ok(Self {
184            shape,
185            stride,
186            start_offset: self.start_offset,
187        })
188    }
189
190    pub(crate) fn strided_index(&self) -> crate::StridedIndex {
191        crate::StridedIndex::from_layout(self)
192    }
193
194    pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
195        let mut block_len = 1;
196        let mut contiguous_dims = 0; // These are counted from the right.
197        for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
198            if stride != block_len {
199                break;
200            }
201            block_len *= dim;
202            contiguous_dims += 1;
203        }
204        let index_dims = self.dims().len() - contiguous_dims;
205        if index_dims == 0 {
206            crate::StridedBlocks::SingleBlock {
207                start_offset: self.start_offset,
208                len: block_len,
209            }
210        } else {
211            let block_start_index = crate::StridedIndex::new(
212                &self.dims()[..index_dims],
213                &self.stride[..index_dims],
214                self.start_offset,
215            );
216            crate::StridedBlocks::MultipleBlocks {
217                block_start_index,
218                block_len,
219            }
220        }
221    }
222
223    // Returns the contiguous offsets with broadcast if applicable.
224    pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
225        let mut left_broadcast = 1;
226        let mut right_broadcast = 1;
227        let strides = self.stride();
228        let dims = self.dims();
229        let mut start_cont = 0;
230        let mut end_cont = dims.len();
231        for (&s, &d) in strides.iter().zip(dims.iter()) {
232            if s != 0 {
233                break;
234            }
235            start_cont += 1;
236            left_broadcast *= d;
237        }
238        if start_cont == dims.len() {
239            return Some(ContiguousOffsetsWithBroadcast {
240                start: self.start_offset,
241                len: 1,
242                left_broadcast,
243                right_broadcast: 1,
244            });
245        }
246        for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
247            if s != 0 {
248                break;
249            }
250            end_cont -= 1;
251            right_broadcast *= d;
252        }
253        // Check that the inner dims are contiguous
254        let strides = &strides[start_cont..end_cont];
255        let dims = &dims[start_cont..end_cont];
256        let mut len = 1;
257        for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
258            if stride != len {
259                return None;
260            }
261            len *= dim;
262        }
263        Some(ContiguousOffsetsWithBroadcast {
264            start: self.start_offset,
265            len,
266            left_broadcast,
267            right_broadcast,
268        })
269    }
270}
271
272#[derive(Debug, Clone, PartialEq, Eq)]
273pub struct ContiguousOffsetsWithBroadcast {
274    pub start: usize,
275    pub len: usize,
276    pub left_broadcast: usize,
277    pub right_broadcast: usize,
278}