use crate::prelude_dev::*;
type Order = TensorIterOrder;
pub fn greedy_layout<D>(layout: &Layout<D>, keep_shape: bool) -> (Layout<D>, Vec<isize>)
where
D: DimDevAPI,
{
let mut layout = layout.clone();
if layout.size() == 0 {
let ndim = layout.ndim();
return (layout, (0..ndim as isize).collect_vec());
}
if keep_shape {
for n in 0..layout.ndim() {
if layout.stride()[n] < 0 {
layout = layout.dim_narrow(n as isize, slice!(None, None, -1)).unwrap();
}
}
}
let shape_old = layout.shape.as_ref();
let stride_old = layout.stride.as_ref();
let mut index = (0..layout.ndim() as isize).collect_vec();
if keep_shape {
index.sort_by(|&i1, &i2| {
let d1 = shape_old[i1 as usize];
let d2 = shape_old[i2 as usize];
let t1 = stride_old[i1 as usize];
let t2 = stride_old[i2 as usize];
match (d1 == 1 || t1 == 0, d2 == 1 || t2 == 0) {
(true, true) => i1.cmp(&i2),
(true, false) => core::cmp::Ordering::Less,
(false, true) => core::cmp::Ordering::Greater,
(false, false) => t1.abs().cmp(&t2.abs()),
}
});
} else {
index.sort_by(|&i1, &i2| {
let d1 = shape_old[i1 as usize];
let d2 = shape_old[i2 as usize];
let t1 = stride_old[i1 as usize];
let t2 = stride_old[i2 as usize];
match (d1 == 1 || t1 == 0, d2 == 1 || t2 == 0) {
(true, true) => i1.cmp(&i2),
(true, false) => core::cmp::Ordering::Greater,
(false, true) => core::cmp::Ordering::Less,
(false, false) => t1.abs().cmp(&t2.abs()),
}
});
}
let mut layout = layout.transpose(&index).unwrap();
if !keep_shape {
let mut shape = layout.shape().clone();
let mut stride = layout.stride().clone();
shape.as_mut().iter_mut().zip(stride.as_mut().iter_mut()).for_each(|(d, t)| {
if *d == 1 || *t == 0 {
*d = 1;
*t = 0;
}
});
layout = unsafe { Layout::new_unchecked(shape, stride, layout.offset()) };
}
return (layout, index);
}
pub fn reversed_permute(indices: &[isize]) -> Vec<isize> {
let mut new_indices = vec![0; indices.len()];
for (idx, &i) in indices.iter().enumerate() {
new_indices[i as usize] = idx as isize;
}
return new_indices;
}
pub fn layout_for_array_copy<D>(layout: &Layout<D>, order: TensorIterOrder) -> Result<Layout<D>>
where
D: DimDevAPI,
{
let layout = match order {
Order::C => layout.shape().c(),
Order::F => layout.shape().f(),
Order::A => {
if layout.c_contig() {
layout.shape().c()
} else if layout.f_contig() {
layout.shape().f()
} else {
match TensorOrder::default() {
RowMajor => layout.shape().c(),
ColMajor => layout.shape().f(),
}
}
},
Order::K => {
let (greedy, indices) = greedy_layout(layout, true);
let layout = greedy.shape().f();
layout.transpose(&reversed_permute(&indices))?
},
_ => rstsr_invalid!(order, "Iter order for copy only accepts CFAK.")?,
};
return Ok(layout);
}
pub fn translate_to_col_major_unary<D>(layout: &Layout<D>, order: TensorIterOrder) -> Result<Layout<D>>
where
D: DimDevAPI,
{
let fn_c = |l: &Layout<D>| Ok(l.reverse_axes());
let fn_f = |l: &Layout<D>| Ok(l.clone());
let fn_b = |l: &Layout<D>| {
let (bounds_min, bounds_max) = l.bounds_index()?;
rstsr_assert_eq!(
bounds_max - bounds_min,
l.size(),
InvalidLayout,
"Data in this layout could not be represented as sequential memory."
)?;
let mut shape = l.new_shape();
let mut stride = l.new_stride();
shape[0] = l.size();
stride[0] = 1;
for i in 1..l.ndim() {
shape[i] = 1;
stride[i] = l.size() as isize;
}
Ok(unsafe { Layout::new_unchecked(shape, stride, l.offset()) })
};
match order {
Order::C => fn_c(layout),
Order::F => fn_f(layout),
Order::A => {
let c_contig = layout.c_contig();
let f_contig = layout.f_contig();
if c_contig || f_contig {
fn_b(layout)
} else {
let c_prefer = layout.c_prefer();
let f_prefer = layout.f_prefer();
match (c_prefer, f_prefer) {
(true, false) => fn_c(layout),
(false, true) => fn_f(layout),
(_, _) => match FlagOrder::default() {
RowMajor => fn_c(layout),
ColMajor => fn_f(layout),
},
}
}
},
Order::K => Ok(greedy_layout(layout, true).0),
Order::G => Ok(greedy_layout(layout, false).0),
Order::B => fn_b(layout),
}
}
pub fn translate_to_col_major<D>(layouts: &[&Layout<D>], order: TensorIterOrder) -> Result<Vec<Layout<D>>>
where
D: DimAPI,
{
if layouts.is_empty() {
return Ok(vec![]);
}
let fn_single = |ls: &[&Layout<D>], order| ls.iter().map(|l| translate_to_col_major_unary(l, order)).collect();
let is_same_shape = layouts.windows(2).all(|w| w[0].shape() == w[1].shape());
rstsr_assert!(is_same_shape, InvalidLayout, "All shape of layout in this function must be the same.")?;
match order {
Order::C | Order::F | Order::B => fn_single(layouts, order),
Order::A => {
let c_contig = layouts.iter().all(|&l| l.c_contig());
let f_contig = layouts.iter().all(|&l| l.f_contig());
if c_contig || f_contig {
fn_single(layouts, Order::B)
} else {
let c_prefer = layouts.iter().all(|&l| l.c_contig());
let f_prefer = layouts.iter().all(|&l| l.f_contig());
match (c_prefer, f_prefer) {
(true, false) => fn_single(layouts, Order::C),
(false, true) => fn_single(layouts, Order::F),
(_, _) => match FlagOrder::default() {
RowMajor => fn_single(layouts, Order::C),
ColMajor => fn_single(layouts, Order::F),
},
}
}
},
Order::K => {
let size_iter = layouts.iter().map(|l| l.size_non_broadcast()).collect_vec();
let idx_layout = if size_iter.iter().max() == size_iter.iter().min() {
0
} else {
size_iter.into_iter().enumerate().max_by_key(|(_, v)| *v).unwrap_or((0, 0)).0
};
let (_, permute_index) = greedy_layout(layouts[idx_layout], true);
layouts.iter().map(|l| l.transpose(&permute_index)).collect()
},
Order::G => rstsr_invalid!(order, "This option is not valid for multiple layouts")?,
}
}
pub fn translate_to_col_major_with_contig<D>(layouts: &[&Layout<D>]) -> (Vec<Layout<IxD>>, usize)
where
D: DimAPI,
{
if layouts.is_empty() {
return (vec![], 0);
}
let dims_f_contig = layouts.iter().map(|l| l.ndim_of_f_contig()).collect_vec();
let ndim_f_contig = *dims_f_contig.iter().min().unwrap();
if ndim_f_contig == 0 {
return (layouts.iter().map(|&l| l.clone().into_dim::<IxD>().unwrap()).collect(), 0);
} else {
let size_contig = layouts[0].shape().as_ref()[0..ndim_f_contig].iter().product::<usize>();
let result = layouts
.iter()
.map(|l| {
let shape = l.shape().as_ref()[ndim_f_contig..].iter().cloned().collect_vec();
let stride = l.stride().as_ref()[ndim_f_contig..].iter().cloned().collect_vec();
unsafe { Layout::new_unchecked(shape, stride, l.offset()) }
})
.collect_vec();
return (result, size_contig);
}
}
pub fn get_axes_composition<D>(layout: &Layout<D>) -> (Vec<usize>, Vec<usize>, Vec<usize>, Vec<usize>)
where
D: DimDevAPI,
{
let mut comp_i = vec![]; let mut comp_b = vec![]; let mut comp_c = vec![]; let mut comp_d = vec![]; let mut remain = (0..layout.ndim()).collect_vec();
for i in 0..layout.ndim() {
if layout.shape()[i] == 1 {
comp_i.push(i);
} else if layout.stride()[i] == 0 {
comp_b.push(i);
}
}
remain.retain(|&i| !comp_i.contains(&i) && !comp_b.contains(&i));
remain.sort_by(|&i1, &i2| layout.stride()[i1].abs().cmp(&layout.stride()[i2].abs()));
let mut current_stride = 1;
for &i in &remain {
let (s, t) = (layout.shape()[i], layout.stride()[i]);
if t == current_stride {
comp_c.push(i);
current_stride *= s as isize;
}
}
comp_d = remain.into_iter().filter(|i| !comp_c.contains(i)).collect_vec();
(comp_i, comp_b, comp_c, comp_d)
}
pub fn get_layout_for_binary_op<D>(la: &Layout<D>, lb: &Layout<D>, order: FlagOrder) -> Result<Layout<D>>
where
D: DimDevAPI,
{
rstsr_assert_eq!(
la.shape(),
lb.shape(),
InvalidLayout,
"Shape of two layouts must be the same for this function."
)?;
let ndim = la.ndim();
let (a1, a0_a, ac_a, _) = get_axes_composition(la);
let (_, a0_b, ac_b, _) = get_axes_composition(lb);
let a0_o = a0_a.iter().filter(|&&i| a0_b.contains(&i)).cloned().collect_vec();
let mut ac_o = vec![];
for (ic_a, ic_b) in ac_a.iter().zip(ac_b.iter()) {
if *ic_a == *ic_b {
ac_o.push(*ic_a);
} else {
break;
}
}
let mut ad_o = (0..ndim).filter(|&i| !a0_o.contains(&i) && !ac_o.contains(&i) && !a1.contains(&i)).collect_vec();
if order == RowMajor {
ad_o.reverse();
}
let mut stride = la.new_stride();
let mut current_stride = 1;
for i in ac_o.iter().chain(ad_o.iter()) {
stride[*i] = current_stride;
current_stride *= la.shape()[*i] as isize;
}
for i in a1 {
let mut s = 1;
match order {
RowMajor => {
if let Some(&x) = stride.as_ref()[i..].iter().find(|&&x| x != 0) {
s = x;
}
},
ColMajor => {
if let Some(&x) = stride.as_ref()[..i].iter().rev().find(|&&x| x != 0) {
s = x;
}
},
}
stride[i] = s;
}
let shape = la.shape().clone();
let layout = unsafe { Layout::new_unchecked(shape, stride, 0) };
Ok(layout)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_greedy_layout() {
unsafe {
let layout = [2, 3, 4].c();
let (greedy, _) = greedy_layout(&layout, false);
assert_eq!(greedy, [4, 3, 2].f());
let (greedy, _) = greedy_layout(&layout, true);
assert_eq!(greedy, [4, 3, 2].f());
let layout = [2, 3, 4].f();
let (greedy, _) = greedy_layout(&layout, false);
assert_eq!(greedy, [2, 3, 4].f());
let (greedy, _) = greedy_layout(&layout, true);
assert_eq!(greedy, [2, 3, 4].f());
let layout = Layout::new_unchecked([5, 1, 2, 1, 3, 6], [1000, 10, 10, 40, 0, 100], 0);
let (greedy, _) = greedy_layout(&layout, false);
let expect = Layout::new_unchecked([2, 6, 5, 1, 1, 1], [10, 100, 1000, 0, 0, 0], 0);
assert_eq!(greedy, expect);
let (greedy, _) = greedy_layout(&layout, true);
let expect = Layout::new_unchecked([1, 1, 3, 2, 6, 5], [10, 40, 0, 10, 100, 1000], 0);
assert_eq!(greedy, expect);
let layout = [2, 3, 4].f().dim_narrow(1, slice!(None, None, -1)).unwrap();
let layout = layout.swapaxes(-1, -2).unwrap();
let (greedy, _) = greedy_layout(&layout, true);
assert_eq!(greedy, [2, 3, 4].f());
let (greedy, _) = greedy_layout(&layout, false);
assert_eq!(greedy, [2, 3, 4].f().dim_narrow(1, slice!(None, None, -1)).unwrap());
}
}
}