use std::sync::Arc;
use hpt_common::{
axis::axis::{process_axes, Axis},
error::shape::ShapeError,
shape::shape::Shape,
shape::shape_utils::{get_broadcast_axes_from, mt_intervals, try_pad_shape},
strides::strides::Strides,
};
use crate::iterator_traits::{ParStridedHelper, StridedHelper};
#[track_caller]
pub(crate) fn par_reshape<S: Into<Shape>, T: ParStridedHelper>(mut iterator: T, shape: S) -> T {
let tmp = shape.into();
let res_shape = tmp;
if iterator._layout().shape() == &res_shape {
return iterator;
}
let size = res_shape.size() as usize;
let inner_loop_size = res_shape[res_shape.len() - 1] as usize;
let outer_loop_size = size / inner_loop_size;
let num_threads;
if outer_loop_size < rayon::current_num_threads() {
num_threads = outer_loop_size;
} else {
num_threads = rayon::current_num_threads();
}
let intervals = mt_intervals(outer_loop_size, num_threads);
let len = intervals.len();
iterator._set_intervals(Arc::new(intervals));
iterator._set_end_index(len);
let self_size = iterator._layout().size();
if size > (self_size as usize) {
let self_shape = try_pad_shape(iterator._layout().shape(), res_shape.len());
let axes_to_broadcast =
get_broadcast_axes_from(&self_shape, &res_shape).expect("Cannot broadcast shapes");
let mut new_strides = vec![0; res_shape.len()];
new_strides
.iter_mut()
.rev()
.zip(iterator._layout().strides().iter().rev())
.for_each(|(a, b)| {
*a = *b;
});
for &axis in axes_to_broadcast.iter() {
assert_eq!(self_shape[axis], 1);
new_strides[axis] = 0;
}
iterator._set_last_strides(new_strides[new_strides.len() - 1]);
iterator._set_strides(new_strides.into());
} else {
ShapeError::check_size_match(iterator._layout().shape().size(), res_shape.size())
.expect("Cannot reshape iterator");
if let Some(new_strides) = iterator._layout().is_reshape_possible(&res_shape) {
iterator._set_strides(new_strides);
iterator._set_last_strides(
iterator._layout().strides()[iterator._layout().strides().len() - 1],
);
} else {
panic!("Cannot reshape iterator");
}
}
iterator._set_shape(res_shape.clone());
iterator
}
#[track_caller]
pub(crate) fn par_transpose<AXIS: Into<Axis>, T: ParStridedHelper>(
mut iterator: T,
axes: AXIS,
) -> T {
let axes = process_axes(axes, iterator._layout().shape().len()).unwrap();
let mut new_shape = iterator._layout().shape().to_vec();
for i in axes.iter() {
new_shape[*i] = iterator._layout().shape()[axes[*i]];
}
let mut new_strides = iterator._layout().strides().to_vec();
for i in axes.iter() {
new_strides[*i] = iterator._layout().strides()[axes[*i]];
}
let new_strides: Strides = new_strides.into();
let new_shape = Arc::new(new_shape);
let outer_loop_size =
(new_shape.iter().product::<i64>() as usize) / (new_shape[new_shape.len() - 1] as usize);
let num_threads;
if outer_loop_size < rayon::current_num_threads() {
num_threads = outer_loop_size;
} else {
num_threads = rayon::current_num_threads();
}
let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
let len = intervals.len();
iterator._set_intervals(intervals.clone());
iterator._set_end_index(len);
iterator._set_last_strides(new_strides[new_strides.len() - 1]);
iterator._set_strides(new_strides);
iterator._set_shape(Shape::from(new_shape));
iterator
}
#[track_caller]
pub(crate) fn par_expand<S: Into<Shape>, T: ParStridedHelper>(mut iterator: T, shape: S) -> T {
let res_shape = shape.into();
let new_strides = iterator
._layout()
.expand_strides(&res_shape)
.expect("Cannot expand iterator");
let outer_loop_size =
(res_shape.iter().product::<i64>() as usize) / (res_shape[res_shape.len() - 1] as usize);
let num_threads;
if outer_loop_size < rayon::current_num_threads() {
num_threads = outer_loop_size;
} else {
num_threads = rayon::current_num_threads();
}
let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
let len = intervals.len();
iterator._set_intervals(intervals.clone());
iterator._set_end_index(len);
iterator._set_shape(res_shape.clone());
iterator._set_strides(new_strides);
iterator
}
#[track_caller]
pub(crate) fn reshape<S: Into<Shape>, T: StridedHelper>(mut iterator: T, shape: S) -> T {
let tmp = shape.into();
let res_shape = tmp;
if iterator._layout().shape() == &res_shape {
return iterator;
}
let size = res_shape.size() as usize;
let self_size = iterator._layout().size();
if size > (self_size as usize) {
let self_shape = try_pad_shape(iterator._layout().shape(), res_shape.len());
let axes_to_broadcast =
get_broadcast_axes_from(&self_shape, &res_shape).expect("Cannot broadcast shapes");
let mut new_strides = vec![0; res_shape.len()];
new_strides
.iter_mut()
.rev()
.zip(iterator._layout().strides().iter().rev())
.for_each(|(a, b)| {
*a = *b;
});
for &axis in axes_to_broadcast.iter() {
assert_eq!(self_shape[axis], 1);
new_strides[axis] = 0;
}
iterator._set_last_strides(new_strides[new_strides.len() - 1]);
iterator._set_strides(new_strides.into());
} else {
ShapeError::check_size_match(
iterator._layout().shape().inner().iter().product(),
res_shape.size(),
)
.expect("Cannot reshape iterator");
if let Some(new_strides) = iterator._layout().is_reshape_possible(&res_shape) {
iterator._set_strides(new_strides);
iterator._set_last_strides(
iterator._layout().strides()[iterator._layout().strides().len() - 1],
);
} else {
panic!("Cannot reshape iterator");
}
}
iterator._set_shape(res_shape.clone());
iterator
}
#[track_caller]
pub(crate) fn expand<T: StridedHelper, S: Into<Shape>>(mut iterator: T, shape: S) -> T {
let res_shape: Shape = shape.into();
let new_strides = iterator
._layout()
.expand_strides(&res_shape)
.expect("Cannot expand iterator");
iterator._set_shape(res_shape.clone());
iterator._set_strides(new_strides);
iterator
}
#[track_caller]
pub(crate) fn transpose<T: StridedHelper, AXIS: Into<Axis>>(mut iterator: T, axes: AXIS) -> T {
let axes = process_axes(axes, iterator._layout().shape().len()).unwrap();
let mut new_shape = iterator._layout().shape().to_vec();
for i in axes.iter() {
new_shape[*i] = iterator._layout().shape()[axes[*i]];
}
let mut new_strides = iterator._layout().strides().to_vec();
for i in axes.iter() {
new_strides[*i] = iterator._layout().strides()[axes[*i]];
}
let new_strides: Strides = new_strides.into();
let new_shape = Arc::new(new_shape);
iterator._set_last_strides(new_strides[new_strides.len() - 1]);
iterator._set_strides(new_strides);
iterator._set_shape(Shape::from(new_shape));
iterator
}