use crate::prelude_dev::*;
#[allow(clippy::needless_range_loop)]
fn attempt_nocopy_reshape(
old_dims: &[usize],
old_strides: &[isize],
newdims: &[usize],
is_f_order: bool,
) -> Option<Vec<isize>> {
let mut oldnd = 0;
let mut olddims = vec![0; old_dims.len()];
let mut oldstrides = vec![0; old_strides.len()];
let mut newstrides = vec![0; newdims.len()];
let newnd = newdims.len();
for i in 0..old_dims.len() {
if old_dims[i] != 1 {
olddims[oldnd] = old_dims[i];
oldstrides[oldnd] = old_strides[i];
oldnd += 1;
}
}
let mut oi = 0;
let mut oj = 1;
let mut ni = 0;
let mut nj = 1;
while ni < newnd && oi < oldnd {
let mut np = newdims[ni];
let mut op = olddims[oi];
while np != op {
if np < op {
np *= newdims[nj];
nj += 1;
} else {
op *= olddims[oj];
oj += 1;
}
}
for ok in oi..(oj - 1) {
if is_f_order {
if oldstrides[ok + 1] != olddims[ok] as isize * oldstrides[ok] {
return None; }
} else {
if oldstrides[ok] != olddims[ok + 1] as isize * oldstrides[ok + 1] {
return None; }
}
}
if is_f_order {
newstrides[ni] = oldstrides[oi];
for nk in (ni + 1)..nj {
newstrides[nk] = newstrides[nk - 1] * newdims[nk - 1] as isize;
}
} else {
newstrides[nj - 1] = oldstrides[oj - 1];
for nk in ((ni + 1)..nj).rev() {
newstrides[nk - 1] = newstrides[nk] * newdims[nk] as isize;
}
}
ni = nj;
nj += 1;
oi = oj;
oj += 1;
}
let last_stride = if ni >= 1 {
let mut stride = newstrides[ni - 1];
if is_f_order {
stride *= newdims[ni - 1] as isize;
}
stride
} else {
1 };
for nk in ni..newnd {
newstrides[nk] = last_stride;
}
Some(newstrides)
}
pub fn reshape_substitute_negatives(shape_out: &[isize], size_in: usize) -> Result<Vec<usize>> {
let mut shape = shape_out.to_vec();
let mut idx_neg1: Option<usize> = None;
for (i, &v) in shape.iter().enumerate() {
match v {
-1 => match idx_neg1 {
Some(_) => rstsr_raise!(InvalidValue, "Only one -1 is allowed in shape.")?,
None => idx_neg1 = Some(i),
},
..-1 => {
rstsr_raise!(InvalidValue, "Negative index must be -1.")?;
},
_ => (),
}
}
if let Some(idx_neg1) = idx_neg1 {
let size_in = size_in as isize;
let size_neg = shape.iter().fold(1, |acc, &v| if v == -1 { acc } else { acc * v });
rstsr_assert!(
size_in % size_neg == 0,
InvalidValue,
"Shape '-1' in {:?} could not be determined to original tensor size {:?}",
shape,
size_in
)?;
shape[idx_neg1] = size_in / size_neg;
}
return Ok(shape.iter().map(|&v| v as usize).collect::<Vec<usize>>());
}
fn quick_check(shape_out: &Vec<usize>, layout_in: &Layout<IxD>, order: FlagOrder) -> Result<Option<Layout<IxD>>> {
let size_in = layout_in.size();
let size_out = shape_out.iter().product();
rstsr_assert_eq!(size_in, size_out, InvalidValue, "Size mismatch between input tensor and output tensor.",)?;
if size_in == 0 || size_in == 1 {
let strides = vec![1; shape_out.len()];
return Ok(Some(Layout::<IxD>::new(shape_out.clone(), strides, layout_in.offset())?));
}
if shape_out == layout_in.shape() {
return Ok(Some(layout_in.clone()));
}
match order {
RowMajor => {
if layout_in.c_contig() {
return Ok(Some(shape_out.new_c_contig(Some(layout_in.offset()))));
}
},
ColMajor => {
if layout_in.f_contig() {
return Ok(Some(shape_out.new_f_contig(Some(layout_in.offset()))));
}
},
};
return Ok(None);
}
pub fn layout_reshapeable(
layout_in: &Layout<IxD>,
shape_out: &Vec<usize>,
order: FlagOrder,
) -> Result<Option<Layout<IxD>>> {
Ok(quick_check(shape_out, layout_in, order)?.or_else(|| {
attempt_nocopy_reshape(layout_in.shape(), layout_in.stride(), shape_out, order == ColMajor).map(
|stride_out| unsafe { Layout::<IxD>::new_unchecked(shape_out.to_vec(), stride_out, layout_in.offset()) },
)
}))
}