Skip to main content

hanzo_ml/
layout.rs

1//! Tensor Layouts including contiguous or sparse strides
2use crate::{Error, Result, Shape};
3
4#[derive(Debug, PartialEq, Eq, Clone)]
5pub struct Layout {
6    shape: Shape,
7    // The strides are given in number of elements and not in bytes.
8    stride: Vec<usize>,
9    start_offset: usize,
10}
11
12impl Layout {
13    pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
14        Self {
15            shape,
16            stride,
17            start_offset,
18        }
19    }
20
21    pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
22        let shape = shape.into();
23        let stride = shape.stride_contiguous();
24        Self {
25            shape,
26            stride,
27            start_offset,
28        }
29    }
30
31    pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
32        Self::contiguous_with_offset(shape, 0)
33    }
34
35    pub fn dims(&self) -> &[usize] {
36        self.shape.dims()
37    }
38
39    /// The dimension size for a specified dimension index.
40    pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
41        let dim = dim.to_index(&self.shape, "dim")?;
42        Ok(self.dims()[dim])
43    }
44
45    pub fn shape(&self) -> &Shape {
46        &self.shape
47    }
48
49    pub fn stride(&self) -> &[usize] {
50        &self.stride
51    }
52
53    pub fn start_offset(&self) -> usize {
54        self.start_offset
55    }
56
57    /// Returns outer stride along `dim` if valid.
58    ///
59    /// Two conditions must hold:
60    ///  1. Inner dims `[dim..]` has standard contiguous strides.
61    ///  2. Outer dims `[..dim]` are contiguous among themselves, i.e.
62    ///     `stride[k] == dims[k+1] * stride[k+1]` for `k` in `0..dim-1`.
63    ///
64    /// When the tensor is fully contiguous this returns `Some(dims[dim..].product())`.
65    pub(crate) fn outer_stride_for_dim(&self, dim: usize) -> Option<usize> {
66        let dims = self.dims();
67        let strides = self.stride();
68
69        // 1. Inner `dims[dim..]` must have contiguous strides.
70        let mut expected = 1usize;
71        for i in (dim..dims.len()).rev() {
72            if strides[i] != expected {
73                return None;
74            }
75            expected *= dims[i];
76        }
77
78        if dim == 0 {
79            // No outer dims.
80            // `expected = dims[dim..].product()`
81            return Some(expected);
82        }
83
84        // 2. Outer `dims[0..dim]` must be internally contiguous.
85        let outer_stride = strides[dim - 1];
86        let mut expected_outer = outer_stride;
87        for k in (0..dim - 1).rev() {
88            expected_outer *= dims[k + 1];
89            if strides[k] != expected_outer {
90                return None;
91            }
92        }
93
94        Some(outer_stride)
95    }
96
97    /// Returns the appropriate start and stop offset if the data is stored in a C
98    /// contiguous (aka row major) way.
99    pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
100        if self.is_contiguous() {
101            let start_o = self.start_offset;
102            Some((start_o, start_o + self.shape.elem_count()))
103        } else {
104            None
105        }
106    }
107
108    /// Returns true if the data is stored in a C contiguous (aka row major) way.
109    /// Note that this does not implies that the start offset is 0 or that there are no extra
110    /// elements at the end of the storage.
111    pub fn is_contiguous(&self) -> bool {
112        self.shape.is_contiguous(&self.stride)
113    }
114
115    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
116    pub fn is_fortran_contiguous(&self) -> bool {
117        self.shape.is_fortran_contiguous(&self.stride)
118    }
119
120    pub fn is_scalar(&self) -> bool {
121        let dims = self.dims();
122        dims.is_empty() || dims.iter().all(|d| *d == 1)
123    }
124
125    /// Returns true if the data is actually a scalar during broadcast
126    pub fn is_scalar_broadcast(&self) -> bool {
127        self.stride().iter().all(|s| *s == 0)
128    }
129
130    pub fn is_scalar_like(&self) -> bool {
131        self.is_scalar() || self.is_scalar_broadcast()
132    }
133
134    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
135        let dims = self.shape().dims();
136        if dim >= dims.len() {
137            Err(Error::DimOutOfRange {
138                shape: self.shape().clone(),
139                dim: dim as i32,
140                op: "narrow",
141            }
142            .bt())?
143        }
144        if start + len > dims[dim] {
145            Err(Error::NarrowInvalidArgs {
146                shape: self.shape.clone(),
147                dim,
148                start,
149                len,
150                msg: "start + len > dim_len",
151            }
152            .bt())?
153        }
154        let mut dims = dims.to_vec();
155        dims[dim] = len;
156        Ok(Self {
157            shape: Shape::from(dims),
158            stride: self.stride.clone(),
159            start_offset: self.start_offset + self.stride[dim] * start,
160        })
161    }
162
163    pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
164        let rank = self.shape.rank();
165        if rank <= dim1 || rank <= dim2 {
166            Err(Error::UnexpectedNumberOfDims {
167                expected: usize::max(dim1, dim2),
168                got: rank,
169                shape: self.shape().clone(),
170            }
171            .bt())?
172        }
173        let mut stride = self.stride().to_vec();
174        let mut dims = self.shape().dims().to_vec();
175        dims.swap(dim1, dim2);
176        stride.swap(dim1, dim2);
177        Ok(Self {
178            shape: Shape::from(dims),
179            stride,
180            start_offset: self.start_offset,
181        })
182    }
183
184    pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
185        let is_permutation =
186            idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
187        if !is_permutation {
188            crate::bail!(
189                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
190                self.dims(),
191                idxs
192            )
193        }
194        let stride = self.stride();
195        let dims = self.shape().dims();
196        let mut perm_stride = stride.to_vec();
197        let mut perm_dims = dims.to_vec();
198        for (i, &idx) in idxs.iter().enumerate() {
199            perm_stride[i] = stride[idx];
200            perm_dims[i] = dims[idx];
201        }
202        Ok(Self {
203            shape: Shape::from(perm_dims),
204            stride: perm_stride,
205            start_offset: self.start_offset,
206        })
207    }
208
209    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
210        let shape = shape.into();
211        if shape.rank() < self.shape().rank() {
212            return Err(Error::BroadcastIncompatibleShapes {
213                src_shape: self.shape().clone(),
214                dst_shape: shape,
215            }
216            .bt());
217        }
218        let added_dims = shape.rank() - self.shape().rank();
219        let mut stride = vec![0; added_dims];
220        for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
221            .iter()
222            .zip(self.dims().iter().zip(self.stride()))
223        {
224            let s = if dst_dim == src_dim {
225                src_stride
226            } else if src_dim != 1 {
227                return Err(Error::BroadcastIncompatibleShapes {
228                    src_shape: self.shape().clone(),
229                    dst_shape: shape,
230                }
231                .bt());
232            } else {
233                0
234            };
235            stride.push(s)
236        }
237        Ok(Self {
238            shape,
239            stride,
240            start_offset: self.start_offset,
241        })
242    }
243
244    pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> {
245        crate::StridedIndex::from_layout(self)
246    }
247
248    pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
249        let mut block_len = 1usize;
250        let mut contiguous_dims = 0usize; // Counted from the right.
251        for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
252            // Size-1 dimensions are trivially contiguous regardless of their stride.
253            if dim == 1 {
254                contiguous_dims += 1;
255                continue;
256            }
257            if stride != block_len {
258                break;
259            }
260            block_len *= dim;
261            contiguous_dims += 1;
262        }
263        let index_dims = self.dims().len() - contiguous_dims;
264        match index_dims {
265            0 => crate::StridedBlocks::SingleBlock {
266                start_offset: self.start_offset,
267                len: block_len,
268            },
269            1 => crate::StridedBlocks::UniformBlocks {
270                start_offset: self.start_offset,
271                block_len,
272                count: self.dims()[0],
273                src_stride: self.stride[0],
274            },
275            _ => {
276                let block_start_index = crate::StridedIndex::new(
277                    &self.dims()[..index_dims],
278                    &self.stride[..index_dims],
279                    self.start_offset,
280                );
281                crate::StridedBlocks::MultipleBlocks {
282                    block_start_index,
283                    block_len,
284                }
285            }
286        }
287    }
288
289    // Returns the contiguous offsets with broadcast if applicable.
290    pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
291        let mut left_broadcast = 1;
292        let mut right_broadcast = 1;
293        let strides = self.stride();
294        let dims = self.dims();
295        let mut start_cont = 0;
296        let mut end_cont = dims.len();
297        for (&s, &d) in strides.iter().zip(dims.iter()) {
298            if s != 0 {
299                break;
300            }
301            start_cont += 1;
302            left_broadcast *= d;
303        }
304        if start_cont == dims.len() {
305            return Some(ContiguousOffsetsWithBroadcast {
306                start: self.start_offset,
307                len: 1,
308                left_broadcast,
309                right_broadcast: 1,
310            });
311        }
312        for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
313            if s != 0 {
314                break;
315            }
316            end_cont -= 1;
317            right_broadcast *= d;
318        }
319        // Check that the inner dims are contiguous
320        let strides = &strides[start_cont..end_cont];
321        let dims = &dims[start_cont..end_cont];
322        let mut len = 1;
323        for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
324            if stride != len {
325                return None;
326            }
327            len *= dim;
328        }
329        Some(ContiguousOffsetsWithBroadcast {
330            start: self.start_offset,
331            len,
332            left_broadcast,
333            right_broadcast,
334        })
335    }
336}
337
338#[derive(Debug, Clone, PartialEq, Eq)]
339pub struct ContiguousOffsetsWithBroadcast {
340    pub start: usize,
341    pub len: usize,
342    pub left_broadcast: usize,
343    pub right_broadcast: usize,
344}