use super::layout::Layout;
pub fn calc_strides_from_shape(shape: &Vec<usize>, layout: Layout) -> Vec<usize> {
let mut data_size: usize = 1;
let mut strides: Vec<usize> = vec![0; shape.len()];
if layout == Layout::RowMajor {
for i in (1..(shape.len() + 1)).rev() {
strides[i - 1] = data_size;
data_size = strides[i - 1] * shape[i - 1];
}
}
else {
for i in 0..shape.len() {
strides[i] = data_size;
data_size = strides[i] * shape[i];
}
}
strides
}
pub fn calc_size_from_shape(shape: &[usize]) -> usize {
shape.iter().copied().reduce(|a, b| a * b).unwrap()
}
pub fn check_concat_dims(lhs: &Vec<usize>, rhs: &Vec<usize>, axis: usize) -> bool{
if lhs.len() != rhs.len() {
return false;
}
let len = lhs.len();
for i in 0..len{
if i == axis {
continue;
}
if lhs[i] != rhs[i]{
return false;
}
}
true
}
pub fn calc_concat_shape(lhs: &Vec<usize>, rhs: &Vec<usize>, axis: usize) -> Option<Vec<usize>>{
if !check_concat_dims(lhs, rhs, axis) {
return None
}
let mut f_vec = lhs.clone();
f_vec[axis] += rhs[axis];
Some(f_vec)
}