1use crate::prelude_dev::*;
8
9type Order = TensorIterOrder;
11
12pub fn greedy_layout<D>(layout: &Layout<D>, keep_shape: bool) -> (Layout<D>, Vec<isize>)
37where
38 D: DimDevAPI,
39{
40 let mut layout = layout.clone();
41
42 if layout.size() == 0 {
44 return (layout.clone(), (0..layout.ndim() as isize).collect_vec());
45 }
46
47 if keep_shape {
49 for n in 0..layout.ndim() {
50 if layout.stride()[n] < 0 {
51 layout = layout.dim_narrow(n as isize, slice!(None, None, -1)).unwrap();
53 }
54 }
55 }
56
57 let shape_old = layout.shape.as_ref();
58 let stride_old = layout.stride.as_ref();
59
60 let mut index = (0..layout.ndim() as isize).collect_vec();
61 if keep_shape {
62 index.sort_by(|&i1, &i2| {
67 let d1 = shape_old[i1 as usize];
68 let d2 = shape_old[i2 as usize];
69 let t1 = stride_old[i1 as usize];
70 let t2 = stride_old[i2 as usize];
71 match (d1 == 1 || t1 == 0, d2 == 1 || t2 == 0) {
72 (true, true) => i1.cmp(&i2),
73 (true, false) => core::cmp::Ordering::Less,
74 (false, true) => core::cmp::Ordering::Greater,
75 (false, false) => t1.abs().cmp(&t2.abs()),
76 }
77 });
78 } else {
79 index.sort_by(|&i1, &i2| {
82 let d1 = shape_old[i1 as usize];
83 let d2 = shape_old[i2 as usize];
84 let t1 = stride_old[i1 as usize];
85 let t2 = stride_old[i2 as usize];
86 match (d1 == 1 || t1 == 0, d2 == 1 || t2 == 0) {
87 (true, true) => i1.cmp(&i2),
88 (true, false) => core::cmp::Ordering::Greater,
89 (false, true) => core::cmp::Ordering::Less,
90 (false, false) => t1.abs().cmp(&t2.abs()),
91 }
92 });
93 }
94
95 let mut layout = layout.transpose(&index).unwrap();
96
97 if !keep_shape {
100 let mut shape = layout.shape().clone();
101 let mut stride = layout.stride().clone();
102 shape.as_mut().iter_mut().zip(stride.as_mut().iter_mut()).for_each(|(d, t)| {
103 if *d == 1 || *t == 0 {
104 *d = 1;
105 *t = 0;
106 }
107 });
108 layout = unsafe { Layout::new_unchecked(shape, stride, layout.offset()) };
109 }
110
111 return (layout, index);
112}
113
114pub fn reversed_permute(indices: &[isize]) -> Vec<isize> {
116 let mut new_indices = vec![0; indices.len()];
117 for (idx, &i) in indices.iter().enumerate() {
118 new_indices[i as usize] = idx as isize;
119 }
120 return new_indices;
121}
122
123pub fn layout_for_array_copy<D>(layout: &Layout<D>, order: TensorIterOrder) -> Result<Layout<D>>
125where
126 D: DimDevAPI,
127{
128 let layout = match order {
129 Order::C => layout.shape().c(),
130 Order::F => layout.shape().f(),
131 Order::A => {
132 if layout.c_contig() {
133 layout.shape().c()
134 } else if layout.f_contig() {
135 layout.shape().f()
136 } else {
137 match TensorOrder::default() {
138 RowMajor => layout.shape().c(),
139 ColMajor => layout.shape().f(),
140 }
141 }
142 },
143 Order::K => {
144 let (greedy, indices) = greedy_layout(layout, true);
145 let layout = greedy.shape().f();
146 layout.transpose(&reversed_permute(&indices))?
147 },
148 _ => rstsr_invalid!(order, "Iter order for copy only accepts CFAK.")?,
149 };
150 return Ok(layout);
151}
152
153pub fn translate_to_col_major_unary<D>(layout: &Layout<D>, order: TensorIterOrder) -> Result<Layout<D>>
165where
166 D: DimDevAPI,
167{
168 let fn_c = |l: &Layout<D>| Ok(l.reverse_axes());
169 let fn_f = |l: &Layout<D>| Ok(l.clone());
170 let fn_b = |l: &Layout<D>| {
171 let (bounds_min, bounds_max) = l.bounds_index()?;
172 rstsr_assert_eq!(
173 bounds_max - bounds_min,
174 l.size(),
175 InvalidLayout,
176 "Data in this layout could not be represented as sequential memory."
177 )?;
178 let mut shape = l.new_shape();
179 let mut stride = l.new_stride();
180 shape[0] = l.size();
181 stride[0] = 1;
182 for i in 1..l.ndim() {
183 shape[i] = 1;
184 stride[i] = l.size() as isize;
185 }
186 Ok(unsafe { Layout::new_unchecked(shape, stride, l.offset()) })
187 };
188 match order {
189 Order::C => fn_c(layout),
190 Order::F => fn_f(layout),
191 Order::A => {
192 let c_contig = layout.c_contig();
193 let f_contig = layout.f_contig();
194 if c_contig || f_contig {
195 fn_b(layout)
196 } else {
197 let c_prefer = layout.c_prefer();
198 let f_prefer = layout.f_prefer();
199 match (c_prefer, f_prefer) {
200 (true, false) => fn_c(layout),
201 (false, true) => fn_f(layout),
202 (_, _) => match FlagOrder::default() {
203 RowMajor => fn_c(layout),
204 ColMajor => fn_f(layout),
205 },
206 }
207 }
208 },
209 Order::K => Ok(greedy_layout(layout, true).0),
210 Order::G => Ok(greedy_layout(layout, false).0),
211 Order::B => fn_b(layout),
212 }
213}
214
215pub fn translate_to_col_major<D>(layouts: &[&Layout<D>], order: TensorIterOrder) -> Result<Vec<Layout<D>>>
232where
233 D: DimAPI,
234{
235 if layouts.is_empty() {
236 return Ok(vec![]);
237 }
238
239 let fn_single = |ls: &[&Layout<D>], order| ls.iter().map(|l| translate_to_col_major_unary(l, order)).collect();
242
243 let is_same_shape = layouts.windows(2).all(|w| w[0].shape() == w[1].shape());
245 rstsr_assert!(is_same_shape, InvalidLayout, "All shape of layout in this function must be the same.")?;
246
247 match order {
248 Order::C | Order::F | Order::B => fn_single(layouts, order),
249 Order::A => {
250 let c_contig = layouts.iter().all(|&l| l.c_contig());
251 let f_contig = layouts.iter().all(|&l| l.f_contig());
252 if c_contig || f_contig {
253 fn_single(layouts, Order::B)
254 } else {
255 let c_prefer = layouts.iter().all(|&l| l.c_contig());
256 let f_prefer = layouts.iter().all(|&l| l.f_contig());
257 match (c_prefer, f_prefer) {
258 (true, false) => fn_single(layouts, Order::C),
259 (false, true) => fn_single(layouts, Order::F),
260 (_, _) => match FlagOrder::default() {
261 RowMajor => fn_single(layouts, Order::C),
262 ColMajor => fn_single(layouts, Order::F),
263 },
264 }
265 }
266 },
267 Order::K => {
268 let size_iter = layouts.iter().map(|l| l.size_non_broadcast()).collect_vec();
270 let idx_layout = if size_iter.iter().max() == size_iter.iter().min() {
271 0
272 } else {
273 size_iter.into_iter().enumerate().max_by_key(|(_, v)| *v).unwrap_or((0, 0)).0
274 };
275 let (_, permute_index) = greedy_layout(layouts[idx_layout], true);
277 layouts.iter().map(|l| l.transpose(&permute_index)).collect()
278 },
279 Order::G => rstsr_invalid!(order, "This option is not valid for multiple layouts")?,
280 }
281}
282
283pub fn translate_to_col_major_with_contig<D>(layouts: &[&Layout<D>]) -> (Vec<Layout<IxD>>, usize)
297where
298 D: DimAPI,
299{
300 if layouts.is_empty() {
301 return (vec![], 0);
302 }
303
304 let dims_f_contig = layouts.iter().map(|l| l.ndim_of_f_contig()).collect_vec();
305 let ndim_f_contig = *dims_f_contig.iter().min().unwrap();
306 if ndim_f_contig == 0 {
308 return (layouts.iter().map(|&l| l.clone().into_dim::<IxD>().unwrap()).collect(), 0);
309 } else {
310 let size_contig = layouts[0].shape().as_ref()[0..ndim_f_contig].iter().product::<usize>();
311 let result = layouts
312 .iter()
313 .map(|l| {
314 let shape = l.shape().as_ref()[ndim_f_contig..].iter().cloned().collect_vec();
315 let stride = l.stride().as_ref()[ndim_f_contig..].iter().cloned().collect_vec();
316 unsafe { Layout::new_unchecked(shape, stride, l.offset()) }
317 })
318 .collect_vec();
319 return (result, size_contig);
320 }
321}
322
323#[cfg(test)]
324mod test {
325 use super::*;
326
327 #[test]
328 fn test_greedy_layout() {
329 unsafe {
330 let layout = [2, 3, 4].c();
332 let (greedy, _) = greedy_layout(&layout, false);
333 assert_eq!(greedy, [4, 3, 2].f());
334 let (greedy, _) = greedy_layout(&layout, true);
335 assert_eq!(greedy, [4, 3, 2].f());
336 let layout = [2, 3, 4].f();
338 let (greedy, _) = greedy_layout(&layout, false);
339 assert_eq!(greedy, [2, 3, 4].f());
340 let (greedy, _) = greedy_layout(&layout, true);
341 assert_eq!(greedy, [2, 3, 4].f());
342 let layout = Layout::new_unchecked([5, 1, 2, 1, 3, 6], [1000, 10, 10, 40, 0, 100], 0);
344 let (greedy, _) = greedy_layout(&layout, false);
345 let expect = Layout::new_unchecked([2, 6, 5, 1, 1, 1], [10, 100, 1000, 0, 0, 0], 0);
346 assert_eq!(greedy, expect);
347 let (greedy, _) = greedy_layout(&layout, true);
348 let expect = Layout::new_unchecked([1, 1, 3, 2, 6, 5], [10, 40, 0, 10, 100, 1000], 0);
349 assert_eq!(greedy, expect);
350 let layout = [2, 3, 4].f().dim_narrow(1, slice!(None, None, -1)).unwrap();
352 let layout = layout.swapaxes(-1, -2).unwrap();
353 let (greedy, _) = greedy_layout(&layout, true);
354 assert_eq!(greedy, [2, 3, 4].f());
355 let (greedy, _) = greedy_layout(&layout, false);
356 assert_eq!(greedy, [2, 3, 4].f().dim_narrow(1, slice!(None, None, -1)).unwrap());
357 }
358 }
359}