use crate::{
iterator_traits::{IterGetSet, ParStridedHelper, ParStridedIteratorZip, ShapeManipulator},
par_strided_fold::ParStridedFold,
par_strided_map::ParStridedMap,
shape_manipulate::{par_expand, par_reshape, par_transpose},
};
use hpt_common::{
axis::axis::Axis,
layout::layout::Layout,
shape::shape::Shape,
shape::shape_utils::{mt_intervals, try_pad_shape},
strides::strides::Strides,
strides::strides_utils::preprocess_strides,
utils::pointer::Pointer,
};
use hpt_traits::tensor::{CommonBounds, TensorInfo};
use rayon::iter::{
plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
ParallelIterator,
};
use std::sync::Arc;
pub mod par_strided_simd {
use hpt_types::vectors::traits::VecTrait;
use std::sync::Arc;
use crate::{CommonBounds, TensorInfo};
use hpt_common::{
axis::axis::Axis,
layout::layout::Layout,
shape::shape::Shape,
shape::shape_utils::{mt_intervals, try_pad_shape},
strides::strides::Strides,
strides::strides_utils::preprocess_strides,
utils::pointer::Pointer,
utils::simd_ref::MutVec,
};
use rayon::iter::{
plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
ParallelIterator,
};
use crate::{
iterator_traits::{
IterGetSetSimd, ParStridedHelper, ParStridedIteratorSimd, ParStridedIteratorSimdZip,
ShapeManipulator,
},
par_strided_map::par_strided_map_simd::ParStridedMapSimd,
shape_manipulate::{par_expand, par_reshape, par_transpose},
};
#[derive(Clone)]
pub struct ParStridedSimd<T: Send + Copy + Sync> {
pub(crate) ptr: Pointer<T>,
pub(crate) layout: Layout,
pub(crate) prg: Vec<i64>,
pub(crate) intervals: Arc<Vec<(usize, usize)>>,
pub(crate) start_index: usize,
pub(crate) end_index: usize,
pub(crate) last_stride: i64,
}
impl<T: CommonBounds> ParStridedSimd<T> {
pub fn shape(&self) -> &Shape {
self.layout.shape()
}
pub fn strides(&self) -> &Strides {
self.layout.strides()
}
pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
let inner_loop_size = *tensor.shape().last().unwrap() as usize;
let outer_loop_size = tensor.size() / inner_loop_size;
let num_threads;
if outer_loop_size < rayon::current_num_threads() {
num_threads = outer_loop_size;
} else {
num_threads = rayon::current_num_threads();
}
let intervals = mt_intervals(outer_loop_size, num_threads);
let len = intervals.len();
ParStridedSimd {
ptr: tensor.ptr(),
layout: tensor.layout().clone(),
prg: vec![],
intervals: Arc::new(intervals),
start_index: 0,
end_index: len,
last_stride: *tensor.strides().last().unwrap(),
}
}
pub fn strided_map_simd<'a, F, F2>(
self,
f: F,
vec_op: F2,
) -> ParStridedMapSimd<'a, ParStridedSimd<T>, T, F, F2>
where
F: Fn((&mut T, <Self as IterGetSetSimd>::Item)) + Sync + Send + 'a,
<Self as IterGetSetSimd>::Item: Send,
F2: Send + Sync + Copy + Fn((MutVec<'_, T::Vec>, <Self as IterGetSetSimd>::SimdItem)),
{
{
ParStridedMapSimd {
iter: self,
f,
f2: vec_op,
phantom: std::marker::PhantomData,
}
}
}
}
impl<T: CommonBounds> ParStridedIteratorSimdZip for ParStridedSimd<T> {}
impl<T: CommonBounds> ParStridedIteratorSimd for ParStridedSimd<T> {}
impl<T: CommonBounds> IterGetSetSimd for ParStridedSimd<T>
where
T::Vec: Send,
{
type Item = T;
type SimdItem = T::Vec;
fn set_end_index(&mut self, end_index: usize) {
self.end_index = end_index;
}
fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.intervals = intervals;
}
fn set_strides(&mut self, strides: Strides) {
self.layout.set_strides(strides);
}
fn set_shape(&mut self, shape: Shape) {
self.layout.set_shape(shape);
}
fn set_prg(&mut self, prg: Vec<i64>) {
self.prg = prg;
}
fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
&self.intervals
}
fn strides(&self) -> &Strides {
self.layout.strides()
}
fn shape(&self) -> &Shape {
self.layout.shape()
}
fn layout(&self) -> &Layout {
&self.layout
}
fn broadcast_set_strides(&mut self, shape: &Shape) {
let self_shape = try_pad_shape(self.shape(), shape.len());
self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
self.last_stride = self.strides()[self.strides().len() - 1];
}
fn outer_loop_size(&self) -> usize {
self.intervals[self.start_index].1 - self.intervals[self.start_index].0
}
fn inner_loop_size(&self) -> usize {
self.shape().last().unwrap().clone() as usize
}
fn next(&mut self) {
for j in (0..(self.shape().len() as i64) - 1).rev() {
let j = j as usize;
if self.prg[j] < self.shape()[j] {
self.prg[j] += 1;
self.ptr.offset(self.strides()[j]);
break;
} else {
self.prg[j] = 0;
self.ptr.offset(-self.strides()[j] * self.shape()[j]);
}
}
}
fn next_simd(&mut self) {}
fn inner_loop_next(&mut self, index: usize) -> Self::Item {
unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
}
#[inline(always)]
fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
unsafe { T::Vec::from_ptr(self.ptr.get_ptr().add(index * T::Vec::SIZE)) }
}
fn all_last_stride_one(&self) -> bool {
self.last_stride == 1
}
fn lanes(&self) -> Option<usize> {
Some(T::Vec::SIZE)
}
}
impl<T> ParallelIterator for ParStridedSimd<T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(self, consumer)
}
}
impl<T> UnindexedProducer for ParStridedSimd<T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = T;
fn split(mut self) -> (Self, Option<Self>) {
if self.end_index - self.start_index <= 1 {
let mut curent_shape_prg: Vec<i64> = vec![0; self.shape().len()];
let mut amount =
self.intervals[self.start_index].0 * (*self.shape().last().unwrap() as usize);
for j in (0..self.shape().len()).rev() {
curent_shape_prg[j] = (amount as i64) % self.shape()[j];
amount /= self.shape()[j] as usize;
self.ptr += curent_shape_prg[j] * self.strides()[j];
}
self.prg = curent_shape_prg;
let mut new_shape = self.shape().to_vec();
new_shape.iter_mut().for_each(|x| {
*x -= 1;
});
self.last_stride = self.strides()[self.strides().len() - 1];
self.set_shape(Shape::from(new_shape));
return (self, None);
}
let _left_interval = &self.intervals[self.start_index..self.end_index];
let left = _left_interval.len() / 2;
let right = _left_interval.len() / 2 + (_left_interval.len() % 2);
(
ParStridedSimd {
ptr: self.ptr.clone(),
layout: self.layout.clone(),
prg: vec![],
intervals: self.intervals.clone(),
start_index: self.start_index,
end_index: self.start_index + left,
last_stride: self.last_stride,
},
Some(ParStridedSimd {
ptr: self.ptr.clone(),
layout: self.layout.clone(),
prg: vec![],
intervals: self.intervals.clone(),
start_index: self.start_index + left,
end_index: self.start_index + left + right,
last_stride: self.last_stride,
}),
)
}
fn fold_with<F>(self, folder: F) -> F
where
F: Folder<Self::Item>,
{
folder
}
}
impl<T: CommonBounds> ParStridedHelper for ParStridedSimd<T> {
fn _set_last_strides(&mut self, stride: i64) {
self.last_stride = stride;
}
fn _set_strides(&mut self, strides: Strides) {
self.layout.set_strides(strides);
}
fn _set_shape(&mut self, shape: Shape) {
self.layout.set_shape(shape);
}
fn _layout(&self) -> &Layout {
&self.layout
}
fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.intervals = intervals;
}
fn _set_end_index(&mut self, end_index: usize) {
self.end_index = end_index;
}
}
impl<T: CommonBounds> ShapeManipulator for ParStridedSimd<T>
where
T::Vec: Send,
{
fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
par_reshape(self, shape)
}
fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
par_transpose(self, axes)
}
fn expand<S: Into<Shape>>(self, shape: S) -> Self {
par_expand(self, shape)
}
}
}
#[derive(Clone)]
pub struct ParStrided<T> {
pub(crate) ptr: Pointer<T>,
pub(crate) layout: Layout,
pub(crate) prg: Vec<i64>,
pub(crate) intervals: Arc<Vec<(usize, usize)>>,
pub(crate) start_index: usize,
pub(crate) end_index: usize,
pub(crate) last_stride: i64,
}
impl<T: CommonBounds> ParStrided<T> {
pub fn shape(&self) -> &Shape {
self.layout.shape()
}
pub fn strides(&self) -> &Strides {
self.layout.strides()
}
pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
let inner_loop_size = tensor.shape()[tensor.shape().len() - 1] as usize;
let outer_loop_size = tensor.size() / inner_loop_size;
let num_threads;
if outer_loop_size < rayon::current_num_threads() {
num_threads = outer_loop_size;
} else {
num_threads = rayon::current_num_threads();
}
let intervals = mt_intervals(outer_loop_size, num_threads);
let len = intervals.len();
ParStrided {
ptr: tensor.ptr(),
layout: tensor.layout().clone(),
prg: vec![],
intervals: Arc::new(intervals),
start_index: 0,
end_index: len,
last_stride: tensor.strides()[tensor.strides().len() - 1],
}
}
pub fn par_strided_fold<ID, F>(self, identity: ID, fold_op: F) -> ParStridedFold<Self, ID, F>
where
F: Fn(ID, T) -> ID + Sync + Send + Copy,
ID: Sync + Send + Copy,
{
ParStridedFold {
iter: self,
identity,
fold_op,
}
}
pub fn strided_map<'a, F, U>(self, f: F) -> ParStridedMap<'a, ParStrided<T>, T, F>
where
F: Fn((&mut U, T)) + Sync + Send + 'a,
U: CommonBounds,
{
ParStridedMap {
iter: self,
f,
phantom: std::marker::PhantomData,
}
}
}
impl<T: CommonBounds> ParStridedIteratorZip for ParStrided<T> {}
impl<T: CommonBounds> IterGetSet for ParStrided<T> {
type Item = T;
fn set_end_index(&mut self, end_index: usize) {
self.end_index = end_index;
}
fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.intervals = intervals;
}
fn set_strides(&mut self, strides: Strides) {
self.layout.set_strides(strides);
}
fn set_shape(&mut self, shape: Shape) {
self.layout.set_shape(shape);
}
fn set_prg(&mut self, prg: Vec<i64>) {
self.prg = prg;
}
fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
&self.intervals
}
fn strides(&self) -> &Strides {
self.layout.strides()
}
fn shape(&self) -> &Shape {
self.layout.shape()
}
fn layout(&self) -> &Layout {
&self.layout
}
fn broadcast_set_strides(&mut self, shape: &Shape) {
let self_shape = try_pad_shape(self.shape(), shape.len());
self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
self.last_stride = self.strides()[self.strides().len() - 1];
}
fn outer_loop_size(&self) -> usize {
self.intervals[self.start_index].1 - self.intervals[self.start_index].0
}
fn inner_loop_size(&self) -> usize {
self.shape().last().unwrap().clone() as usize
}
fn next(&mut self) {
for j in (0..(self.shape().len() as i64) - 1).rev() {
let j = j as usize;
if self.prg[j] < self.shape()[j] {
self.prg[j] += 1;
self.ptr.offset(self.strides()[j]);
break;
} else {
self.prg[j] = 0;
self.ptr.offset(-self.strides()[j] * self.shape()[j]);
}
}
}
fn inner_loop_next(&mut self, index: usize) -> Self::Item {
unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
}
}
impl<T> ParallelIterator for ParStrided<T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(self, consumer)
}
}
impl<T> UnindexedProducer for ParStrided<T>
where
T: CommonBounds,
T::Vec: Send,
{
type Item = T;
fn split(mut self) -> (Self, Option<Self>) {
if self.end_index - self.start_index <= 1 {
let mut curent_shape_prg: Vec<i64> = vec![0; self.shape().len()];
let mut amount =
self.intervals[self.start_index].0 * (*self.shape().last().unwrap() as usize);
let mut index = 0;
for j in (0..self.shape().len()).rev() {
curent_shape_prg[j] = (amount as i64) % self.shape()[j];
amount /= self.shape()[j] as usize;
index += curent_shape_prg[j] * self.strides()[j];
}
self.ptr.offset(index);
self.prg = curent_shape_prg;
let mut new_shape = self.shape().to_vec();
new_shape.iter_mut().for_each(|x| {
*x -= 1;
});
self.last_stride = self.strides()[self.strides().len() - 1];
self.set_shape(Shape::from(new_shape));
return (self, None);
}
let _left_interval = &self.intervals[self.start_index..self.end_index];
let left = _left_interval.len() / 2;
let right = _left_interval.len() / 2 + (_left_interval.len() % 2);
(
ParStrided {
ptr: self.ptr.clone(),
layout: self.layout.clone(),
prg: vec![],
intervals: self.intervals.clone(),
start_index: self.start_index,
end_index: self.start_index + left,
last_stride: self.last_stride,
},
Some(ParStrided {
ptr: self.ptr.clone(),
layout: self.layout.clone(),
prg: vec![],
intervals: self.intervals.clone(),
start_index: self.start_index + left,
end_index: self.start_index + left + right,
last_stride: self.last_stride,
}),
)
}
fn fold_with<F>(self, folder: F) -> F
where
F: Folder<Self::Item>,
{
folder
}
}
impl<T> ParStridedHelper for ParStrided<T> {
fn _set_last_strides(&mut self, last_stride: i64) {
self.last_stride = last_stride;
}
fn _set_strides(&mut self, strides: Strides) {
self.layout.set_strides(strides);
}
fn _set_shape(&mut self, shape: Shape) {
self.layout.set_shape(shape);
}
fn _layout(&self) -> &Layout {
&self.layout
}
fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
self.intervals = intervals;
}
fn _set_end_index(&mut self, end_index: usize) {
self.end_index = end_index;
}
}
impl<T: CommonBounds> ShapeManipulator for ParStrided<T> {
#[track_caller]
fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
par_reshape(self, shape)
}
fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
par_transpose(self, axes)
}
fn expand<S: Into<Shape>>(self, shape: S) -> Self {
par_expand(self, shape)
}
}