rstsr_common/layout/
rearrangement.rs

1//! Layout rearrangement.
2//!
3//! Purposes of rearrangement of layouts:
4//! - Faster iteration to inplace-modify storage, and binary/ternary operations.
5//! - Split layout to multiple layouts.
6
7use crate::prelude_dev::*;
8
9// type alias for this file
10type Order = TensorIterOrder;
11
12/* #region translate tensor order to col-major with TensorIterType */
13
14/// This function will return a f-prefer layout that make minimal memory
15/// accessing efforts (pointers will not frequently back-and-forth).
16///
17/// Note that this function should only be used for iteration.
18///
19/// # Parameter `keep_shape`
20///
21/// Keep size of output layout when input layout is boardcasted.
22/// This option should be false if [`TensorIterOrder::K`] and true if
23/// [`TensorIterOrder::G`].
24///
25/// For example of layout shape `[5, 1, 2, 1, 3, 6]` and stride `[1000, 10, 10,
26/// 40, 0, 100]`,
27/// - false: shape `[2, 6, 5, 1, 1, 1]` and stride `[10, 100, 1000, 0, 0, 0]`; meaning that
28///   broadcasted shapes are eliminated and moved to last axes.
29/// - true: shape `[3, 1, 1, 2, 6, 5]` and stride `[0, 10, 40, 10, 100, 1000]`; meaning that
30///   broadcasted shapes are iterated with most priority.
31///
32/// # Returns
33///
34/// - `layout`: The output layout of greedy iteration.
35/// - `index`: Transpose index from input layout to output layout.
36pub fn greedy_layout<D>(layout: &Layout<D>, keep_shape: bool) -> (Layout<D>, Vec<isize>)
37where
38    D: DimDevAPI,
39{
40    let mut layout = layout.clone();
41
42    // if no elements in layout, return itself
43    if layout.size() == 0 {
44        return (layout.clone(), (0..layout.ndim() as isize).collect_vec());
45    }
46
47    // revert negative strides if keep_shape is not required
48    if keep_shape {
49        for n in 0..layout.ndim() {
50            if layout.stride()[n] < 0 {
51                // should not panic here
52                layout = layout.dim_narrow(n as isize, slice!(None, None, -1)).unwrap();
53            }
54        }
55    }
56
57    let shape_old = layout.shape.as_ref();
58    let stride_old = layout.stride.as_ref();
59
60    let mut index = (0..layout.ndim() as isize).collect_vec();
61    if keep_shape {
62        // sort shape and strides if keep shape
63        // - (shape = 1 / stride = 0) the smallest (pointer not moving for these cases)
64        // - if (shape = 1 / stride = 0, broadcastable axes) preserve order
65        // - (larger shape first) if not broadcastable axes, then compare stride size (smaller stride first)
66        index.sort_by(|&i1, &i2| {
67            let d1 = shape_old[i1 as usize];
68            let d2 = shape_old[i2 as usize];
69            let t1 = stride_old[i1 as usize];
70            let t2 = stride_old[i2 as usize];
71            match (d1 == 1 || t1 == 0, d2 == 1 || t2 == 0) {
72                (true, true) => i1.cmp(&i2),
73                (true, false) => core::cmp::Ordering::Less,
74                (false, true) => core::cmp::Ordering::Greater,
75                (false, false) => t1.abs().cmp(&t2.abs()),
76            }
77        });
78    } else {
79        // sort shape and strides if not keep shape
80        // everything is similar, though broadcastable axes should be moved to last
81        index.sort_by(|&i1, &i2| {
82            let d1 = shape_old[i1 as usize];
83            let d2 = shape_old[i2 as usize];
84            let t1 = stride_old[i1 as usize];
85            let t2 = stride_old[i2 as usize];
86            match (d1 == 1 || t1 == 0, d2 == 1 || t2 == 0) {
87                (true, true) => i1.cmp(&i2),
88                (true, false) => core::cmp::Ordering::Greater,
89                (false, true) => core::cmp::Ordering::Less,
90                (false, false) => t1.abs().cmp(&t2.abs()),
91            }
92        });
93    }
94
95    let mut layout = layout.transpose(&index).unwrap();
96
97    // for case of not keep shape, dimension of broadcastable axes will be set to 1,
98    // strides will be set to 0.
99    if !keep_shape {
100        let mut shape = layout.shape().clone();
101        let mut stride = layout.stride().clone();
102        shape.as_mut().iter_mut().zip(stride.as_mut().iter_mut()).for_each(|(d, t)| {
103            if *d == 1 || *t == 0 {
104                *d = 1;
105                *t = 0;
106            }
107        });
108        layout = unsafe { Layout::new_unchecked(shape, stride, layout.offset()) };
109    }
110
111    return (layout, index);
112}
113
114/// Reversed permutation indices.
115pub fn reversed_permute(indices: &[isize]) -> Vec<isize> {
116    let mut new_indices = vec![0; indices.len()];
117    for (idx, &i) in indices.iter().enumerate() {
118        new_indices[i as usize] = idx as isize;
119    }
120    return new_indices;
121}
122
123/// Return a layout that is suitable for array copy.
124pub fn layout_for_array_copy<D>(layout: &Layout<D>, order: TensorIterOrder) -> Result<Layout<D>>
125where
126    D: DimDevAPI,
127{
128    let layout = match order {
129        Order::C => layout.shape().c(),
130        Order::F => layout.shape().f(),
131        Order::A => {
132            if layout.c_contig() {
133                layout.shape().c()
134            } else if layout.f_contig() {
135                layout.shape().f()
136            } else {
137                match TensorOrder::default() {
138                    RowMajor => layout.shape().c(),
139                    ColMajor => layout.shape().f(),
140                }
141            }
142        },
143        Order::K => {
144            let (greedy, indices) = greedy_layout(layout, true);
145            let layout = greedy.shape().f();
146            layout.transpose(&reversed_permute(&indices))?
147        },
148        _ => rstsr_invalid!(order, "Iter order for copy only accepts CFAK.")?,
149    };
150    return Ok(layout);
151}
152
153/// Translate one layout to column-major iteration.
154///
155/// For how parameter `it_type` works, we refer to definition of
156/// [`TensorIterOrder`].
157///
158/// - C: reverse axes
159/// - F: preserve axes
160/// - A: B if contiguous, C if c-prefer, F if f-prefer; otherwise default
161/// - K: greedy layout, keep shape
162/// - G: greedy layout, eliminate broadcastable dimensions
163/// - B: sequential memory; valid option if `size = bound_max - bound_min`, otherwise raise err
164pub fn translate_to_col_major_unary<D>(layout: &Layout<D>, order: TensorIterOrder) -> Result<Layout<D>>
165where
166    D: DimDevAPI,
167{
168    let fn_c = |l: &Layout<D>| Ok(l.reverse_axes());
169    let fn_f = |l: &Layout<D>| Ok(l.clone());
170    let fn_b = |l: &Layout<D>| {
171        let (bounds_min, bounds_max) = l.bounds_index()?;
172        rstsr_assert_eq!(
173            bounds_max - bounds_min,
174            l.size(),
175            InvalidLayout,
176            "Data in this layout could not be represented as sequential memory."
177        )?;
178        let mut shape = l.new_shape();
179        let mut stride = l.new_stride();
180        shape[0] = l.size();
181        stride[0] = 1;
182        for i in 1..l.ndim() {
183            shape[i] = 1;
184            stride[i] = l.size() as isize;
185        }
186        Ok(unsafe { Layout::new_unchecked(shape, stride, l.offset()) })
187    };
188    match order {
189        Order::C => fn_c(layout),
190        Order::F => fn_f(layout),
191        Order::A => {
192            let c_contig = layout.c_contig();
193            let f_contig = layout.f_contig();
194            if c_contig || f_contig {
195                fn_b(layout)
196            } else {
197                let c_prefer = layout.c_prefer();
198                let f_prefer = layout.f_prefer();
199                match (c_prefer, f_prefer) {
200                    (true, false) => fn_c(layout),
201                    (false, true) => fn_f(layout),
202                    (_, _) => match FlagOrder::default() {
203                        RowMajor => fn_c(layout),
204                        ColMajor => fn_f(layout),
205                    },
206                }
207            }
208        },
209        Order::K => Ok(greedy_layout(layout, true).0),
210        Order::G => Ok(greedy_layout(layout, false).0),
211        Order::B => fn_b(layout),
212    }
213}
214
215/// Translate multiple layouts to column-major iteration.
216///
217/// This function requires all layouts have the same shape.
218///
219/// For how parameter `it_type` works, we refer to definition of
220/// [`TensorIterOrder`].
221///
222/// - C: reverse axes
223/// - F: preserve axes
224/// - A: B if contiguous, C if c-prefer, F if f-prefer; otherwise default
225/// - K: greedy layout for the one which have the largest non-broadcast-size, otherwise left-most
226///   layout (usually for mutable-assign/inplace-op)
227/// - G: invalid option here
228/// - B: sequential memory; valid option if `size = bound_max - bound_min`, otherwise raise err
229///
230/// This operation will not flip any strides.
231pub fn translate_to_col_major<D>(layouts: &[&Layout<D>], order: TensorIterOrder) -> Result<Vec<Layout<D>>>
232where
233    D: DimAPI,
234{
235    if layouts.is_empty() {
236        return Ok(vec![]);
237    }
238
239    // this function will map all layouts to column-major iteration by a single
240    // iter-order.
241    let fn_single = |ls: &[&Layout<D>], order| ls.iter().map(|l| translate_to_col_major_unary(l, order)).collect();
242
243    // make sure all layouts have the same shape
244    let is_same_shape = layouts.windows(2).all(|w| w[0].shape() == w[1].shape());
245    rstsr_assert!(is_same_shape, InvalidLayout, "All shape of layout in this function must be the same.")?;
246
247    match order {
248        Order::C | Order::F | Order::B => fn_single(layouts, order),
249        Order::A => {
250            let c_contig = layouts.iter().all(|&l| l.c_contig());
251            let f_contig = layouts.iter().all(|&l| l.f_contig());
252            if c_contig || f_contig {
253                fn_single(layouts, Order::B)
254            } else {
255                let c_prefer = layouts.iter().all(|&l| l.c_contig());
256                let f_prefer = layouts.iter().all(|&l| l.f_contig());
257                match (c_prefer, f_prefer) {
258                    (true, false) => fn_single(layouts, Order::C),
259                    (false, true) => fn_single(layouts, Order::F),
260                    (_, _) => match FlagOrder::default() {
261                        RowMajor => fn_single(layouts, Order::C),
262                        ColMajor => fn_single(layouts, Order::F),
263                    },
264                }
265            }
266        },
267        Order::K => {
268            // find the layout with the largest non-broadcast-size
269            let size_iter = layouts.iter().map(|l| l.size_non_broadcast()).collect_vec();
270            let idx_layout = if size_iter.iter().max() == size_iter.iter().min() {
271                0
272            } else {
273                size_iter.into_iter().enumerate().max_by_key(|(_, v)| *v).unwrap_or((0, 0)).0
274            };
275            // make same permutation for all layouts
276            let (_, permute_index) = greedy_layout(layouts[idx_layout], true);
277            layouts.iter().map(|l| l.transpose(&permute_index)).collect()
278        },
279        Order::G => rstsr_invalid!(order, "This option is not valid for multiple layouts")?,
280    }
281}
282
283/// This function will return minimal dimension layout, that the first axis is
284/// f-contiguous.
285///
286/// For example, if shape [2, 4, 6, 8, 10] is contiguous in f-order for the
287/// first three axes, then it will return shape [48, 8, 10], and the contiguous
288/// size 48.
289///
290/// # Notes
291///
292/// - Should be used after [`translate_to_col_major`].
293/// - Accepts multiple layouts to be compared.
294/// - Due to that final dimension is not known to compiler, this function will return dynamic
295///   layout.
296pub fn translate_to_col_major_with_contig<D>(layouts: &[&Layout<D>]) -> (Vec<Layout<IxD>>, usize)
297where
298    D: DimAPI,
299{
300    if layouts.is_empty() {
301        return (vec![], 0);
302    }
303
304    let dims_f_contig = layouts.iter().map(|l| l.ndim_of_f_contig()).collect_vec();
305    let ndim_f_contig = *dims_f_contig.iter().min().unwrap();
306    // following is the worst case: no axes are contiguous in f-order
307    if ndim_f_contig == 0 {
308        return (layouts.iter().map(|&l| l.clone().into_dim::<IxD>().unwrap()).collect(), 0);
309    } else {
310        let size_contig = layouts[0].shape().as_ref()[0..ndim_f_contig].iter().product::<usize>();
311        let result = layouts
312            .iter()
313            .map(|l| {
314                let shape = l.shape().as_ref()[ndim_f_contig..].iter().cloned().collect_vec();
315                let stride = l.stride().as_ref()[ndim_f_contig..].iter().cloned().collect_vec();
316                unsafe { Layout::new_unchecked(shape, stride, l.offset()) }
317            })
318            .collect_vec();
319        return (result, size_contig);
320    }
321}
322
323#[cfg(test)]
324mod test {
325    use super::*;
326
327    #[test]
328    fn test_greedy_layout() {
329        unsafe {
330            // c-contiguous layout
331            let layout = [2, 3, 4].c();
332            let (greedy, _) = greedy_layout(&layout, false);
333            assert_eq!(greedy, [4, 3, 2].f());
334            let (greedy, _) = greedy_layout(&layout, true);
335            assert_eq!(greedy, [4, 3, 2].f());
336            // f-contiguous layout
337            let layout = [2, 3, 4].f();
338            let (greedy, _) = greedy_layout(&layout, false);
339            assert_eq!(greedy, [2, 3, 4].f());
340            let (greedy, _) = greedy_layout(&layout, true);
341            assert_eq!(greedy, [2, 3, 4].f());
342            // dimension-size 1 or stride-size 0
343            let layout = Layout::new_unchecked([5, 1, 2, 1, 3, 6], [1000, 10, 10, 40, 0, 100], 0);
344            let (greedy, _) = greedy_layout(&layout, false);
345            let expect = Layout::new_unchecked([2, 6, 5, 1, 1, 1], [10, 100, 1000, 0, 0, 0], 0);
346            assert_eq!(greedy, expect);
347            let (greedy, _) = greedy_layout(&layout, true);
348            let expect = Layout::new_unchecked([1, 1, 3, 2, 6, 5], [10, 40, 0, 10, 100, 1000], 0);
349            assert_eq!(greedy, expect);
350            // negative strides
351            let layout = [2, 3, 4].f().dim_narrow(1, slice!(None, None, -1)).unwrap();
352            let layout = layout.swapaxes(-1, -2).unwrap();
353            let (greedy, _) = greedy_layout(&layout, true);
354            assert_eq!(greedy, [2, 3, 4].f());
355            let (greedy, _) = greedy_layout(&layout, false);
356            assert_eq!(greedy, [2, 3, 4].f().dim_narrow(1, slice!(None, None, -1)).unwrap());
357        }
358    }
359}