Skip to main content

lumen_core/
layout.rs

1use crate::{Error, Result};
2use super::{Dim, Shape};
3
4#[derive(Debug, PartialEq, Eq, Clone)]
5pub struct Layout {
6    pub(crate) shape: Shape,
7    pub(crate) stride: Vec<usize>,
8    pub(crate) start_offset: usize,
9}
10
11impl Layout {
12    pub fn new<S: Into<Shape>>(shape: S, stride: Vec<usize>, start_offset: usize) -> Self {
13        Self {
14            shape: shape.into(), stride, start_offset
15        }
16    }
17    
18    pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
19        let shape = shape.into();
20        let stride = shape.stride_contiguous();
21        Self {
22            shape,
23            stride,
24            start_offset: 0,
25        }
26    }
27
28    pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
29        let shape = shape.into();
30        let stride = shape.stride_contiguous();
31        Self {
32            shape,
33            stride,
34            start_offset,
35        }
36    }
37
38    pub fn dims(&self) -> &[usize] {
39        self.shape.dims()
40    }
41
42    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
43        let dim = dim.to_index(&self.shape, "dim")?;
44        Ok(self.dims()[dim])
45    }
46
47    pub fn shape(&self) -> &Shape {
48        &self.shape
49    }
50
51    pub fn stride(&self) -> &[usize] {
52        &self.stride
53    }
54
55    pub fn start_offset(&self) -> usize {
56        self.start_offset
57    }
58
59    pub fn element_count(&self) -> usize {
60        self.shape().element_count()
61    }
62
63    pub fn is_contiguous(&self) -> bool {
64        self.shape.is_contiguous(&self.stride)
65    }
66
67    pub fn slice(&self, dim: usize, start: usize, end: usize, step: usize) -> Result<Self> {
68        let dims = self.shape().dims();
69        
70        if dim >= dims.len() {
71            Err(Error::DimOutOfRange { 
72                shape: self.shape().clone(), 
73                dim: dim as i32, 
74                op: "slice" 
75            })?;
76        }
77        
78        if step == 0 {
79            return Err(Error::NarrowInvalidArgs {
80                shape: self.shape.clone(),
81                dim, start, len: 0,
82                msg: "step cannot be 0",
83            }.into());
84        }
85
86        if start > end || end > dims[dim] {
87                return Err(Error::NarrowInvalidArgs {
88                shape: self.shape.clone(),
89                dim, start, len: end.saturating_sub(start), 
90                msg: "index out of range",
91            }.into());
92        }
93
94        let new_len = if start == end { 0 } else { (start..end).step_by(step).len() };
95
96        let mut new_dims = dims.to_vec();
97        new_dims[dim] = new_len;
98
99        let mut new_stride = self.stride.clone();
100        new_stride[dim] *= step; 
101
102        Ok(Self::new(
103            new_dims, 
104            new_stride,
105            self.start_offset + self.stride[dim] * start 
106        ))
107    }
108
109    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
110        self.slice(dim, start, start + len, 1)
111    }
112
113    pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
114        let rank = self.shape.rank();
115        if rank <= dim1 || rank <= dim2 {
116            Err(Error::UnexpectedNumberOfDims {
117                expected: usize::max(dim1, dim2),
118                got: rank,
119                shape: self.shape().clone(),
120            })?
121        }
122
123        let mut stride = self.stride().to_vec();
124        let mut dims = self.shape().dims().to_vec();
125        dims.swap(dim1, dim2);
126        stride.swap(dim1, dim2);
127
128        Ok(Self::new(dims, stride, self.start_offset))
129    }
130
131    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
132        let shape = shape.into();
133        if shape.rank() < self.shape().rank() {
134            return Err(Error::BroadcastIncompatibleShapes {
135                src_shape: self.shape().clone(),
136                dst_shape: shape,
137            })?;
138        }
139
140        let added_dims = shape.rank() - self.shape().rank();
141        let mut stride = vec![0; added_dims];
142        for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
143            .iter()
144            .zip(self.dims().iter().zip(self.stride()))
145        {
146            let s = if dst_dim == src_dim {
147                src_stride
148            } else if src_dim != 1 {
149                return Err(Error::BroadcastIncompatibleShapes {
150                    src_shape: self.shape().clone(),
151                    dst_shape: shape,
152                })?;
153            } else {
154                0
155            };
156            stride.push(s)
157        }
158        Ok(Self {
159            shape,
160            stride,
161            start_offset: self.start_offset,
162        })
163    }
164
165    pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
166        let is_permutation =
167            idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
168        if !is_permutation {
169            crate::bail!(
170                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
171                self.dims(),
172                idxs
173            )
174        }
175        let stride = self.stride();
176        let dims = self.shape().dims();
177        let mut perm_stride = stride.to_vec();
178        let mut perm_dims = dims.to_vec();
179        for (i, &idx) in idxs.iter().enumerate() {
180            perm_stride[i] = stride[idx];
181            perm_dims[i] = dims[idx];
182        }
183        Ok(Self {
184            shape: Shape::from(perm_dims),
185            stride: perm_stride,
186            start_offset: self.start_offset,
187        })
188    }
189
190    /// Returns an iterator over **storage indices**.
191    ///
192    /// This iterator yields the linear (flat) indices as they are laid out
193    /// in the underlying storage buffer. The order depends on the memory
194    /// layout (e.g., row-major / column-major / with strides).
195    ///
196    /// Example for shape = (2, 2) in row-major layout:
197    /// yields: `0, 1, 2, 3`
198    pub fn storage_indices(&self) -> StorageIndices {
199        StorageIndices::from_layout(self)
200    }
201}
202
203//////////////////////////////////////////////////////////////////////////////////////
204///                  StorageIndices
205//////////////////////////////////////////////////////////////////////////////////////
206
207#[derive(Debug, Clone)]
208pub enum StorageIndices<'a> {
209    UncontiguousStorageIndices(UncontiguousStorageIndices<'a>),
210    ContiguousStorageIndices(ContiguousStorageIndices),
211}
212
213impl<'a> StorageIndices<'a> {
214    pub fn from_layout(l: &'a Layout) -> Self {
215        if l.is_contiguous() {
216            Self::ContiguousStorageIndices(ContiguousStorageIndices::from_layout(l))
217        } else {
218            Self::UncontiguousStorageIndices(UncontiguousStorageIndices::from_layout(l))
219        }
220    }
221
222    pub fn reset(&mut self) {
223        match self {
224            Self::UncontiguousStorageIndices(index) => index.reset(),
225            Self::ContiguousStorageIndices(index) => index.reset(),
226        }
227    }
228
229    pub fn len(&self) -> usize {
230        match self {
231            Self::UncontiguousStorageIndices(index) => index.len(),
232            Self::ContiguousStorageIndices(index) => index.len(),
233        }
234    }
235}
236
237impl<'a> Iterator for StorageIndices<'a> {
238    type Item = usize;
239
240    fn next(&mut self) -> Option<Self::Item> {
241        match self {
242            Self::ContiguousStorageIndices(i) => i.next(),
243            Self::UncontiguousStorageIndices(i) => i.next(),
244        }
245    }
246}
247
248#[derive(Debug, Clone)]
249pub struct ContiguousStorageIndices {
250    init_storage_index: usize,
251    storage_index: usize,
252    end_index: usize, 
253}
254
255impl ContiguousStorageIndices {
256    fn from_layout(l: &Layout) -> Self {
257        Self {
258            init_storage_index: l.start_offset(),
259            storage_index: l.start_offset(),
260            end_index: l.start_offset() + l.element_count(),
261        }
262    }
263
264    fn reset(&mut self) {
265        self.storage_index = self.init_storage_index;
266    }
267
268    fn len(&self) -> usize {
269        self.end_index - self.init_storage_index
270    }
271}
272
273impl Iterator for ContiguousStorageIndices {
274    type Item = usize;
275
276    fn next(&mut self) -> Option<Self::Item> {
277        if self.storage_index >= self.end_index {
278            None
279        } else {
280            let index = self.storage_index;
281            self.storage_index += 1;
282            Some(index)
283        }
284    }
285}
286
287impl<S: Into<Shape>> From<S> for Layout { 
288    fn from(value: S) -> Self {
289        Layout::contiguous(value.into())
290    }
291}
292
293#[derive(Debug, Clone)]
294pub struct UncontiguousStorageIndices<'a> {
295    init_storage_index: Option<usize>, /// For reset
296    next_storage_index: Option<usize>,
297    multi_index: Vec<usize>,
298    dims: &'a [usize],
299    stride: &'a [usize],
300    len: usize,
301}
302
303impl<'a> UncontiguousStorageIndices<'a> {
304    fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
305        let elem_count: usize = dims.iter().product();
306        let next_storage_index = if elem_count == 0 {
307            None
308        } else {
309            // This applies to the scalar case.
310            Some(start_offset)
311        };
312        UncontiguousStorageIndices {
313            init_storage_index: next_storage_index,
314            next_storage_index,
315            multi_index: vec![0; dims.len()],
316            dims,
317            stride,
318            len: elem_count,
319        }
320    }
321
322    fn from_layout(l: &'a Layout) -> Self {
323        Self::new(l.dims(), l.stride(), l.start_offset())
324    }
325
326    pub fn reset(&mut self) {
327        self.next_storage_index = self.init_storage_index;
328    }
329
330    pub fn len(&self) -> usize {
331        self.len
332    }
333}
334
335impl Iterator for UncontiguousStorageIndices<'_> {
336    type Item = usize;
337
338    fn next(&mut self) -> Option<Self::Item> {
339        let storage_index = self.next_storage_index?;
340        let mut updated = false;
341        let mut next_storage_index = storage_index;
342        for ((multi_i, max_i), stride_i) in self
343            .multi_index
344            .iter_mut()
345            .zip(self.dims.iter())
346            .zip(self.stride.iter())
347            .rev()
348        {
349            let next_i = *multi_i + 1;
350            if next_i < *max_i {
351                *multi_i = next_i;
352                updated = true;
353                next_storage_index += stride_i;
354                break;
355            } else {
356                next_storage_index -= *multi_i * stride_i;
357                *multi_i = 0
358            }
359        }
360        self.next_storage_index = if updated {
361            Some(next_storage_index)
362        } else {
363            None
364        };
365        Some(storage_index)
366    }
367}
368
369#[cfg(test)]
370#[allow(unused)]
371mod tests {
372    use super::{Layout, StorageIndices};
373
374    #[test]
375    fn test_strided_index1() {
376        let layout = Layout::contiguous((2, 5, 4));
377        let index = StorageIndices::from_layout(&layout);
378        for i in index {
379            println!("{}", i);
380        }
381    }
382
383    #[test]
384    fn test_strided_index2() {
385        let layout = Layout::contiguous((2, 3, 3));
386        let layout = layout.narrow(1, 1, 1).unwrap();
387        println!("{:?}", layout.stride());
388        let index = StorageIndices::from_layout(&layout);
389        for i in index {
390            println!("{}", i);
391        }
392    }
393}
394