candle_core_temp/
layout.rs

1use crate::{Error, Result, Shape};
2
3#[derive(Debug, PartialEq, Eq, Clone)]
4pub struct Layout {
5    shape: Shape,
6    // The strides are given in number of elements and not in bytes.
7    stride: Vec<usize>,
8    start_offset: usize,
9}
10
11impl Layout {
12    pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
13        Self {
14            shape,
15            stride,
16            start_offset,
17        }
18    }
19
20    pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
21        let shape = shape.into();
22        let stride = shape.stride_contiguous();
23        Self {
24            shape,
25            stride,
26            start_offset,
27        }
28    }
29
30    pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
31        Self::contiguous_with_offset(shape, 0)
32    }
33
34    pub fn dims(&self) -> &[usize] {
35        self.shape.dims()
36    }
37
38    pub fn shape(&self) -> &Shape {
39        &self.shape
40    }
41
42    pub fn stride(&self) -> &[usize] {
43        &self.stride
44    }
45
46    pub fn start_offset(&self) -> usize {
47        self.start_offset
48    }
49
50    /// Returns the appropriate start and stop offset if the data is stored in a C
51    /// contiguous (aka row major) way.
52    pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
53        if self.is_contiguous() {
54            let start_o = self.start_offset;
55            Some((start_o, start_o + self.shape.elem_count()))
56        } else {
57            None
58        }
59    }
60
61    /// Returns true if the data is stored in a C contiguous (aka row major) way.
62    /// Note that this does not implies that the start offset is 0 or that there are no extra
63    /// elements at the end of the storage.
64    pub fn is_contiguous(&self) -> bool {
65        self.shape.is_contiguous(&self.stride)
66    }
67
68    /// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
69    pub fn is_fortran_contiguous(&self) -> bool {
70        self.shape.is_fortran_contiguous(&self.stride)
71    }
72
73    pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
74        let dims = self.shape().dims();
75        if dim >= dims.len() {
76            Err(Error::DimOutOfRange {
77                shape: self.shape().clone(),
78                dim: dim as i32,
79                op: "narrow",
80            }
81            .bt())?
82        }
83        if start + len > dims[dim] {
84            Err(Error::NarrowInvalidArgs {
85                shape: self.shape.clone(),
86                dim,
87                start,
88                len,
89                msg: "start + len > dim_len",
90            }
91            .bt())?
92        }
93        let mut dims = dims.to_vec();
94        dims[dim] = len;
95        Ok(Self {
96            shape: Shape::from(dims),
97            stride: self.stride.clone(),
98            start_offset: self.start_offset + self.stride[dim] * start,
99        })
100    }
101
102    pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
103        let rank = self.shape.rank();
104        if rank <= dim1 || rank <= dim2 {
105            Err(Error::UnexpectedNumberOfDims {
106                expected: usize::max(dim1, dim2),
107                got: rank,
108                shape: self.shape().clone(),
109            }
110            .bt())?
111        }
112        let mut stride = self.stride().to_vec();
113        let mut dims = self.shape().dims().to_vec();
114        dims.swap(dim1, dim2);
115        stride.swap(dim1, dim2);
116        Ok(Self {
117            shape: Shape::from(dims),
118            stride,
119            start_offset: self.start_offset,
120        })
121    }
122
123    pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
124        let is_permutation =
125            idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
126        if !is_permutation {
127            crate::bail!(
128                "dimension mismatch in permute, tensor {:?}, dims: {:?}",
129                self.dims(),
130                idxs
131            )
132        }
133        let stride = self.stride();
134        let dims = self.shape().dims();
135        let mut perm_stride = stride.to_vec();
136        let mut perm_dims = dims.to_vec();
137        for (i, &idx) in idxs.iter().enumerate() {
138            perm_stride[i] = stride[idx];
139            perm_dims[i] = dims[idx];
140        }
141        Ok(Self {
142            shape: Shape::from(perm_dims),
143            stride: perm_stride,
144            start_offset: self.start_offset,
145        })
146    }
147
148    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
149        let shape = shape.into();
150        if shape.rank() < self.shape().rank() {
151            return Err(Error::BroadcastIncompatibleShapes {
152                src_shape: self.shape().clone(),
153                dst_shape: shape,
154            }
155            .bt());
156        }
157        let added_dims = shape.rank() - self.shape().rank();
158        let mut stride = vec![0; added_dims];
159        for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
160            .iter()
161            .zip(self.dims().iter().zip(self.stride()))
162        {
163            let s = if dst_dim == src_dim {
164                src_stride
165            } else if src_dim != 1 {
166                return Err(Error::BroadcastIncompatibleShapes {
167                    src_shape: self.shape().clone(),
168                    dst_shape: shape,
169                }
170                .bt());
171            } else {
172                0
173            };
174            stride.push(s)
175        }
176        Ok(Self {
177            shape,
178            stride,
179            start_offset: self.start_offset,
180        })
181    }
182
183    pub(crate) fn strided_index(&self) -> crate::StridedIndex {
184        crate::StridedIndex::from_layout(self)
185    }
186
187    pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
188        let mut block_len = 1;
189        let mut contiguous_dims = 0; // These are counted from the right.
190        for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
191            if stride != block_len {
192                break;
193            }
194            block_len *= dim;
195            contiguous_dims += 1;
196        }
197        let index_dims = self.dims().len() - contiguous_dims;
198        if index_dims == 0 {
199            crate::StridedBlocks::SingleBlock {
200                start_offset: self.start_offset,
201                len: block_len,
202            }
203        } else {
204            let block_start_index = crate::StridedIndex::new(
205                &self.dims()[..index_dims],
206                &self.stride[..index_dims],
207                self.start_offset,
208            );
209            crate::StridedBlocks::MultipleBlocks {
210                block_start_index,
211                block_len,
212            }
213        }
214    }
215
216    // Returns the contiguous offsets with broadcast if applicable.
217    pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
218        let mut left_broadcast = 1;
219        let mut right_broadcast = 1;
220        let strides = self.stride();
221        let dims = self.dims();
222        let mut start_cont = 0;
223        let mut end_cont = dims.len();
224        for (&s, &d) in strides.iter().zip(dims.iter()) {
225            if s != 0 {
226                break;
227            }
228            start_cont += 1;
229            left_broadcast *= d;
230        }
231        if start_cont == dims.len() {
232            return Some(ContiguousOffsetsWithBroadcast {
233                start: self.start_offset,
234                len: 1,
235                left_broadcast,
236                right_broadcast: 1,
237            });
238        }
239        for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
240            if s != 0 {
241                break;
242            }
243            end_cont -= 1;
244            right_broadcast *= d;
245        }
246        // Check that the inner dims are contiguous
247        let strides = &strides[start_cont..end_cont];
248        let dims = &dims[start_cont..end_cont];
249        let mut len = 1;
250        for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
251            if stride != len {
252                return None;
253            }
254            len *= dim;
255        }
256        Some(ContiguousOffsetsWithBroadcast {
257            start: self.start_offset,
258            len,
259            left_broadcast,
260            right_broadcast,
261        })
262    }
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub struct ContiguousOffsetsWithBroadcast {
267    pub start: usize,
268    pub len: usize,
269    pub left_broadcast: usize,
270    pub right_broadcast: usize,
271}